diff options
-rw-r--r-- | mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 46 | ||||
-rw-r--r-- | mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 15 | ||||
-rw-r--r-- | mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 10 | ||||
-rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 26 | ||||
-rw-r--r-- | mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 6 |
7 files changed, 36 insertions, 79 deletions
diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 0b640298e02..0b8d6bd9913 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/LLVMIR/LLVMLowering.h" #include "mlir/LLVMIR/Transforms.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -392,7 +393,9 @@ public: : DialectOpConversion("some_consumer", 1, context) {} void rewrite(Operation *op, ArrayRef<Value *> operands, - PatternRewriter &rewriter) const override {} + PatternRewriter &rewriter) const override { + rewriter.replaceOp(op, llvm::None); + } }; void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns, @@ -403,7 +406,7 @@ void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns, namespace { // The conversion class from Linalg to LLVMIR. -class Lowering : public DialectConversion { +class Lowering : public LLVMLowering { public: explicit Lowering(std::function<void(mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context)> @@ -412,37 +415,13 @@ public: protected: // Initialize the list of converters. - void initConverters(OwningRewritePatternList &patterns, - MLIRContext *context) override { - setup(patterns, context); + void initAdditionalConverters(OwningRewritePatternList &patterns) override { + setup(patterns, llvmDialect->getContext()); } // This gets called for block and region arguments, and attributes. - Type convertType(Type t) override { return linalg::convertLinalgType(t); } - - // This gets called for function signatures. Convert function arguments and - // results to the LLVM types, but keep the outer function type as built-in - // MLIR function type. This does not support multi-result functions because - // LLVM does not. - FunctionType convertFunctionSignatureType( - FunctionType t, ArrayRef<NamedAttributeList> argAttrs, - SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override { - convertedArgAttrs.reserve(argAttrs.size()); - convertedArgAttrs.insert(convertedArgAttrs.end(), argAttrs.begin(), - argAttrs.end()); - - SmallVector<Type, 4> argTypes; - argTypes.reserve(t.getNumInputs()); - for (auto ty : t.getInputs()) - argTypes.push_back(linalg::convertLinalgType(ty)); - - SmallVector<Type, 1> resultTypes; - resultTypes.reserve(t.getNumResults()); - for (auto ty : t.getResults()) - resultTypes.push_back(linalg::convertLinalgType(ty)); - assert(t.getNumResults() <= 1 && "NYI: multi-result functions"); - - return FunctionType::get(argTypes, resultTypes, t.getContext()); + Type convertAdditionalType(Type t) override { + return linalg::convertLinalgType(t); } private: @@ -472,13 +451,6 @@ void linalg::convertToLLVM(mlir::Module &module) { auto r = Lowering(getDescriptorConverters).convert(&module); (void)r; assert(succeeded(r) && "conversion failed"); - - // Convert the remaining standard MLIR operations to the LLVM IR dialect using - // the default converter. - auto converter = createStdToLLVMConverter(); - r = converter->convert(&module); - (void)r; - assert(succeeded(r) && "second conversion failed"); } namespace { diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 8d945f3274c..b475ed2d2a1 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -58,13 +58,6 @@ public: : DialectOpConversion(Op::getOperationName(), 1, context) {} using Base = LoadStoreOpConversion<Op>; - // Match the Op specified as template argument. - PatternMatchResult match(Operation *op) const override { - if (isa<Op>(op)) - return matchSuccess(); - return matchFailure(); - } - // Compute the pointer to an element of the buffer underlying the view given // current view indices. Use the base offset and strides stored in the view // descriptor to emit IR iteratively computing the actual offset, followed by @@ -128,6 +121,7 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> { ArrayRef<Value *> indices = operands.drop_front(2); Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); intrinsics::store(data, ptr); + rewriter.replaceOp(op, llvm::None); } }; @@ -154,11 +148,4 @@ void linalg::convertLinalg3ToLLVM(Module &module) { auto r = lowering->convert(&module); (void)r; assert(succeeded(r) && "conversion failed"); - - // Convert the remaining standard MLIR operations to the LLVM IR dialect using - // the default converter. - auto converter = createStdToLLVMConverter(); - r = converter->convert(&module); - (void)r; - assert(succeeded(r) && "second conversion failed"); } diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 4cb34d1181a..00b3a74f018 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -175,6 +175,7 @@ public: }); // clang-format on } + rewriter.replaceOp(op, llvm::None); } private: @@ -308,14 +309,11 @@ public: void rewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override { - auto retOp = cast<toy::ReturnOp>(op); - using namespace edsc; - auto loc = retOp.getLoc(); // Argument is optional, handle both cases. - if (retOp.getNumOperands()) - rewriter.create<ReturnOp>(loc, operands[0]); + if (op->getNumOperands()) + rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]); else - rewriter.create<ReturnOp>(loc); + rewriter.replaceOpWithNewOp<ReturnOp>(op); } }; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 784674b8278..340e84457f6 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -254,8 +254,8 @@ public: /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. template <typename OpTy, typename... Args> - void replaceOpWithNewOp(Operation *op, Args... args) { - auto newOp = create<OpTy>(op->getLoc(), args...); + void replaceOpWithNewOp(Operation *op, Args &&... args) { + auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {}); } @@ -263,9 +263,9 @@ public: /// The result values of the two ops must be the same types. This allows /// specifying a list of ops that may be removed if dead. template <typename OpTy, typename... Args> - void replaceOpWithNewOp(Operation *op, ArrayRef<Value *> valuesToRemoveIfDead, - Args... args) { - auto newOp = create<OpTy>(op->getLoc(), args...); + void replaceOpWithNewOp(ArrayRef<Value *> valuesToRemoveIfDead, Operation *op, + Args &&... args) { + auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), valuesToRemoveIfDead); } diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 228e16d752e..44b1156ec28 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -115,7 +115,7 @@ void QuantizedConstRewrite::rewrite(Operation *op, auto newConstOp = rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); rewriter.replaceOpWithNewOp<StorageCastOp>( - op, {origConstOp}, *op->result_type_begin(), newConstOp); + {origConstOp}, op, *op->result_type_begin(), newConstOp); } void ConvertConstPass::runOnFunction() { diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 6de9a151455..f1c43eff289 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -311,7 +311,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) - return; + return rewriter.replaceOp(op, llvm::None); if (numResults == 1) return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)); @@ -542,8 +542,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape); Value *casted = rewriter.create<LLVM::BitcastOp>( op->getLoc(), getVoidPtrType(), bufferPtr); - rewriter.create<LLVM::CallOp>(op->getLoc(), ArrayRef<Type>(), - rewriter.getFunctionAttr(freeFunc), casted); + rewriter.replaceOpWithNewOp<LLVM::CallOp>( + op, ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted); } }; @@ -803,7 +803,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> { Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1], operands.drop_front(2), rewriter, getModule()); - rewriter.create<LLVM::StoreOp>(op->getLoc(), operands[0], dataPtr); + rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, operands[0], dataPtr); } }; @@ -818,8 +818,8 @@ struct OneToOneLLVMTerminatorLowering ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, PatternRewriter &rewriter) const override { - rewriter.create<TargetOp>(op->getLoc(), properOperands, destinations, - operands, op->getAttrs()); + rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations, + operands, op->getAttrs()); } }; @@ -838,17 +838,15 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { - rewriter.create<LLVM::ReturnOp>( - op->getLoc(), llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(), + return rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs()); - return; } if (numArguments == 1) { - rewriter.create<LLVM::ReturnOp>( - op->getLoc(), llvm::ArrayRef<Value *>(operands.front()), + return rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, llvm::ArrayRef<Value *>(operands.front()), llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs()); - return; } // Otherwise, we need to pack the arguments into an LLVM struct type before @@ -861,8 +859,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { op->getLoc(), packedType, packed, operands[i], getIntegerArrayAttr(rewriter, i)); } - rewriter.create<LLVM::ReturnOp>( - op->getLoc(), llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(), + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs()); } }; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 2ecea9ca5a2..a4da12a922b 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -243,6 +243,7 @@ public: Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0], positionAttr(rewriter, 0))); call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted); + rewriter.replaceOp(op, llvm::None); } }; @@ -498,6 +499,7 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> { ArrayRef<Value *> indices = operands.drop_front(2); Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); llvm_store(data, ptr); + rewriter.replaceOp(op, llvm::None); } }; @@ -578,8 +580,8 @@ public: auto fAttr = rewriter.getFunctionAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); - rewriter.create<LLVM::CallOp>(op->getLoc(), operands, - ArrayRef<NamedAttribute>{named}); + rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands, + ArrayRef<NamedAttribute>{named}); } }; |