diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-17 14:57:07 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-17 14:57:41 -0800 |
| commit | 74278dd01e5713920a35f1c3e0731e535667c19a (patch) | |
| tree | 3aa35ade367c4f86e092c52471346e6456e52aa0 /mlir/lib | |
| parent | 6fa3bd5b3e57806ffa34946bd36528f72bf06b58 (diff) | |
| download | bcm5719-llvm-74278dd01e5713920a35f1c3e0731e535667c19a.tar.gz bcm5719-llvm-74278dd01e5713920a35f1c3e0731e535667c19a.zip | |
NFC: Use TypeSwitch to simplify existing code.
PiperOrigin-RevId: 286066371
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/MemRefBoundCheck.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 33 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 70 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 12 |
4 files changed, 47 insertions, 77 deletions
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 52379c0a1d0..4696ce64c22 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Passes.h" @@ -49,11 +50,9 @@ std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefBoundCheckPass() { void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) { - boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) { - boundCheckLoadOrStoreOp(storeOp); - } + TypeSwitch<Operation *>(opInst).Case<AffineLoadOp, AffineStoreOp>( + [](auto op) { boundCheckLoadOrStoreOp(op); }); + // TODO(bondhugula): do this for DMA ops as well. }); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 51cdd7270d9..5d6a92fee92 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -232,25 +233,19 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) { } // Dispatch based on the actual type. Return null type on error. -Type LLVMTypeConverter::convertStandardType(Type type) { - if (auto funcType = type.dyn_cast<FunctionType>()) - return convertFunctionType(funcType); - if (auto intType = type.dyn_cast<IntegerType>()) - return convertIntegerType(intType); - if (auto floatType = type.dyn_cast<FloatType>()) - return convertFloatType(floatType); - if (auto indexType = type.dyn_cast<IndexType>()) - return convertIndexType(indexType); - if (auto memRefType = type.dyn_cast<MemRefType>()) - return convertMemRefType(memRefType); - if (auto memRefType = type.dyn_cast<UnrankedMemRefType>()) - return convertUnrankedMemRefType(memRefType); - if (auto vectorType = type.dyn_cast<VectorType>()) - return convertVectorType(vectorType); - if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) - return llvmType; - - return {}; +Type LLVMTypeConverter::convertStandardType(Type t) { + return TypeSwitch<Type, Type>(t) + .Case([&](FloatType type) { return convertFloatType(type); }) + .Case([&](FunctionType type) { return convertFunctionType(type); }) + .Case([&](IndexType type) { return convertIndexType(type); }) + .Case([&](IntegerType type) { return convertIntegerType(type); }) + .Case([&](MemRefType type) { return convertMemRefType(type); }) + .Case([&](UnrankedMemRefType type) { + return convertUnrankedMemRefType(type); + }) + .Case([&](VectorType type) { return convertVectorType(type); }) + .Case([](LLVM::LLVMType type) { return type; }) + .Default([](Type) { return Type(); }); } LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 2cb75de084a..f7591bf7480 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" @@ -1634,54 +1635,33 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { return success(); } -LogicalResult Serializer::processOperation(Operation *op) { - LLVM_DEBUG(llvm::dbgs() << "[op] '" << op->getName() << "'\n"); +LogicalResult Serializer::processOperation(Operation *opInst) { + LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); // First dispatch the ops that do not directly mirror an instruction from // the SPIR-V spec. - if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) { - return processAddressOfOp(addressOfOp); - } - if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { - return processBranchOp(branchOp); - } - if (auto condBranchOp = dyn_cast<spirv::BranchConditionalOp>(op)) { - return processBranchConditionalOp(condBranchOp); - } - if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) { - return processConstantOp(constOp); - } - if (auto fnOp = dyn_cast<FuncOp>(op)) { - return processFuncOp(fnOp); - } - if (auto varOp = dyn_cast<spirv::VariableOp>(op)) { - return processVariableOp(varOp); - } - if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { - return processGlobalVariableOp(varOp); - } - if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) { - return processSelectionOp(selectionOp); - } - if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) { - return processLoopOp(loopOp); - } - if (isa<spirv::ModuleEndOp>(op)) { - return success(); - } - if (auto refOpOp = dyn_cast<spirv::ReferenceOfOp>(op)) { - return processReferenceOfOp(refOpOp); - } - if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) { - return processSpecConstantOp(specConstOp); - } - if (auto undefOp = dyn_cast<spirv::UndefOp>(op)) { - return processUndefOp(undefOp); - } - - // Then handle all the ops that directly mirror SPIR-V instructions with - // auto-generated methods. - return dispatchToAutogenSerialization(op); + return TypeSwitch<Operation *, LogicalResult>(opInst) + .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) + .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) + .Case([&](spirv::BranchConditionalOp op) { + return processBranchConditionalOp(op); + }) + .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) + .Case([&](FuncOp op) { return processFuncOp(op); }) + .Case([&](spirv::GlobalVariableOp op) { + return processGlobalVariableOp(op); + }) + .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) + .Case([&](spirv::ModuleEndOp) { return success(); }) + .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) + .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) + .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) + .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) + .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) + + // Then handle all the ops that directly mirror SPIR-V instructions with + // auto-generated methods. + .Default([&](auto *op) { return dispatchToAutogenSerialization(op); }); } namespace { diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 190b6c3155e..79a6d7a6902 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/Utils.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" @@ -47,14 +48,9 @@ static bool isMemRefDereferencingOp(Operation &op) { /// Return the AffineMapAttr associated with memory 'op' on 'memref'. static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { - if (auto loadOp = dyn_cast<AffineLoadOp>(op)) - return loadOp.getAffineMapAttrForMemRef(memref); - else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) - return storeOp.getAffineMapAttrForMemRef(memref); - else if (auto dmaStart = dyn_cast<AffineDmaStartOp>(op)) - return dmaStart.getAffineMapAttrForMemRef(memref); - assert(isa<AffineDmaWaitOp>(op)); - return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref); + return TypeSwitch<Operation *, NamedAttribute>(op) + .Case<AffineDmaStartOp, AffineLoadOp, AffineStoreOp, AffineDmaWaitOp>( + [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); } // Perform the replacement in `op`. |

