summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h2
-rw-r--r--mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp4
-rw-r--r--mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp3
-rw-r--r--mlir/examples/Linalg/Linalg3/lib/Transforms.cpp4
-rw-r--r--mlir/examples/toy/Ch4/mlir/ToyCombine.cpp6
-rw-r--r--mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp2
-rw-r--r--mlir/examples/toy/Ch5/mlir/LateLowering.cpp6
-rw-r--r--mlir/examples/toy/Ch5/mlir/ToyCombine.cpp8
-rw-r--r--mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h2
-rw-r--r--mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h2
-rw-r--r--mlir/include/mlir/IR/OperationSupport.h4
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h62
-rw-r--r--mlir/include/mlir/Transforms/LowerAffine.h2
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp16
-rw-r--r--mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp3
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp3
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp5
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp2
-rw-r--r--mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp5
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp3
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp2
-rw-r--r--mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp3
-rw-r--r--mlir/lib/IR/PatternMatch.cpp11
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp11
-rw-r--r--mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp9
-rw-r--r--mlir/lib/StandardOps/Ops.cpp25
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp3
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp9
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp8
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp6
-rw-r--r--mlir/test/lib/TestDialect/TestPatterns.cpp8
-rw-r--r--mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp3
34 files changed, 113 insertions, 133 deletions
diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h
index 2d4a4a2e6c2..8a5eddda6b0 100644
--- a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h
+++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h
@@ -31,7 +31,7 @@ class MLIRContext;
class ModuleOp;
class RewritePattern;
class Type;
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
namespace LLVM {
class LLVMType;
} // end namespace LLVM
diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
index 411a7afb284..58e61596153 100644
--- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
+++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
@@ -395,8 +395,8 @@ public:
void linalg::populateLinalg1ToLLVMConversionPatterns(
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
- RewriteListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
- ViewOpConversion>::build(patterns, context);
+ patterns.insert<DropConsumer, RangeOpConversion, SliceOpConversion,
+ ViewOpConversion>(context);
}
namespace {
diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
index 8c77737ff3d..e4a401ea70f 100644
--- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
+++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
@@ -145,8 +145,7 @@ struct LinalgTypeConverter : public LLVMTypeConverter {
// coverters to the list.
static void populateLinalg3ToLLVMConversionPatterns(
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
- RewriteListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns,
- context);
+ patterns.insert<LoadOpConversion, StoreOpConversion>(context);
}
LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
index d81eec0a370..8f97f4317f7 100644
--- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
+++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
@@ -261,8 +261,8 @@ struct LowerLinalgLoadStorePass
void runOnFunction() {
OwningRewritePatternList patterns;
auto *context = &getContext();
- patterns.push_back(llvm::make_unique<Rewriter<linalg::LoadOp>>(context));
- patterns.push_back(llvm::make_unique<Rewriter<linalg::StoreOp>>(context));
+ patterns.insert<Rewriter<linalg::LoadOp>, Rewriter<linalg::StoreOp>>(
+ context);
applyPatternsGreedily(getFunction(), std::move(patterns));
}
};
diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
index 92e80d2dfa3..b89cb85ff06 100644
--- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
@@ -142,14 +142,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
// Register our patterns for rewrite by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context));
+ results.insert<SimplifyRedundantTranspose>(context);
}
// Register our patterns for rewrite by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
- SimplifyNullReshape>::build(results, context);
+ results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
+ SimplifyNullReshape>(context);
}
} // namespace toy
diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
index f3463ba4e0f..72bc2891db6 100644
--- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
+++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
@@ -132,7 +132,7 @@ struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
OwningRewritePatternList patterns;
- RewriteListBuilder<MulOpConversion>::build(patterns, &getContext());
+ patterns.insert<MulOpConversion>(&getContext());
if (failed(applyPartialConversion(getFunction(), target,
std::move(patterns)))) {
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n");
diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
index 5a01122c28a..8b2cc214a55 100644
--- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
@@ -352,9 +352,9 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
void runOnModule() override {
ToyTypeConverter typeConverter;
OwningRewritePatternList toyPatterns;
- RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
- TransposeOpConversion,
- ReturnOpConversion>::build(toyPatterns, &getContext());
+ toyPatterns.insert<AddOpConversion, PrintOpConversion, ConstantOpConversion,
+ TransposeOpConversion, ReturnOpConversion>(
+ &getContext());
mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(),
typeConverter);
diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
index 8e9e8ebcd55..4798ad188d1 100644
--- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
@@ -144,14 +144,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
// Register our patterns for rewrite by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context));
+ results.insert<SimplifyRedundantTranspose>(context);
}
// Register our patterns for rewrite by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape,
- SimplifyNullReshape>::build(results, context);
+ results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
+ SimplifyNullReshape>(context);
}
namespace {
@@ -180,7 +180,7 @@ struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
void TypeCastOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyIdentityTypeCast>(context));
+ results.insert<SimplifyIdentityTypeCast>(context);
}
} // namespace toy
diff --git a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
index e8ab2732d31..78e4356607f 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
@@ -29,7 +29,7 @@ class MLIRContext;
class RewritePattern;
// Owning list of rewriting patterns.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
/// Collect a set of patterns to lower from loop.for, loop.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index 361294a729e..941e382905f 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -38,7 +38,7 @@ class RewritePattern;
class Type;
// Owning list of rewriting patterns.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
/// Type for a callback constructing the owning list of patterns for the
/// conversion to the LLVMIR dialect. The callback is expected to append
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index c76f1d620af..204da29b39a 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -57,9 +57,7 @@ class Value;
/// either OpTy or OperandAdaptor<OpTy> seamlessly.
template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
-/// This is a vector that owns the patterns inside of it.
-using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
enum class OperationProperty {
/// This bit is set for an operation if it is a commutative operation: that
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index d739a804438..e3897b1d63a 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -394,8 +394,39 @@ private:
// Pattern-driven rewriters
//===----------------------------------------------------------------------===//
-/// This is a vector that owns the patterns inside of it.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList {
+ using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
+
+public:
+ PatternListT::iterator begin() { return patterns.begin(); }
+ PatternListT::iterator end() { return patterns.end(); }
+ PatternListT::const_iterator begin() const { return patterns.begin(); }
+ PatternListT::const_iterator end() const { return patterns.end(); }
+
+ //===--------------------------------------------------------------------===//
+ // Pattern Insertion
+ //===--------------------------------------------------------------------===//
+
+ void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); }
+
+ /// Add an instance of each of the pattern types 'Ts' to the pattern list with
+ /// the given arguments.
+ // Note: ConstructorArg is necessary here to separate the two variadic lists.
+ template <typename... Ts, typename ConstructorArg,
+ typename... ConstructorArgs>
+ void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
+ // The following expands a call to emplace_back for each of the pattern
+ // types 'Ts'. This magic is necessary due to a limitation in the places
+ // that a parameter pack can be expanded in c++11.
+ // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+ using dummy = int[];
+ (void)dummy{
+ 0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...};
+ }
+
+private:
+ PatternListT patterns;
+};
/// This class manages optimization and execution of a group of rewrite
/// patterns, providing an API for finding and applying, the best match against
@@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
class RewritePatternMatcher {
public:
/// Create a RewritePatternMatcher with the specified set of patterns.
- explicit RewritePatternMatcher(OwningRewritePatternList &&patterns);
+ explicit RewritePatternMatcher(OwningRewritePatternList &patterns);
/// Try to match the given operation to a pattern and rewrite it. Return
/// true if any pattern matches.
@@ -416,7 +447,7 @@ private:
/// The group of patterns that are matched for optimization through this
/// matcher.
- OwningRewritePatternList patterns;
+ std::vector<RewritePattern *> patterns;
};
/// Rewrite the regions of the specified operation, which must be isolated from
@@ -427,29 +458,6 @@ private:
///
bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns);
-/// Helper class to create a list of rewrite patterns given a list of their
-/// types and a list of attributes perfect-forwarded to each of the conversion
-/// constructors.
-template <typename Arg, typename... Args> struct RewriteListBuilder {
- template <typename... ConstructorArgs>
- static void build(OwningRewritePatternList &patterns,
- ConstructorArgs &&... constructorArgs) {
- RewriteListBuilder<Args...>::build(
- patterns, std::forward<ConstructorArgs>(constructorArgs)...);
- RewriteListBuilder<Arg>::build(
- patterns, std::forward<ConstructorArgs>(constructorArgs)...);
- }
-};
-
-// Template specialization to stop recursion.
-template <typename Arg> struct RewriteListBuilder<Arg> {
- template <typename... ConstructorArgs>
- static void build(OwningRewritePatternList &patterns,
- ConstructorArgs &&... constructorArgs) {
- patterns.emplace_back(llvm::make_unique<Arg>(
- std::forward<ConstructorArgs>(constructorArgs)...));
- }
-};
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H
diff --git a/mlir/include/mlir/Transforms/LowerAffine.h b/mlir/include/mlir/Transforms/LowerAffine.h
index 9ad3f66def5..5fae4763bf7 100644
--- a/mlir/include/mlir/Transforms/LowerAffine.h
+++ b/mlir/include/mlir/Transforms/LowerAffine.h
@@ -32,7 +32,7 @@ class RewritePattern;
class Value;
// Owning list of rewriting patterns.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+class OwningRewritePatternList;
/// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values.
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
index 9a026231ab2..767c2e344d9 100644
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
void AffineApplyOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyAffineApply>(context));
+ results.insert<SimplifyAffineApply>(context);
}
//===----------------------------------------------------------------------===//
@@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() {
void AffineDmaStartOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// dma_start(memrefcast) -> dma_start
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() {
void AffineDmaWaitOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// dma_wait(memrefcast) -> dma_wait
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.push_back(llvm::make_unique<AffineForLoopBoundFolder>(context));
+ results.insert<AffineForLoopBoundFolder>(context);
}
AffineBound AffineForOp::getLowerBound() {
@@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() {
void AffineLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() {
void AffineStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
index c37decf69e6..034aa22f922 100644
--- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
+++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
@@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
void mlir::populateLoopToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
- RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
- patterns, ctx);
+ patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
}
void ControlFlowToCFGPass::runOnFunction() {
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 4eadb874908..58f01fc6689 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() {
SPIRVTypeConverter typeConverter(context);
SPIRVEntryFnTypeConverter entryFnConverter(context);
OwningRewritePatternList patterns;
- RewriteListBuilder<KernelFnConversion>::build(
- patterns, context, typeConverter, entryFnConverter);
+ patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
populateStandardToSPIRVPatterns(context, patterns);
ConversionTarget target(*context);
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index af8812c8cf4..09ddcd1e475 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
- RewriteListBuilder<
+ patterns.insert<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering,
@@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns(
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
- SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(),
- converter);
+ SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
}
// Convert types using the stored LLVM IR module.
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index d32d8668046..067f2aeda06 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
- RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context);
+ patterns.insert<ReturnToSPIRVConversion>(context);
}
} // namespace mlir
diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index dafc8e711f5..d2f3881710c 100644
--- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
- patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
- patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
+ patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
applyPatternsGreedily(fn, std::move(patterns));
}
@@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
- patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
+ patterns.insert<UniformDequantizePattern>(context);
applyPatternsGreedily(fn, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index bda5979939c..2fbaa49f56e 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- RewriteListBuilder<PropagateConstantBounds>::build(results, context);
+ results.insert<PropagateConstantBounds>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
index e237e8b6eb2..3bd49d43adc 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -60,8 +60,7 @@ public:
void StorageCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.push_back(
- llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context));
+ patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
}
QuantizationDialect::QuantizationDialect(MLIRContext *context)
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 8469fa2ea70..2276fbd21c9 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
- patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
+ patterns.insert<QuantizedConstRewrite>(context);
applyPatternsGreedily(func, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
index 32d8c8a81c1..8f5d1b33c64 100644
--- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
- patterns.push_back(
- llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
+ patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
applyPatternsGreedily(func, std::move(patterns));
if (hadFailure)
signalPassFailure();
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5010b845c78..94fa7ab43f7 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace(
//===----------------------------------------------------------------------===//
RewritePatternMatcher::RewritePatternMatcher(
- OwningRewritePatternList &&patterns)
- : patterns(std::move(patterns)) {
+ OwningRewritePatternList &patterns) {
+ for (auto &pattern : patterns)
+ this->patterns.push_back(pattern.get());
+
// Sort the patterns by benefit to simplify the matching logic.
std::stable_sort(this->patterns.begin(), this->patterns.end(),
- [](const std::unique_ptr<RewritePattern> &l,
- const std::unique_ptr<RewritePattern> &r) {
+ [](RewritePattern *l, RewritePattern *r) {
return r->getBenefit() < l->getBenefit();
});
}
@@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher(
/// Try to match the given operation to a pattern and rewrite it.
bool RewritePatternMatcher::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
- for (auto &pattern : patterns) {
+ for (auto *pattern : patterns) {
// Ignore patterns that are for the wrong root or are impossible to match.
if (pattern->getRootKind() != op->getName() ||
pattern->getBenefit().isImpossibleToMatch())
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index 6b62a8e1340..7c2ea5945f4 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -678,12 +678,11 @@ static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
- RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
- BufferSizeOpConversion, DimOpConversion,
- LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
- LoadOpConversion, RangeOpConversion, SliceOpConversion,
- StoreOpConversion, ViewOpConversion>::build(patterns, ctx,
- converter);
+ patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
+ BufferSizeOpConversion, DimOpConversion,
+ LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
+ LoadOpConversion, RangeOpConversion, SliceOpConversion,
+ StoreOpConversion, ViewOpConversion>(ctx, converter);
}
namespace {
diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
index 6b376db8516..3de89137c3c 100644
--- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
+++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
@@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
- patterns.push_back(
- llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
- patterns.push_back(
- llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
- patterns.push_back(
- llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
+ patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>,
+ RemoveIdentityOpRewrite<StatisticsRefOp>,
+ RemoveIdentityOpRewrite<CoupledRefOp>>(context);
applyPatternsGreedily(func, std::move(patterns));
}
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp
index df99f00c110..9ecd99a5169 100644
--- a/mlir/lib/StandardOps/Ops.cpp
+++ b/mlir/lib/StandardOps/Ops.cpp
@@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results,
- context);
+ results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
}
//===----------------------------------------------------------------------===//
@@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) {
void CallIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.push_back(
- llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
+ results.insert<SimplifyIndirectCallWithKnownCallee>(context);
}
//===----------------------------------------------------------------------===//
@@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) {
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context));
+ results.insert<SimplifyConstCondBranchPred>(context);
}
//===----------------------------------------------------------------------===//
@@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) {
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dealloc(memrefcast) -> dealloc
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
- results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
+ results.insert<SimplifyDeadDealloc>(context);
}
//===----------------------------------------------------------------------===//
@@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() {
void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dma_start(memrefcast) -> dma_start
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
// ---------------------------------------------------------------------------
@@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dma_wait(memrefcast) -> dma_wait
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) {
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// load(memrefcast) -> load
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
@@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) {
void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// store(memrefcast) -> store
- results.push_back(
- llvm::make_unique<MemRefCastFolder>(getOperationName(), context));
+ results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 50c636f708e..6f264b0af35 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern {
void mlir::populateFuncOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &converter) {
- RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx,
- converter);
+ patterns.insert<FuncOpSignatureConversion>(ctx, converter);
}
/// This function converts the type signature of the given block, by invoking
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index f35f963b8ae..1c558efd8e4 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -507,10 +507,11 @@ public:
void mlir::populateAffineToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
- RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
- AffineDmaWaitLowering, AffineLoadLowering,
- AffineStoreLowering, AffineForLowering, AffineIfLowering,
- AffineTerminatorLowering>::build(patterns, ctx);
+ patterns
+ .insert<AffineApplyLowering, AffineDmaStartLowering,
+ AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
+ AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
+ ctx);
}
namespace {
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index 3585e2befd6..ef67488023f 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -365,12 +365,8 @@ struct LowerVectorTransfersPass
void runOnFunction() {
OwningRewritePatternList patterns;
auto *context = &getContext();
- patterns.push_back(
- llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
- context));
- patterns.push_back(
- llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
- context));
+ patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
+ VectorTransferRewriter<VectorTransferWriteOp>>(context);
applyPatternsGreedily(getFunction(), std::move(patterns));
}
};
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 52952178b37..1df4ceec8f3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -44,8 +44,8 @@ namespace {
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
- OwningRewritePatternList &&patterns)
- : PatternRewriter(ctx), matcher(std::move(patterns)) {
+ OwningRewritePatternList &patterns)
+ : PatternRewriter(ctx), matcher(patterns) {
worklist.reserve(64);
}
@@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op,
if (!op->isKnownIsolatedFromAbove())
return false;
- GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns));
+ GreedyPatternRewriteDriver driver(op->getContext(), patterns);
bool converged = driver.simplify(op, maxPatternMatchIterations);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index 201dfc3005c..ed94eed4fdd 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
populateWithGenerated(&getContext(), &patterns);
// Verify named pattern is generated with expected name.
- RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext());
+ patterns.insert<TestNamedPatternRule>(&getContext());
applyPatternsGreedily(getFunction(), std::move(patterns));
}
@@ -193,9 +193,9 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
- RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
- TestDropOp, TestPassthroughInvalidOp,
- TestSplitReturnType>::build(patterns, &getContext());
+ patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+ TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
+ &getContext());
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
index edf6aeae469..f75413fdaed 100644
--- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
+++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
@@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
populateStdToLLVMConversionPatterns(converter, patterns);
- patterns.push_back(llvm::make_unique<GPULaunchFuncOpLowering>(converter));
+ patterns.insert<GPULaunchFuncOpLowering>(converter);
}));
pm.addPass(createLowerGpuOpsToNVVMOpsPass());
pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index d408ecfa5eb..24eeaf50d78 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "void populateWithGenerated(MLIRContext *context, "
<< "OwningRewritePatternList *patterns) {\n";
for (const auto &name : rewriterNames) {
- os << " patterns->push_back(llvm::make_unique<" << name
- << ">(context));\n";
+ os << " patterns->insert<" << name << ">(context);\n";
}
os << "}\n";
}
OpenPOWER on IntegriCloud