summaryrefslogtreecommitdiffstats
path: root/mlir/unittests/ADT
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-17 10:07:26 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-17 10:08:06 -0800
commitf44cf23297089dd4beb6f81a7fdda4e59466dcdb (patch)
tree91fd982330235a7073376d589f9106a4a84522c4 /mlir/unittests/ADT
parent6e581e29a47b3005b14d8c1dac29bd9cd9c48381 (diff)
downloadbcm5719-llvm-f44cf23297089dd4beb6f81a7fdda4e59466dcdb.tar.gz
bcm5719-llvm-f44cf23297089dd4beb6f81a7fdda4e59466dcdb.zip
Add a new utility class TypeSwitch to ADT.
This class provides a simplified mechanism for defining a switch over a set of types using llvm casting functionality. More specifically, this allows for defining a switch over a value of type T where each case corresponds to a type(CaseT) that can be used with dyn_cast<CaseT>(...). An example is shown below: // Traditional piece of code: Operation *op = ...; if (auto constant = dyn_cast<ConstantOp>(op)) ...; else if (auto return = dyn_cast<ReturnOp>(op)) ...; else ...; // New piece of code: Operation *op = ...; TypeSwitch<Operation *>(op) .Case<ConstantOp>([](ConstantOp constant) { ... }) .Case<ReturnOp>([](ReturnOp return) { ... }) .Default([](Operation *op) { ... }); Aside from the above, TypeSwitch supports return values, void return, multiple types per case, etc. The usability is intended to be very similar to StringSwitch. (Using c++14 template lambdas makes everything even nicer) More complex example of how this makes certain things easier: LogicalResult process(Constant op); LogicalResult process(ReturnOp op); LogicalResult process(FuncOp op); TypeSwitch<Operation *, LogicalResult>(op) .Case<ConstantOp, ReturnOp, FuncOp>([](auto op) { return process(op); }) .Default([](Operation *op) { return op->emitError() << "could not be processed"; }); PiperOrigin-RevId: 286003613
Diffstat (limited to 'mlir/unittests/ADT')
-rw-r--r--mlir/unittests/ADT/CMakeLists.txt5
-rw-r--r--mlir/unittests/ADT/TypeSwitchTest.cpp97
2 files changed, 102 insertions, 0 deletions
diff --git a/mlir/unittests/ADT/CMakeLists.txt b/mlir/unittests/ADT/CMakeLists.txt
new file mode 100644
index 00000000000..cb122620512
--- /dev/null
+++ b/mlir/unittests/ADT/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_mlir_unittest(MLIRADTTests
+ TypeSwitchTest.cpp
+)
+
+target_link_libraries(MLIRADTTests PRIVATE MLIRSupport LLVMSupport)
diff --git a/mlir/unittests/ADT/TypeSwitchTest.cpp b/mlir/unittests/ADT/TypeSwitchTest.cpp
new file mode 100644
index 00000000000..b6a78de892e
--- /dev/null
+++ b/mlir/unittests/ADT/TypeSwitchTest.cpp
@@ -0,0 +1,97 @@
+//===- TypeSwitchTest.cpp - TypeSwitch unit tests -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/ADT/TypeSwitch.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+/// Utility classes to setup casting functionality.
+struct Base {
+ enum Kind { DerivedA, DerivedB, DerivedC, DerivedD, DerivedE };
+ Kind kind;
+};
+template <Base::Kind DerivedKind> struct DerivedImpl : Base {
+ DerivedImpl() : Base{DerivedKind} {}
+ static bool classof(const Base *base) { return base->kind == DerivedKind; }
+};
+struct DerivedA : public DerivedImpl<Base::DerivedA> {};
+struct DerivedB : public DerivedImpl<Base::DerivedB> {};
+struct DerivedC : public DerivedImpl<Base::DerivedC> {};
+struct DerivedD : public DerivedImpl<Base::DerivedD> {};
+struct DerivedE : public DerivedImpl<Base::DerivedE> {};
+} // end anonymous namespace
+
+TEST(StringSwitchTest, CaseResult) {
+ auto translate = [](auto value) {
+ return TypeSwitch<Base *, int>(&value)
+ .Case<DerivedA>([](DerivedA *) { return 0; })
+ .Case([](DerivedB *) { return 1; })
+ .Case([](DerivedC *) { return 2; })
+ .Default([](Base *) { return -1; });
+ };
+ EXPECT_EQ(0, translate(DerivedA()));
+ EXPECT_EQ(1, translate(DerivedB()));
+ EXPECT_EQ(2, translate(DerivedC()));
+ EXPECT_EQ(-1, translate(DerivedD()));
+}
+
+TEST(StringSwitchTest, CasesResult) {
+ auto translate = [](auto value) {
+ return TypeSwitch<Base *, int>(&value)
+ .Case<DerivedA, DerivedB, DerivedD>([](auto *) { return 0; })
+ .Case([](DerivedC *) { return 1; })
+ .Default([](Base *) { return -1; });
+ };
+ EXPECT_EQ(0, translate(DerivedA()));
+ EXPECT_EQ(0, translate(DerivedB()));
+ EXPECT_EQ(1, translate(DerivedC()));
+ EXPECT_EQ(0, translate(DerivedD()));
+ EXPECT_EQ(-1, translate(DerivedE()));
+}
+
+TEST(StringSwitchTest, CaseVoid) {
+ auto translate = [](auto value) {
+ int result = -2;
+ TypeSwitch<Base *>(&value)
+ .Case([&](DerivedA *) { result = 0; })
+ .Case([&](DerivedB *) { result = 1; })
+ .Case([&](DerivedC *) { result = 2; })
+ .Default([&](Base *) { result = -1; });
+ return result;
+ };
+ EXPECT_EQ(0, translate(DerivedA()));
+ EXPECT_EQ(1, translate(DerivedB()));
+ EXPECT_EQ(2, translate(DerivedC()));
+ EXPECT_EQ(-1, translate(DerivedD()));
+}
+
+TEST(StringSwitchTest, CasesVoid) {
+ auto translate = [](auto value) {
+ int result = -1;
+ TypeSwitch<Base *>(&value)
+ .Case<DerivedA, DerivedB, DerivedD>([&](auto *) { result = 0; })
+ .Case([&](DerivedC *) { result = 1; });
+ return result;
+ };
+ EXPECT_EQ(0, translate(DerivedA()));
+ EXPECT_EQ(0, translate(DerivedB()));
+ EXPECT_EQ(1, translate(DerivedC()));
+ EXPECT_EQ(0, translate(DerivedD()));
+ EXPECT_EQ(-1, translate(DerivedE()));
+}
OpenPOWER on IntegriCloud