summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-17 14:57:07 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-17 14:57:41 -0800
commit74278dd01e5713920a35f1c3e0731e535667c19a (patch)
tree3aa35ade367c4f86e092c52471346e6456e52aa0 /mlir/lib
parent6fa3bd5b3e57806ffa34946bd36528f72bf06b58 (diff)
downloadbcm5719-llvm-74278dd01e5713920a35f1c3e0731e535667c19a.tar.gz
bcm5719-llvm-74278dd01e5713920a35f1c3e0731e535667c19a.zip
NFC: Use TypeSwitch to simplify existing code.
PiperOrigin-RevId: 286066371
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/MemRefBoundCheck.cpp9
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp33
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp70
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp12
4 files changed, 47 insertions, 77 deletions
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp
index 52379c0a1d0..4696ce64c22 100644
--- a/mlir/lib/Analysis/MemRefBoundCheck.cpp
+++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Passes.h"
@@ -49,11 +50,9 @@ std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefBoundCheckPass() {
void MemRefBoundCheck::runOnFunction() {
getFunction().walk([](Operation *opInst) {
- if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
- boundCheckLoadOrStoreOp(loadOp);
- } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
- boundCheckLoadOrStoreOp(storeOp);
- }
+ TypeSwitch<Operation *>(opInst).Case<AffineLoadOp, AffineStoreOp>(
+ [](auto op) { boundCheckLoadOrStoreOp(op); });
+
// TODO(bondhugula): do this for DMA ops as well.
});
}
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 51cdd7270d9..5d6a92fee92 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -232,25 +233,19 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
}
// Dispatch based on the actual type. Return null type on error.
-Type LLVMTypeConverter::convertStandardType(Type type) {
- if (auto funcType = type.dyn_cast<FunctionType>())
- return convertFunctionType(funcType);
- if (auto intType = type.dyn_cast<IntegerType>())
- return convertIntegerType(intType);
- if (auto floatType = type.dyn_cast<FloatType>())
- return convertFloatType(floatType);
- if (auto indexType = type.dyn_cast<IndexType>())
- return convertIndexType(indexType);
- if (auto memRefType = type.dyn_cast<MemRefType>())
- return convertMemRefType(memRefType);
- if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
- return convertUnrankedMemRefType(memRefType);
- if (auto vectorType = type.dyn_cast<VectorType>())
- return convertVectorType(vectorType);
- if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
- return llvmType;
-
- return {};
+Type LLVMTypeConverter::convertStandardType(Type t) {
+ return TypeSwitch<Type, Type>(t)
+ .Case([&](FloatType type) { return convertFloatType(type); })
+ .Case([&](FunctionType type) { return convertFunctionType(type); })
+ .Case([&](IndexType type) { return convertIndexType(type); })
+ .Case([&](IntegerType type) { return convertIntegerType(type); })
+ .Case([&](MemRefType type) { return convertMemRefType(type); })
+ .Case([&](UnrankedMemRefType type) {
+ return convertUnrankedMemRefType(type);
+ })
+ .Case([&](VectorType type) { return convertVectorType(type); })
+ .Case([](LLVM::LLVMType type) { return type; })
+ .Default([](Type) { return Type(); });
}
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 2cb75de084a..f7591bf7480 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
@@ -1634,54 +1635,33 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
return success();
}
-LogicalResult Serializer::processOperation(Operation *op) {
- LLVM_DEBUG(llvm::dbgs() << "[op] '" << op->getName() << "'\n");
+LogicalResult Serializer::processOperation(Operation *opInst) {
+ LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
// First dispatch the ops that do not directly mirror an instruction from
// the SPIR-V spec.
- if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
- return processAddressOfOp(addressOfOp);
- }
- if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
- return processBranchOp(branchOp);
- }
- if (auto condBranchOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
- return processBranchConditionalOp(condBranchOp);
- }
- if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) {
- return processConstantOp(constOp);
- }
- if (auto fnOp = dyn_cast<FuncOp>(op)) {
- return processFuncOp(fnOp);
- }
- if (auto varOp = dyn_cast<spirv::VariableOp>(op)) {
- return processVariableOp(varOp);
- }
- if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
- return processGlobalVariableOp(varOp);
- }
- if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) {
- return processSelectionOp(selectionOp);
- }
- if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) {
- return processLoopOp(loopOp);
- }
- if (isa<spirv::ModuleEndOp>(op)) {
- return success();
- }
- if (auto refOpOp = dyn_cast<spirv::ReferenceOfOp>(op)) {
- return processReferenceOfOp(refOpOp);
- }
- if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) {
- return processSpecConstantOp(specConstOp);
- }
- if (auto undefOp = dyn_cast<spirv::UndefOp>(op)) {
- return processUndefOp(undefOp);
- }
-
- // Then handle all the ops that directly mirror SPIR-V instructions with
- // auto-generated methods.
- return dispatchToAutogenSerialization(op);
+ return TypeSwitch<Operation *, LogicalResult>(opInst)
+ .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
+ .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
+ .Case([&](spirv::BranchConditionalOp op) {
+ return processBranchConditionalOp(op);
+ })
+ .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
+ .Case([&](FuncOp op) { return processFuncOp(op); })
+ .Case([&](spirv::GlobalVariableOp op) {
+ return processGlobalVariableOp(op);
+ })
+ .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
+ .Case([&](spirv::ModuleEndOp) { return success(); })
+ .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
+ .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
+ .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
+ .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
+ .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
+
+ // Then handle all the ops that directly mirror SPIR-V instructions with
+ // auto-generated methods.
+ .Default([&](auto *op) { return dispatchToAutogenSerialization(op); });
}
namespace {
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 190b6c3155e..79a6d7a6902 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/Utils.h"
+#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Dominance.h"
@@ -47,14 +48,9 @@ static bool isMemRefDereferencingOp(Operation &op) {
/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
- if (auto loadOp = dyn_cast<AffineLoadOp>(op))
- return loadOp.getAffineMapAttrForMemRef(memref);
- else if (auto storeOp = dyn_cast<AffineStoreOp>(op))
- return storeOp.getAffineMapAttrForMemRef(memref);
- else if (auto dmaStart = dyn_cast<AffineDmaStartOp>(op))
- return dmaStart.getAffineMapAttrForMemRef(memref);
- assert(isa<AffineDmaWaitOp>(op));
- return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
+ return TypeSwitch<Operation *, NamedAttribute>(op)
+ .Case<AffineDmaStartOp, AffineLoadOp, AffineStoreOp, AffineDmaWaitOp>(
+ [=](auto op) { return op.getAffineMapAttrForMemRef(memref); });
}
// Perform the replacement in `op`.
OpenPOWER on IntegriCloud