diff options
author | River Riddle <riverriddle@google.com> | 2019-12-17 10:07:26 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-17 10:08:06 -0800 |
commit | f44cf23297089dd4beb6f81a7fdda4e59466dcdb (patch) | |
tree | 91fd982330235a7073376d589f9106a4a84522c4 /mlir/unittests/ADT | |
parent | 6e581e29a47b3005b14d8c1dac29bd9cd9c48381 (diff) | |
download | bcm5719-llvm-f44cf23297089dd4beb6f81a7fdda4e59466dcdb.tar.gz bcm5719-llvm-f44cf23297089dd4beb6f81a7fdda4e59466dcdb.zip |
Add a new utility class TypeSwitch to ADT.
This class provides a simplified mechanism for defining a switch over a set of types using llvm casting functionality. More specifically, this allows for defining a switch over a value of type T where each case corresponds to a type(CaseT) that can be used with dyn_cast<CaseT>(...). An example is shown below:
// Traditional piece of code:
Operation *op = ...;
if (auto constant = dyn_cast<ConstantOp>(op))
...;
else if (auto return = dyn_cast<ReturnOp>(op))
...;
else
...;
// New piece of code:
Operation *op = ...;
TypeSwitch<Operation *>(op)
.Case<ConstantOp>([](ConstantOp constant) { ... })
.Case<ReturnOp>([](ReturnOp return) { ... })
.Default([](Operation *op) { ... });
Aside from the above, TypeSwitch supports return values, void return, multiple types per case, etc. The usability is intended to be very similar to StringSwitch.
(Using c++14 template lambdas makes everything even nicer)
More complex example of how this makes certain things easier:
LogicalResult process(Constant op);
LogicalResult process(ReturnOp op);
LogicalResult process(FuncOp op);
TypeSwitch<Operation *, LogicalResult>(op)
.Case<ConstantOp, ReturnOp, FuncOp>([](auto op) { return process(op); })
.Default([](Operation *op) { return op->emitError() << "could not be processed"; });
PiperOrigin-RevId: 286003613
Diffstat (limited to 'mlir/unittests/ADT')
-rw-r--r-- | mlir/unittests/ADT/CMakeLists.txt | 5 | ||||
-rw-r--r-- | mlir/unittests/ADT/TypeSwitchTest.cpp | 97 |
2 files changed, 102 insertions, 0 deletions
diff --git a/mlir/unittests/ADT/CMakeLists.txt b/mlir/unittests/ADT/CMakeLists.txt new file mode 100644 index 00000000000..cb122620512 --- /dev/null +++ b/mlir/unittests/ADT/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_unittest(MLIRADTTests + TypeSwitchTest.cpp +) + +target_link_libraries(MLIRADTTests PRIVATE MLIRSupport LLVMSupport) diff --git a/mlir/unittests/ADT/TypeSwitchTest.cpp b/mlir/unittests/ADT/TypeSwitchTest.cpp new file mode 100644 index 00000000000..b6a78de892e --- /dev/null +++ b/mlir/unittests/ADT/TypeSwitchTest.cpp @@ -0,0 +1,97 @@ +//===- TypeSwitchTest.cpp - TypeSwitch unit tests -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/ADT/TypeSwitch.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +/// Utility classes to setup casting functionality. +struct Base { + enum Kind { DerivedA, DerivedB, DerivedC, DerivedD, DerivedE }; + Kind kind; +}; +template <Base::Kind DerivedKind> struct DerivedImpl : Base { + DerivedImpl() : Base{DerivedKind} {} + static bool classof(const Base *base) { return base->kind == DerivedKind; } +}; +struct DerivedA : public DerivedImpl<Base::DerivedA> {}; +struct DerivedB : public DerivedImpl<Base::DerivedB> {}; +struct DerivedC : public DerivedImpl<Base::DerivedC> {}; +struct DerivedD : public DerivedImpl<Base::DerivedD> {}; +struct DerivedE : public DerivedImpl<Base::DerivedE> {}; +} // end anonymous namespace + +TEST(StringSwitchTest, CaseResult) { + auto translate = [](auto value) { + return TypeSwitch<Base *, int>(&value) + .Case<DerivedA>([](DerivedA *) { return 0; }) + .Case([](DerivedB *) { return 1; }) + .Case([](DerivedC *) { return 2; }) + .Default([](Base *) { return -1; }); + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(1, translate(DerivedB())); + EXPECT_EQ(2, translate(DerivedC())); + EXPECT_EQ(-1, translate(DerivedD())); +} + +TEST(StringSwitchTest, CasesResult) { + auto translate = [](auto value) { + return TypeSwitch<Base *, int>(&value) + .Case<DerivedA, DerivedB, DerivedD>([](auto *) { return 0; }) + .Case([](DerivedC *) { return 1; }) + .Default([](Base *) { return -1; }); + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(0, translate(DerivedB())); + EXPECT_EQ(1, translate(DerivedC())); + EXPECT_EQ(0, translate(DerivedD())); + EXPECT_EQ(-1, translate(DerivedE())); +} + +TEST(StringSwitchTest, CaseVoid) { + auto translate = [](auto value) { + int result = -2; + TypeSwitch<Base *>(&value) + .Case([&](DerivedA *) { result = 0; }) + .Case([&](DerivedB *) { result = 1; }) + .Case([&](DerivedC *) { result = 2; }) + .Default([&](Base *) { result = -1; }); + return result; + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(1, translate(DerivedB())); + EXPECT_EQ(2, translate(DerivedC())); + EXPECT_EQ(-1, translate(DerivedD())); +} + +TEST(StringSwitchTest, CasesVoid) { + auto translate = [](auto value) { + int result = -1; + TypeSwitch<Base *>(&value) + .Case<DerivedA, DerivedB, DerivedD>([&](auto *) { result = 0; }) + .Case([&](DerivedC *) { result = 1; }); + return result; + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(0, translate(DerivedB())); + EXPECT_EQ(1, translate(DerivedC())); + EXPECT_EQ(0, translate(DerivedD())); + EXPECT_EQ(-1, translate(DerivedE())); +} |