summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp46
-rw-r--r--mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp15
-rw-r--r--mlir/examples/toy/Ch5/mlir/LateLowering.cpp10
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h10
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp2
-rw-r--r--mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp26
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp6
7 files changed, 36 insertions, 79 deletions
diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
index 0b640298e02..0b8d6bd9913 100644
--- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
+++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/LLVMLowering.h"
#include "mlir/LLVMIR/Transforms.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
@@ -392,7 +393,9 @@ public:
: DialectOpConversion("some_consumer", 1, context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
- PatternRewriter &rewriter) const override {}
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOp(op, llvm::None);
+ }
};
void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
@@ -403,7 +406,7 @@ void linalg::getDescriptorConverters(mlir::OwningRewritePatternList &patterns,
namespace {
// The conversion class from Linalg to LLVMIR.
-class Lowering : public DialectConversion {
+class Lowering : public LLVMLowering {
public:
explicit Lowering(std::function<void(mlir::OwningRewritePatternList &patterns,
mlir::MLIRContext *context)>
@@ -412,37 +415,13 @@ public:
protected:
// Initialize the list of converters.
- void initConverters(OwningRewritePatternList &patterns,
- MLIRContext *context) override {
- setup(patterns, context);
+ void initAdditionalConverters(OwningRewritePatternList &patterns) override {
+ setup(patterns, llvmDialect->getContext());
}
// This gets called for block and region arguments, and attributes.
- Type convertType(Type t) override { return linalg::convertLinalgType(t); }
-
- // This gets called for function signatures. Convert function arguments and
- // results to the LLVM types, but keep the outer function type as built-in
- // MLIR function type. This does not support multi-result functions because
- // LLVM does not.
- FunctionType convertFunctionSignatureType(
- FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
- SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override {
- convertedArgAttrs.reserve(argAttrs.size());
- convertedArgAttrs.insert(convertedArgAttrs.end(), argAttrs.begin(),
- argAttrs.end());
-
- SmallVector<Type, 4> argTypes;
- argTypes.reserve(t.getNumInputs());
- for (auto ty : t.getInputs())
- argTypes.push_back(linalg::convertLinalgType(ty));
-
- SmallVector<Type, 1> resultTypes;
- resultTypes.reserve(t.getNumResults());
- for (auto ty : t.getResults())
- resultTypes.push_back(linalg::convertLinalgType(ty));
- assert(t.getNumResults() <= 1 && "NYI: multi-result functions");
-
- return FunctionType::get(argTypes, resultTypes, t.getContext());
+ Type convertAdditionalType(Type t) override {
+ return linalg::convertLinalgType(t);
}
private:
@@ -472,13 +451,6 @@ void linalg::convertToLLVM(mlir::Module &module) {
auto r = Lowering(getDescriptorConverters).convert(&module);
(void)r;
assert(succeeded(r) && "conversion failed");
-
- // Convert the remaining standard MLIR operations to the LLVM IR dialect using
- // the default converter.
- auto converter = createStdToLLVMConverter();
- r = converter->convert(&module);
- (void)r;
- assert(succeeded(r) && "second conversion failed");
}
namespace {
diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
index 8d945f3274c..b475ed2d2a1 100644
--- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
+++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
@@ -58,13 +58,6 @@ public:
: DialectOpConversion(Op::getOperationName(), 1, context) {}
using Base = LoadStoreOpConversion<Op>;
- // Match the Op specified as template argument.
- PatternMatchResult match(Operation *op) const override {
- if (isa<Op>(op))
- return matchSuccess();
- return matchFailure();
- }
-
// Compute the pointer to an element of the buffer underlying the view given
// current view indices. Use the base offset and strides stored in the view
// descriptor to emit IR iteratively computing the actual offset, followed by
@@ -128,6 +121,7 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
ArrayRef<Value *> indices = operands.drop_front(2);
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
intrinsics::store(data, ptr);
+ rewriter.replaceOp(op, llvm::None);
}
};
@@ -154,11 +148,4 @@ void linalg::convertLinalg3ToLLVM(Module &module) {
auto r = lowering->convert(&module);
(void)r;
assert(succeeded(r) && "conversion failed");
-
- // Convert the remaining standard MLIR operations to the LLVM IR dialect using
- // the default converter.
- auto converter = createStdToLLVMConverter();
- r = converter->convert(&module);
- (void)r;
- assert(succeeded(r) && "second conversion failed");
}
diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
index 4cb34d1181a..00b3a74f018 100644
--- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
@@ -175,6 +175,7 @@ public:
});
// clang-format on
}
+ rewriter.replaceOp(op, llvm::None);
}
private:
@@ -308,14 +309,11 @@ public:
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
- auto retOp = cast<toy::ReturnOp>(op);
- using namespace edsc;
- auto loc = retOp.getLoc();
// Argument is optional, handle both cases.
- if (retOp.getNumOperands())
- rewriter.create<ReturnOp>(loc, operands[0]);
+ if (op->getNumOperands())
+ rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]);
else
- rewriter.create<ReturnOp>(loc);
+ rewriter.replaceOpWithNewOp<ReturnOp>(op);
}
};
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 784674b8278..340e84457f6 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -254,8 +254,8 @@ public:
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
template <typename OpTy, typename... Args>
- void replaceOpWithNewOp(Operation *op, Args... args) {
- auto newOp = create<OpTy>(op->getLoc(), args...);
+ void replaceOpWithNewOp(Operation *op, Args &&... args) {
+ auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
}
@@ -263,9 +263,9 @@ public:
/// The result values of the two ops must be the same types. This allows
/// specifying a list of ops that may be removed if dead.
template <typename OpTy, typename... Args>
- void replaceOpWithNewOp(Operation *op, ArrayRef<Value *> valuesToRemoveIfDead,
- Args... args) {
- auto newOp = create<OpTy>(op->getLoc(), args...);
+ void replaceOpWithNewOp(ArrayRef<Value *> valuesToRemoveIfDead, Operation *op,
+ Args &&... args) {
+ auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
valuesToRemoveIfDead);
}
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 228e16d752e..44b1156ec28 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -115,7 +115,7 @@ void QuantizedConstRewrite::rewrite(Operation *op,
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
rewriter.replaceOpWithNewOp<StorageCastOp>(
- op, {origConstOp}, *op->result_type_begin(), newConstOp);
+ {origConstOp}, op, *op->result_type_begin(), newConstOp);
}
void ConvertConstPass::runOnFunction() {
diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
index 6de9a151455..f1c43eff289 100644
--- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
+++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
@@ -311,7 +311,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
- return;
+ return rewriter.replaceOp(op, llvm::None);
if (numResults == 1)
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0));
@@ -542,8 +542,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
- rewriter.create<LLVM::CallOp>(op->getLoc(), ArrayRef<Type>(),
- rewriter.getFunctionAttr(freeFunc), casted);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
}
};
@@ -803,7 +803,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
operands.drop_front(2), rewriter, getModule());
- rewriter.create<LLVM::StoreOp>(op->getLoc(), operands[0], dataPtr);
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, operands[0], dataPtr);
}
};
@@ -818,8 +818,8 @@ struct OneToOneLLVMTerminatorLowering
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
PatternRewriter &rewriter) const override {
- rewriter.create<TargetOp>(op->getLoc(), properOperands, destinations,
- operands, op->getAttrs());
+ rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
+ operands, op->getAttrs());
}
};
@@ -838,17 +838,15 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
- rewriter.create<LLVM::ReturnOp>(
- op->getLoc(), llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
+ return rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+ op, llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
- return;
}
if (numArguments == 1) {
- rewriter.create<LLVM::ReturnOp>(
- op->getLoc(), llvm::ArrayRef<Value *>(operands.front()),
+ return rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+ op, llvm::ArrayRef<Value *>(operands.front()),
llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(),
op->getAttrs());
- return;
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
@@ -861,8 +859,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
op->getLoc(), packedType, packed, operands[i],
getIntegerArrayAttr(rewriter, i));
}
- rewriter.create<LLVM::ReturnOp>(
- op->getLoc(), llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+ op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
}
};
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index 2ecea9ca5a2..a4da12a922b 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -243,6 +243,7 @@ public:
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
positionAttr(rewriter, 0)));
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
+ rewriter.replaceOp(op, llvm::None);
}
};
@@ -498,6 +499,7 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
ArrayRef<Value *> indices = operands.drop_front(2);
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
llvm_store(data, ptr);
+ rewriter.replaceOp(op, llvm::None);
}
};
@@ -578,8 +580,8 @@ public:
auto fAttr = rewriter.getFunctionAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
- rewriter.create<LLVM::CallOp>(op->getLoc(), operands,
- ArrayRef<NamedAttribute>{named});
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
+ ArrayRef<NamedAttribute>{named});
}
};
OpenPOWER on IntegriCloud