diff options
| author | River Riddle <riverriddle@google.com> | 2019-05-11 15:17:28 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2019-05-20 13:37:10 -0700 |
| commit | 02e03b9bf4a1fe60b89d4bd662895ebcc374129b (patch) | |
| tree | f300b5e0886e01bd59c6fb7d7042b4e1fbf9ff3d /mlir | |
| parent | 360f8a209e21b058cc20949fc8600817b0a1044c (diff) | |
| download | bcm5719-llvm-02e03b9bf4a1fe60b89d4bd662895ebcc374129b.tar.gz bcm5719-llvm-02e03b9bf4a1fe60b89d4bd662895ebcc374129b.zip | |
Add support for using llvm::dyn_cast/cast/isa for operation casts and replace usages of Operation::dyn_cast with llvm::dyn_cast.
--
PiperOrigin-RevId: 247778391
Diffstat (limited to 'mlir')
43 files changed, 140 insertions, 130 deletions
diff --git a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp index ecb6309466a..a7fba179c79 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp @@ -31,7 +31,7 @@ ViewOp linalg::getViewBaseViewOp(Value *view) { auto viewType = view->getType().dyn_cast<ViewType>(); (void)viewType; assert(viewType.isa<ViewType>() && "expected a ViewType"); - while (auto slice = view->getDefiningOp()->dyn_cast<SliceOp>()) { + while (auto slice = dyn_cast<SliceOp>(view->getDefiningOp())) { view = slice.getParentView(); assert(viewType.isa<ViewType>() && "expected a ViewType"); } @@ -48,7 +48,7 @@ std::pair<mlir::Value *, unsigned> linalg::getViewRootIndexing(Value *view, (void)viewType; assert(viewType.isa<ViewType>() && "expected a ViewType"); assert(dim < viewType.getRank() && "dim exceeds rank"); - if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>()) + if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp())) return std::make_pair(viewOp.getIndexing(dim), dim); auto sliceOp = view->getDefiningOp()->cast<SliceOp>(); diff --git a/mlir/examples/Linalg/Linalg1/lib/Common.cpp b/mlir/examples/Linalg/Linalg1/lib/Common.cpp index bfdc40a6aa0..278f9c57607 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Common.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Common.cpp @@ -40,7 +40,7 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( assert(ivs.size() == indexings.size()); for (unsigned i = 0, e = indexings.size(); i < e; ++i) { auto rangeOp = - indexings[i].getValue()->getDefiningOp()->dyn_cast<RangeOp>(); + llvm::dyn_cast<RangeOp>(indexings[i].getValue()->getDefiningOp()); if (!rangeOp) { continue; } diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index 372c08f9eea..5bcebc79c18 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -33,7 +33,7 @@ using namespace linalg::intrinsics; unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa<ViewType>() && "expected a ViewType"); - if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>()) + if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp())) return viewOp.getRank(); return view->getDefiningOp()->cast<SliceOp>().getRank(); } diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index d1af7503d1b..83fd9ad3143 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -43,7 +43,7 @@ using namespace linalg::intrinsics; // analyses. This builds the chain. static SmallVector<Value *, 8> getViewChain(mlir::Value *v) { assert(v->getType().isa<ViewType>() && "ViewType expected"); - if (v->getDefiningOp()->dyn_cast<ViewOp>()) { + if (v->getDefiningOp()->isa<ViewOp>()) { return SmallVector<mlir::Value *, 8>{v}; } @@ -53,7 +53,7 @@ static SmallVector<Value *, 8> getViewChain(mlir::Value *v) { tmp.push_back(v); v = sliceOp.getParentView(); } while (!v->getType().isa<ViewType>()); - assert(v->getDefiningOp()->cast<ViewOp>() && "must be a ViewOp"); + assert(v->getDefiningOp()->isa<ViewOp>() && "must be a ViewOp"); tmp.push_back(v); return SmallVector<mlir::Value *, 8>(tmp.rbegin(), tmp.rend()); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h index 9339d7309e3..3090f29dcfc 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -91,7 +91,7 @@ inline llvm::SmallVector<mlir::Value *, 8> extractRangesFromViewOrSliceOp(mlir::Value *view) { // This expects a viewType which must come from either ViewOp or SliceOp. assert(view->getType().isa<linalg::ViewType>() && "expected ViewType"); - if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>()) + if (auto viewOp = llvm::dyn_cast<linalg::ViewOp>(view->getDefiningOp())) return viewOp.getRanges(); auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>(); diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 42999aef7ae..bce7f58860d 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -46,9 +46,9 @@ void linalg::composeSliceOps(mlir::Function *f) { void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { f->walk([](Operation *op) { - if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) { + if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) { matmulOp.writeAsFinerGrainTensorContraction(); - } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) { + } else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) { matvecOp.writeAsFinerGrainTensorContraction(); } else { return; @@ -205,11 +205,11 @@ writeContractionAsLoops(ContractionOp contraction) { llvm::Optional<SmallVector<mlir::AffineForOp, 4>> linalg::writeAsLoops(Operation *op) { - if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) { + if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) { return writeContractionAsLoops(matmulOp); - } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) { + } else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) { return writeContractionAsLoops(matvecOp); - } else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) { + } else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) { return writeContractionAsLoops(dotOp); } return llvm::None; @@ -276,7 +276,7 @@ PatternMatchResult Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto load = op->cast<linalg::LoadOp>(); - SliceOp slice = load.getView()->getDefiningOp()->dyn_cast<SliceOp>(); + SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : load.getView()->getDefiningOp()->cast<ViewOp>(); ScopedContext scope(FuncBuilder(load), load.getLoc()); @@ -291,7 +291,7 @@ PatternMatchResult Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto store = op->cast<linalg::StoreOp>(); - SliceOp slice = store.getView()->getDefiningOp()->dyn_cast<SliceOp>(); + SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : store.getView()->getDefiningOp()->cast<ViewOp>(); ScopedContext scope(FuncBuilder(store), store.getLoc()); diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 05865e9e53c..6771257ae0f 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -52,8 +52,8 @@ void linalg::lowerToTiledLoops(mlir::Function *f, } static bool isZeroIndex(Value *v) { - return v->getDefiningOp() && v->getDefiningOp()->isa<ConstantIndexOp>() && - v->getDefiningOp()->dyn_cast<ConstantIndexOp>().getValue() == 0; + return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) && + cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0; } template <typename ConcreteOp> @@ -178,11 +178,11 @@ writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction, llvm::Optional<SmallVector<mlir::AffineForOp, 8>> linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) { - if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) { + if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) { return writeContractionAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) { + } else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) { return writeContractionAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) { + } else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) { return writeContractionAsTiledViews(dotOp, tileSizes); } return llvm::None; @@ -190,11 +190,11 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) { void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef<Value *> tileSizes) { f->walk([tileSizes](Operation *op) { - if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) { + if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) { writeAsTiledViews(matmulOp, tileSizes); - } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) { + } else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) { writeAsTiledViews(matvecOp, tileSizes); - } else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) { + } else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) { writeAsTiledViews(dotOp, tileSizes); } else { return; diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index a11c88266b7..c9f98e7d6a9 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -238,13 +238,13 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast<AddOp>()) { + if (auto addOp = llvm::dyn_cast<AddOp>(op)) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } // Transpose is easy: just invert the dimensions. - if (auto transpose = op->dyn_cast<TransposeOp>()) { + if (auto transpose = llvm::dyn_cast<TransposeOp>(op)) { SmallVector<int64_t, 2> dims; auto arrayTy = transpose.getOperand()->getType().cast<ToyArrayType>(); dims.insert(dims.end(), arrayTy.getShape().begin(), @@ -259,7 +259,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = op->dyn_cast<MulOp>()) { + if (auto mulOp = llvm::dyn_cast<MulOp>(op)) { auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>(); auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>(); auto lhsRank = lhs.getShape().size(); @@ -291,7 +291,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = op->dyn_cast<GenericCallOp>()) { + if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index f3e8ff06781..942ce866182 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -53,7 +53,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { // Look through the input of the current transpose. mlir::Value *transposeInput = transpose.getOperand(); TransposeOp transposeInputOp = - mlir::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp()); + llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! if (!transposeInputOp) return matchFailure(); @@ -75,7 +75,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { ReshapeOp reshape = op->cast<ReshapeOp>(); // Look through the input of the current reshape. - ConstantOp constantOp = mlir::dyn_cast_or_null<ConstantOp>( + ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>( reshape.getOperand()->getDefiningOp()); // If the input is defined by another constant, bingo! if (!constantOp) diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 4ef62d33adc..534b5cbd2ab 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -366,7 +366,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> { // First patch calls type to return memref instead of ToyArray for (auto &function : getModule()) { function.walk([&](Operation *op) { - auto callOp = op->dyn_cast<CallOp>(); + auto callOp = dyn_cast<CallOp>(op); if (!callOp) return; if (!callOp.getNumResults()) @@ -382,14 +382,14 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> { for (auto &function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). - if (auto allocOp = op->dyn_cast<toy::AllocOp>()) { + if (auto allocOp = dyn_cast<toy::AllocOp>(op)) { auto result = allocTensor(allocOp); allocOp.replaceAllUsesWith(result); allocOp.erase(); return; } // Eliminate all type.cast before lowering to LLVM. - if (auto typeCastOp = op->dyn_cast<toy::TypeCastOp>()) { + if (auto typeCastOp = dyn_cast<toy::TypeCastOp>(op)) { typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); typeCastOp.erase(); return; @@ -429,7 +429,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> { // Insert a `dealloc` operation right before the `return` operations, unless // it is returned itself in which case the caller is responsible for it. builder.getFunction()->walk([&](Operation *op) { - auto returnOp = op->dyn_cast<ReturnOp>(); + auto returnOp = dyn_cast<ReturnOp>(op); if (!returnOp) return; if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index a083e62f05f..4e17b234d14 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -238,7 +238,7 @@ public: LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); // The add operation is trivial: propagate the input type as is. - if (auto addOp = op->dyn_cast<AddOp>()) { + if (auto addOp = llvm::dyn_cast<AddOp>(op)) { op->getResult(0)->setType(op->getOperand(0)->getType()); continue; } @@ -261,7 +261,7 @@ public: // catch it but shape inference earlier in the pass could generate an // invalid IR (from an invalid Toy input of course) and we wouldn't want // to crash here. - if (auto mulOp = op->dyn_cast<MulOp>()) { + if (auto mulOp = llvm::dyn_cast<MulOp>(op)) { auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>(); auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>(); auto lhsRank = lhs.getShape().size(); @@ -295,7 +295,7 @@ public: // for this function, queue the callee in the inter-procedural work list, // and return. The current function stays in the work list and will // restart after the callee is processed. - if (auto callOp = op->dyn_cast<GenericCallOp>()) { + if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) { auto calleeName = callOp.getCalleeName(); auto *callee = getModule().getNamedFunction(calleeName); if (!callee) { diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index 5d23488c95d..39302f6c0f9 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -439,7 +439,7 @@ ValueHandle ValueHandle::create(Args... args) { if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } else if (op->getNumResults() == 0) { - if (auto f = op->dyn_cast<AffineForOp>()) { + if (auto f = dyn_cast<AffineForOp>(op)) { return ValueHandle(f.getInductionVar()); } } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1ee6c4806fb..7f182e882db 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -271,7 +271,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = op->dyn_cast<OpTy>(); + auto result = dyn_cast<OpTy>(op); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 0770d2cfa27..d4b85b56d0f 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -116,7 +116,7 @@ public: /// Specialization of walk to only visit operations of 'OpTy'. template <typename OpTy> void walk(std::function<void(OpTy)> callback) { walk([&](Operation *opInst) { - if (auto op = opInst->dyn_cast<OpTy>()) + if (auto op = dyn_cast<OpTy>(opInst)) callback(op); }); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index b80e8aca9bc..2eff412a71e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -792,7 +792,7 @@ public: /// This is the hook used by the AsmPrinter to emit this to the .mlir file. /// Op implementations should provide a print method. static void printAssembly(Operation *op, OpAsmPrinter *p) { - auto opPointer = op->dyn_cast<ConcreteType>(); + auto opPointer = dyn_cast<ConcreteType>(op); assert(opPointer && "op's name does not match name of concrete type instantiated with"); opPointer.print(p); @@ -825,11 +825,13 @@ public: /// This is a public constructor. Any op can be initialized to null. explicit Op() : OpState(nullptr) {} + Op(std::nullptr_t) : OpState(nullptr) {} -protected: - /// This is a private constructor only accessible through the - /// Operation::cast family of methods. - explicit Op(Operation *state) : OpState(state) {} + /// This is a public constructor to enable access via the llvm::cast family of + /// methods. This should not be used directly. + explicit Op(Operation *state) : OpState(state) { + assert(!state || isa<ConcreteOpType>(state)); + } friend class Operation; private: diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 54e49b73e3b..31ec8ea54a6 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -389,14 +389,6 @@ public: // Conversions to declared operations like DimOp //===--------------------------------------------------------------------===// - /// The dyn_cast methods perform a dynamic cast from an Operation to a typed - /// Op like DimOp. This returns a null Op on failure. - template <typename OpClass> OpClass dyn_cast() { - if (isa<OpClass>()) - return cast<OpClass>(); - return OpClass(); - } - /// The cast methods perform a cast from an Operation to a typed Op like /// DimOp. This aborts if the parameter to the template isn't an instance of /// the template type argument. @@ -417,10 +409,10 @@ public: /// including this one. void walk(const std::function<void(Operation *)> &callback); - /// Specialization of walk to only visit operations of 'OpTy'. - template <typename OpTy> void walk(std::function<void(OpTy)> callback) { + /// Specialization of walk to only visit operations of 'T'. + template <typename T> void walk(std::function<void(T)> callback) { walk([&](Operation *op) { - if (auto derivedOp = op->dyn_cast<OpTy>()) + if (auto derivedOp = dyn_cast<T>(op)) callback(derivedOp); }); } @@ -534,17 +526,6 @@ inline auto Operation::getOperands() -> operand_range { return {operand_begin(), operand_end()}; } -/// Provide dyn_cast_or_null functionality for Operation casts. -template <typename T> T dyn_cast_or_null(Operation *op) { - return op ? op->dyn_cast<T>() : T(); -} - -/// Provide isa_and_nonnull functionality for Operation casts, i.e. if the -/// operation is non-null and a class of 'T'. -template <typename T> bool isa_and_nonnull(Operation *op) { - return op && op->isa<T>(); -} - /// This class implements the result iterators for the Operation class /// in terms of getResult(idx). class ResultIterator final @@ -598,4 +579,30 @@ inline auto Operation::getResultTypes() } // end namespace mlir +namespace llvm { +/// Provide isa functionality for operation casts. +template <typename T> struct isa_impl<T, ::mlir::Operation> { + static inline bool doit(const ::mlir::Operation &op) { + return T::classof(const_cast<::mlir::Operation *>(&op)); + } +}; + +/// Provide specializations for operation casts as the resulting T is value +/// typed. +template <typename T> struct cast_retty_impl<T, ::mlir::Operation *> { + using ret_type = T; +}; +template <typename T> struct cast_retty_impl<T, ::mlir::Operation> { + using ret_type = T; +}; +template <class T> +struct cast_convert_val<T, ::mlir::Operation, ::mlir::Operation> { + static T doit(::mlir::Operation &val) { return T(&val); } +}; +template <class T> +struct cast_convert_val<T, ::mlir::Operation *, ::mlir::Operation *> { + static T doit(::mlir::Operation *val) { return T(val); } +}; +} // end namespace llvm + #endif // MLIR_IR_OPERATION_H diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 3b02ed55c34..51528c18d38 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -215,7 +215,7 @@ public: OperationState state(getContext(), location, OpTy::getOperationName()); OpTy::build(this, &state, args...); auto *op = createOperation(state); - auto result = op->dyn_cast<OpTy>(); + auto result = dyn_cast<OpTy>(op); assert(result && "Builder didn't return the right type"); return result; } @@ -231,7 +231,7 @@ public: // If the Operation we produce is valid, return it. if (!OpTy::verifyInvariants(op)) { - auto result = op->dyn_cast<OpTy>(); + auto result = dyn_cast<OpTy>(op); assert(result && "Builder didn't return the right type"); return result; } diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h index 031dceb518e..6676ad0d818 100644 --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -69,6 +69,7 @@ using llvm::cast_or_null; using llvm::dyn_cast; using llvm::dyn_cast_or_null; using llvm::isa; +using llvm::isa_and_nonnull; // Containers. using llvm::ArrayRef; diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 51209da7385..2dfed934ee0 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -61,11 +61,11 @@ bool mlir::isValidDim(Value *value) { if (op->getParentOp() == nullptr || op->isa<ConstantOp>()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = op->dyn_cast<AffineApplyOp>()) + if (auto applyOp = dyn_cast<AffineApplyOp>(op)) return applyOp.isValidDim(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = op->dyn_cast<DimOp>()) + if (auto dimOp = dyn_cast<DimOp>(op)) return isTopLevelSymbol(dimOp.getOperand()); return false; } @@ -86,11 +86,11 @@ bool mlir::isValidSymbol(Value *value) { if (op->getParentOp() == nullptr || op->isa<ConstantOp>()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto applyOp = op->dyn_cast<AffineApplyOp>()) + if (auto applyOp = dyn_cast<AffineApplyOp>(op)) return applyOp.isValidSymbol(); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = op->dyn_cast<DimOp>()) + if (auto dimOp = dyn_cast<DimOp>(op)) return isTopLevelSymbol(dimOp.getOperand()); return false; } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 78caa4c2625..60f2b142986 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -320,8 +320,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, loadAndStores.match(forOp, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { auto *op = ls.getMatchedOperation(); - auto load = op->dyn_cast<LoadOp>(); - auto store = op->dyn_cast<StoreOp>(); + auto load = dyn_cast<LoadOp>(op); + auto store = dyn_cast<StoreOp>(op); // Only scalar types are considered vectorizable, all load/store must be // vectorizable for a loop to qualify as vectorizable. // TODO(ntv): ponder whether we want to be more general here. @@ -338,8 +338,8 @@ isVectorizableLoopBodyWithOpCond(AffineForOp loop, bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim) { VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { - auto load = op.dyn_cast<LoadOp>(); - auto store = op.dyn_cast<StoreOp>(); + auto load = dyn_cast<LoadOp>(op); + auto store = dyn_cast<StoreOp>(op); return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim) : isContiguousAccess(loop.getInductionVar(), store, memRefDim); }); diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 0fb88620fa1..4e23441d5a5 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -48,9 +48,9 @@ FunctionPassBase *mlir::createMemRefBoundCheckPass() { void MemRefBoundCheck::runOnFunction() { getFunction().walk([](Operation *opInst) { - if (auto loadOp = opInst->dyn_cast<LoadOp>()) { + if (auto loadOp = dyn_cast<LoadOp>(opInst)) { boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + } else if (auto storeOp = dyn_cast<StoreOp>(opInst)) { boundCheckLoadOrStoreOp(storeOp); } // TODO(bondhugula): do this for DMA ops as well. diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index bce000a4c1f..155a2bbbd1b 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -50,7 +50,7 @@ static void getForwardSliceImpl(Operation *op, return; } - if (auto forOp = op->dyn_cast<AffineForOp>()) { + if (auto forOp = dyn_cast<AffineForOp>(op)) { for (auto &u : forOp.getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 1eaab676567..8d963e4739c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -44,7 +44,7 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) { AffineForOp currAffineForOp; // Traverse up the hierarchy collecing all 'affine.for' operation while // skipping over 'affine.if' operations. - while (currOp && ((currAffineForOp = currOp->dyn_cast<AffineForOp>()) || + while (currOp && ((currAffineForOp = dyn_cast<AffineForOp>(currOp)) || currOp->isa<AffineIfOp>())) { if (currAffineForOp) loops->push_back(currAffineForOp); @@ -239,7 +239,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol->getDefiningOp()) { - if (auto constOp = op->dyn_cast<ConstantIndexOp>()) { + if (auto constOp = dyn_cast<ConstantIndexOp>(op)) { cst.setIdToConstant(*symbol, constOp.getValue()); } } @@ -467,7 +467,7 @@ static Operation *getInstAtPosition(ArrayRef<unsigned> positions, } if (level == positions.size() - 1) return &op; - if (auto childAffineForOp = op.dyn_cast<AffineForOp>()) + if (auto childAffineForOp = dyn_cast<AffineForOp>(op)) return getInstAtPosition(positions, level + 1, childAffineForOp.getBody()); @@ -633,7 +633,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { - if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) { + if (auto loadOp = dyn_cast<LoadOp>(loadOrStoreOpInst)) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp.getMemRefType(); @@ -643,7 +643,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { } } else { assert(loadOrStoreOpInst->isa<StoreOp>() && "load/store op expected"); - auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>(); + auto storeOp = dyn_cast<StoreOp>(loadOrStoreOpInst); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); auto storeMemrefType = storeOp.getMemRefType(); @@ -750,7 +750,7 @@ Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp, void mlir::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet<Value *, 8> *sequentialLoops) { forOp.getOperation()->walk([&](Operation *op) { - if (auto innerFor = op->dyn_cast<AffineForOp>()) + if (auto innerFor = dyn_cast<AffineForOp>(op)) if (!isLoopParallel(innerFor)) sequentialLoops->insert(innerFor.getInductionVar()); }); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index b45ac001be4..8fecf058bfc 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -152,7 +152,7 @@ static SetVector<Operation *> getParentsOfType(Operation *op) { SetVector<Operation *> res; auto *current = op; while (auto *parent = current->getParentOp()) { - if (auto typedParent = parent->template dyn_cast<T>()) { + if (auto typedParent = dyn_cast<T>(parent)) { assert(res.count(parent) == 0 && "Already inserted"); res.insert(parent); } @@ -177,7 +177,7 @@ AffineMap mlir::makePermutationMap( } } - if (auto load = op->dyn_cast<LoadOp>()) { + if (auto load = dyn_cast<LoadOp>(op)) { return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim); } @@ -198,10 +198,10 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (auto read = op.dyn_cast<VectorTransferReadOp>()) { + if (auto read = dyn_cast<VectorTransferReadOp>(op)) { superVectorType = read.getResultType(); mustDivide = true; - } else if (auto write = op.dyn_cast<VectorTransferWriteOp>()) { + } else if (auto write = dyn_cast<VectorTransferWriteOp>(op)) { superVectorType = write.getVectorType(); mustDivide = true; } else if (op.getNumResults() == 0) { diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 610c8b66320..2c9117736ae 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -100,7 +100,7 @@ ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands, if (op->getNumResults() == 1) { return ValueHandle(op->getResult(0)); } - if (auto f = op->dyn_cast<AffineForOp>()) { + if (auto f = dyn_cast<AffineForOp>(op)) { return ValueHandle(f.getInductionVar()); } llvm_unreachable("unsupported operation, use an OperationHandle instead"); @@ -147,8 +147,8 @@ static llvm::Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs, if (!lbDef || !ubDef) return llvm::Optional<ValueHandle>(); - auto lbConst = lbDef->dyn_cast<ConstantIndexOp>(); - auto ubConst = ubDef->dyn_cast<ConstantIndexOp>(); + auto lbConst = dyn_cast<ConstantIndexOp>(lbDef); + auto ubConst = dyn_cast<ConstantIndexOp>(ubDef); if (!lbConst || !ubConst) return llvm::Optional<ValueHandle>(); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 434f7206e04..6e20542a818 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -319,11 +319,11 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes, // TODO(ntv) expose as a primitive for other passes. static LogicalResult tileLinalgOp(Operation *op, ArrayRef<int64_t> tileSizes, PerFunctionState &state) { - if (auto matmulOp = op->dyn_cast<MatmulOp>()) { + if (auto matmulOp = dyn_cast<MatmulOp>(op)) { return tileLinalgOp(matmulOp, tileSizes, state); - } else if (auto matvecOp = op->dyn_cast<MatvecOp>()) { + } else if (auto matvecOp = dyn_cast<MatvecOp>(op)) { return tileLinalgOp(matvecOp, tileSizes, state); - } else if (auto dotOp = op->dyn_cast<DotOp>()) { + } else if (auto dotOp = dyn_cast<DotOp>(op)) { return tileLinalgOp(dotOp, tileSizes, state); } return failure(); diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 4b77ece21dd..98cf4b75b6a 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -68,9 +68,9 @@ ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( SmallVector<Value *, 8> mlir::getRanges(Operation *op) { SmallVector<Value *, 8> res; - if (auto view = op->dyn_cast<ViewOp>()) { + if (auto view = dyn_cast<ViewOp>(op)) { res.append(view.getIndexings().begin(), view.getIndexings().end()); - } else if (auto slice = op->dyn_cast<SliceOp>()) { + } else if (auto slice = dyn_cast<SliceOp>(op)) { for (auto *i : slice.getIndexings()) if (i->getType().isa<RangeType>()) res.push_back(i); @@ -100,7 +100,7 @@ SmallVector<Value *, 8> mlir::getRanges(Operation *op) { Value *mlir::createOrReturnView(FuncBuilder *b, Location loc, Operation *viewDefiningOp, ArrayRef<Value *> ranges) { - if (auto view = viewDefiningOp->dyn_cast<ViewOp>()) { + if (auto view = dyn_cast<ViewOp>(viewDefiningOp)) { auto indexings = view.getIndexings(); if (std::equal(indexings.begin(), indexings.end(), ranges.begin())) return view.getResult(); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 05e3b13eb4c..bc68a78bd0a 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -134,7 +134,7 @@ struct MemRefCastFolder : public RewritePattern { void rewrite(Operation *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = memref->dyn_cast<MemRefCastOp>()) + if (auto cast = dyn_cast<MemRefCastOp>(memref)) op->setOperand(i, cast.getOperand()); rewriter.updatedRootInPlace(op); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 8a9c649feb3..597efc3ba37 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -199,11 +199,11 @@ bool ModuleTranslation::convertOperation(Operation &opInst, // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. - if (auto brOp = opInst.dyn_cast<LLVM::BrOp>()) { + if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) { builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); return false; } - if (auto condbrOp = opInst.dyn_cast<LLVM::CondBrOp>()) { + if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) { builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), blockMapping[condbrOp.getSuccessor(0)], blockMapping[condbrOp.getSuccessor(1)]); @@ -264,7 +264,7 @@ static Value *getPHISourceValue(Block *current, Block *pred, // For conditional branches, we need to check if the current block is reached // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = terminator.dyn_cast<LLVM::CondBrOp>(); + auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator); assert(condBranchOp && "only branch operations can be terminators of a block that " "has successors"); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 10f47fe9be1..937399cc703 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -173,11 +173,11 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; - if (auto loadOp = opInst->dyn_cast<LoadOp>()) { + if (auto loadOp = dyn_cast<LoadOp>(opInst)) { rank = loadOp.getMemRefType().getRank(); region->memref = loadOp.getMemRef(); region->setWrite(false); - } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + } else if (auto storeOp = dyn_cast<StoreOp>(opInst)) { rank = storeOp.getMemRefType().getRank(); region->memref = storeOp.getMemRef(); region->setWrite(true); @@ -483,7 +483,7 @@ bool DmaGeneration::runOnBlock(Block *block) { }); for (auto it = curBegin; it != block->end(); ++it) { - if (auto forOp = it->dyn_cast<AffineForOp>()) { + if (auto forOp = dyn_cast<AffineForOp>(&*it)) { // Returns true if the footprint is known to exceed capacity. auto exceedsCapacity = [&](AffineForOp forOp) { Optional<int64_t> footprint = @@ -607,10 +607,10 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // Walk this range of operations to gather all memory regions. block->walk(begin, end, [&](Operation *opInst) { // Gather regions to allocate to buffers in faster memory space. - if (auto loadOp = opInst->dyn_cast<LoadOp>()) { + if (auto loadOp = dyn_cast<LoadOp>(opInst)) { if (loadOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + } else if (auto storeOp = dyn_cast<StoreOp>(opInst)) { if (storeOp.getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { @@ -739,7 +739,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { // For a range of operations, a note will be emitted at the caller. AffineForOp forOp; uint64_t sizeInKib = llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024); - if (llvm::DebugFlag && (forOp = begin->dyn_cast<AffineForOp>())) { + if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) { forOp.emitRemark() << sizeInKib << " KiB of DMA buffers in fast memory space for this block\n"; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 796d2164ad9..1c4a4d1f755 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -644,7 +644,7 @@ bool MemRefDependenceGraph::init(Function &f) { DenseMap<Operation *, unsigned> forToNodeMap; for (auto &op : f.front()) { - if (auto forOp = op.dyn_cast<AffineForOp>()) { + if (auto forOp = dyn_cast<AffineForOp>(op)) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; @@ -666,14 +666,14 @@ bool MemRefDependenceGraph::init(Function &f) { } forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); - } else if (auto loadOp = op.dyn_cast<LoadOp>()) { + } else if (auto loadOp = dyn_cast<LoadOp>(op)) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); auto *memref = op.cast<LoadOp>().getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = op.dyn_cast<StoreOp>()) { + } else if (auto storeOp = dyn_cast<StoreOp>(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); @@ -2125,7 +2125,7 @@ public: auto *fn = dstNode->op->getFunction(); for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { for (auto &use : fn->getArgument(i)->getUses()) { - if (auto loadOp = use.getOwner()->dyn_cast<LoadOp>()) { + if (auto loadOp = dyn_cast<LoadOp>(use.getOwner())) { // Gather loops surrounding 'use'. SmallVector<AffineForOp, 4> loops; getLoopIVs(*use.getOwner(), &loops); diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index ce42a5eba85..28e13d89ada 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -273,7 +273,7 @@ static void getTileableBands(Function &f, for (auto &block : f) for (auto &op : block) - if (auto forOp = op.dyn_cast<AffineForOp>()) + if (auto forOp = dyn_cast<AffineForOp>(op)) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 366a7ede5eb..0a23295c8d9 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -92,7 +92,7 @@ void LoopUnrollAndJam::runOnFunction() { // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on // any for operation. auto &entryBlock = getFunction().front(); - if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>()) + if (auto forOp = dyn_cast<AffineForOp>(entryBlock.front())) runOnAffineForOp(forOp); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index dc389c8e37a..1ffe5e3ddd7 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -620,10 +620,10 @@ void LowerAffinePass::runOnFunction() { // Rewrite all of the ifs and fors. We walked the operations in postorders, // so we know that we will rewrite them in the reverse order. for (auto *op : llvm::reverse(instsToRewrite)) { - if (auto ifOp = op->dyn_cast<AffineIfOp>()) { + if (auto ifOp = dyn_cast<AffineIfOp>(op)) { if (lowerAffineIf(ifOp)) return signalPassFailure(); - } else if (auto forOp = op->dyn_cast<AffineForOp>()) { + } else if (auto forOp = dyn_cast<AffineForOp>(op)) { if (lowerAffineFor(forOp)) return signalPassFailure(); } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2f06a9aa3bf..28dfb2278e0 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -556,12 +556,12 @@ static bool instantiateMaterialization(Operation *op, if (op->getNumRegions() != 0) return op->emitError("NYI path Op with region"), true; - if (auto write = op->dyn_cast<VectorTransferWriteOp>()) { + if (auto write = dyn_cast<VectorTransferWriteOp>(op)) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = op->dyn_cast<VectorTransferReadOp>()) { + if (auto read = dyn_cast<VectorTransferReadOp>(op)) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index a63d462c4a9..94df936c93f 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -103,7 +103,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { SmallVector<Operation *, 8> storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (auto &use : loadOp.getMemRef()->getUses()) { - auto storeOp = use.getOwner()->dyn_cast<StoreOp>(); + auto storeOp = dyn_cast<StoreOp>(use.getOwner()); if (!storeOp) continue; auto *storeOpInst = storeOp.getOperation(); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 66fbf4a1306..0da97f7d169 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -181,7 +181,7 @@ static void findMatchingStartFinishInsts( // Collect outgoing DMA operations - needed to check for dependences below. SmallVector<DmaStartOp, 4> outgoingDmaOps; for (auto &op : *forOp.getBody()) { - auto dmaStartOp = op.dyn_cast<DmaStartOp>(); + auto dmaStartOp = dyn_cast<DmaStartOp>(op); if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } @@ -193,7 +193,7 @@ static void findMatchingStartFinishInsts( dmaFinishInsts.push_back(&op); continue; } - auto dmaStartOp = op.dyn_cast<DmaStartOp>(); + auto dmaStartOp = dyn_cast<DmaStartOp>(op); if (!dmaStartOp) continue; diff --git a/mlir/lib/Transforms/TestConstantFold.cpp b/mlir/lib/Transforms/TestConstantFold.cpp index 0990d7a73f6..ec1e971973e 100644 --- a/mlir/lib/Transforms/TestConstantFold.cpp +++ b/mlir/lib/Transforms/TestConstantFold.cpp @@ -48,7 +48,7 @@ void TestConstantFold::foldOperation(Operation *op, } // If this op is a constant that are used and cannot be de-duplicated, // remember it for cleanup later. - else if (auto constant = op->dyn_cast<ConstantOp>()) { + else if (auto constant = dyn_cast<ConstantOp>(op)) { existingConstants.push_back(op); } } diff --git a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp index fc8209be872..b907840b27d 100644 --- a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp @@ -40,7 +40,7 @@ bool ConstantFoldHelper::tryToConstantFold( // into the value it contains. We need to consider constants before the // constant folding logic to avoid re-creating the same constant later. // TODO: Extend to support dialect-specific constant ops. - if (auto constant = op->dyn_cast<ConstantOp>()) { + if (auto constant = dyn_cast<ConstantOp>(op)) { // If this constant is dead, update bookkeeping and signal the caller. if (constant.use_empty()) { notifyRemoval(op); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index a10e4a1ae49..7fbb48ecf99 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -363,7 +363,7 @@ void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops, nestedLoops.push_back(curr); auto *currBody = curr.getBody(); while (currBody->begin() == std::prev(currBody->end(), 2) && - (curr = curr.getBody()->front().dyn_cast<AffineForOp>())) { + (curr = dyn_cast<AffineForOp>(curr.getBody()->front()))) { nestedLoops.push_back(curr); currBody = curr.getBody(); } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 753f7cf750f..b64dc53e037 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -234,7 +234,7 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { static bool affineApplyOp(Operation &op) { return op.isa<AffineApplyOp>(); } static bool singleResultAffineApplyOpWithoutUses(Operation &op) { - auto app = op.dyn_cast<AffineApplyOp>(); + auto app = dyn_cast<AffineApplyOp>(op); return app && app.use_empty(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 025a6535a78..9b8768a6445 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -839,8 +839,8 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = ls.getMatchedOperation(); - auto load = opInst->dyn_cast<LoadOp>(); - auto store = opInst->dyn_cast<StoreOp>(); + auto load = dyn_cast<LoadOp>(opInst); + auto store = dyn_cast<StoreOp>(opInst); LLVM_DEBUG(opInst->print(dbgs())); LogicalResult result = load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) @@ -982,7 +982,7 @@ static Value *vectorizeOperand(Value *operand, Operation *op, return nullptr; } // 3. vectorize constant. - if (auto constant = operand->getDefiningOp()->dyn_cast<ConstantOp>()) { + if (auto constant = dyn_cast<ConstantOp>(operand->getDefiningOp())) { return vectorizeConstant( op, constant, VectorType::get(state->strategy->vectorSizes, operand->getType())); @@ -1012,7 +1012,7 @@ static Operation *vectorizeOneOperation(Operation *opInst, assert(!opInst->isa<VectorTransferWriteOp>() && "vector.transfer_write cannot be further vectorized"); - if (auto store = opInst->dyn_cast<StoreOp>()) { + if (auto store = dyn_cast<StoreOp>(opInst)) { auto *memRef = store.getMemRef(); auto *value = store.getValueToStore(); auto *vectorValue = vectorizeOperand(value, opInst, state); diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index ec566e28825..5c34ed160b2 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -161,8 +161,8 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) { } // Output the check and the rewritten builder string. - os << "if (auto op = opInst.dyn_cast<" << op.getQualCppClassName() - << ">()) {\n"; + os << "if (auto op = dyn_cast<" << op.getQualCppClassName() + << ">(opInst)) {\n"; os << bs.str() << builderStrRef << "\n"; os << " return false;\n"; os << "}\n"; |

