diff options
34 files changed, 113 insertions, 133 deletions
diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h index 2d4a4a2e6c2..8a5eddda6b0 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h @@ -31,7 +31,7 @@ class MLIRContext; class ModuleOp; class RewritePattern; class Type; -using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; +class OwningRewritePatternList; namespace LLVM { class LLVMType; } // end namespace LLVM diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 411a7afb284..58e61596153 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -395,8 +395,8 @@ public: void linalg::populateLinalg1ToLLVMConversionPatterns( mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) { - RewriteListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion, - ViewOpConversion>::build(patterns, context); + patterns.insert<DropConsumer, RangeOpConversion, SliceOpConversion, + ViewOpConversion>(context); } namespace { diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 8c77737ff3d..e4a401ea70f 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -145,8 +145,7 @@ struct LinalgTypeConverter : public LLVMTypeConverter { // coverters to the list. static void populateLinalg3ToLLVMConversionPatterns( mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) { - RewriteListBuilder<LoadOpConversion, StoreOpConversion>::build(patterns, - context); + patterns.insert<LoadOpConversion, StoreOpConversion>(context); } LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) { diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index d81eec0a370..8f97f4317f7 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -261,8 +261,8 @@ struct LowerLinalgLoadStorePass void runOnFunction() { OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique<Rewriter<linalg::LoadOp>>(context)); - patterns.push_back(llvm::make_unique<Rewriter<linalg::StoreOp>>(context)); + patterns.insert<Rewriter<linalg::LoadOp>, Rewriter<linalg::StoreOp>>( + context); applyPatternsGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 92e80d2dfa3..b89cb85ff06 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -142,14 +142,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> { // Register our patterns for rewrite by the Canonicalization framework. void TransposeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context)); + results.insert<SimplifyRedundantTranspose>(context); } // Register our patterns for rewrite by the Canonicalization framework. void ReshapeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape, - SimplifyNullReshape>::build(results, context); + results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape, + SimplifyNullReshape>(context); } } // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index f3463ba4e0f..72bc2891db6 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -132,7 +132,7 @@ struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> { target.addLegalOp<toy::AllocOp, toy::TypeCastOp>(); OwningRewritePatternList patterns; - RewriteListBuilder<MulOpConversion>::build(patterns, &getContext()); + patterns.insert<MulOpConversion>(&getContext()); if (failed(applyPartialConversion(getFunction(), target, std::move(patterns)))) { emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n"); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 5a01122c28a..8b2cc214a55 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -352,9 +352,9 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> { void runOnModule() override { ToyTypeConverter typeConverter; OwningRewritePatternList toyPatterns; - RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion, - TransposeOpConversion, - ReturnOpConversion>::build(toyPatterns, &getContext()); + toyPatterns.insert<AddOpConversion, PrintOpConversion, ConstantOpConversion, + TransposeOpConversion, ReturnOpConversion>( + &getContext()); mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(), typeConverter); diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 8e9e8ebcd55..4798ad188d1 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -144,14 +144,14 @@ struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> { // Register our patterns for rewrite by the Canonicalization framework. void TransposeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context)); + results.insert<SimplifyRedundantTranspose>(context); } // Register our patterns for rewrite by the Canonicalization framework. void ReshapeOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - mlir::RewriteListBuilder<SimplifyReshapeConstant, SimplifyReshapeReshape, - SimplifyNullReshape>::build(results, context); + results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape, + SimplifyNullReshape>(context); } namespace { @@ -180,7 +180,7 @@ struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> { void TypeCastOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.push_back(llvm::make_unique<SimplifyIdentityTypeCast>(context)); + results.insert<SimplifyIdentityTypeCast>(context); } } // namespace toy diff --git a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h index e8ab2732d31..78e4356607f 100644 --- a/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h +++ b/mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h @@ -29,7 +29,7 @@ class MLIRContext; class RewritePattern; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; +class OwningRewritePatternList; /// Collect a set of patterns to lower from loop.for, loop.if, and /// loop.terminator to CFG operations within the Standard dialect, in particular diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 361294a729e..941e382905f 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -38,7 +38,7 @@ class RewritePattern; class Type; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; +class OwningRewritePatternList; /// Type for a callback constructing the owning list of patterns for the /// conversion to the LLVMIR dialect. The callback is expected to append diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index c76f1d620af..204da29b39a 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -57,9 +57,7 @@ class Value; /// either OpTy or OperandAdaptor<OpTy> seamlessly. template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor; -/// This is a vector that owns the patterns inside of it. -using OwningPatternList = std::vector<std::unique_ptr<Pattern>>; -using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; +class OwningRewritePatternList; enum class OperationProperty { /// This bit is set for an operation if it is a commutative operation: that diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index d739a804438..e3897b1d63a 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -394,8 +394,39 @@ private: // Pattern-driven rewriters //===----------------------------------------------------------------------===// -/// This is a vector that owns the patterns inside of it. -using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; +class OwningRewritePatternList { + using PatternListT = std::vector<std::unique_ptr<RewritePattern>>; + +public: + PatternListT::iterator begin() { return patterns.begin(); } + PatternListT::iterator end() { return patterns.end(); } + PatternListT::const_iterator begin() const { return patterns.begin(); } + PatternListT::const_iterator end() const { return patterns.end(); } + + //===--------------------------------------------------------------------===// + // Pattern Insertion + //===--------------------------------------------------------------------===// + + void insert(RewritePattern *pattern) { patterns.emplace_back(pattern); } + + /// Add an instance of each of the pattern types 'Ts' to the pattern list with + /// the given arguments. + // Note: ConstructorArg is necessary here to separate the two variadic lists. + template <typename... Ts, typename ConstructorArg, + typename... ConstructorArgs> + void insert(ConstructorArg &&arg, ConstructorArgs &&... args) { + // The following expands a call to emplace_back for each of the pattern + // types 'Ts'. This magic is necessary due to a limitation in the places + // that a parameter pack can be expanded in c++11. + // FIXME: In c++17 this can be simplified by using 'fold expressions'. + using dummy = int[]; + (void)dummy{ + 0, (patterns.emplace_back(llvm::make_unique<Ts>(arg, args...)), 0)...}; + } + +private: + PatternListT patterns; +}; /// This class manages optimization and execution of a group of rewrite /// patterns, providing an API for finding and applying, the best match against @@ -404,7 +435,7 @@ using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; class RewritePatternMatcher { public: /// Create a RewritePatternMatcher with the specified set of patterns. - explicit RewritePatternMatcher(OwningRewritePatternList &&patterns); + explicit RewritePatternMatcher(OwningRewritePatternList &patterns); /// Try to match the given operation to a pattern and rewrite it. Return /// true if any pattern matches. @@ -416,7 +447,7 @@ private: /// The group of patterns that are matched for optimization through this /// matcher. - OwningRewritePatternList patterns; + std::vector<RewritePattern *> patterns; }; /// Rewrite the regions of the specified operation, which must be isolated from @@ -427,29 +458,6 @@ private: /// bool applyPatternsGreedily(Operation *op, OwningRewritePatternList &&patterns); -/// Helper class to create a list of rewrite patterns given a list of their -/// types and a list of attributes perfect-forwarded to each of the conversion -/// constructors. -template <typename Arg, typename... Args> struct RewriteListBuilder { - template <typename... ConstructorArgs> - static void build(OwningRewritePatternList &patterns, - ConstructorArgs &&... constructorArgs) { - RewriteListBuilder<Args...>::build( - patterns, std::forward<ConstructorArgs>(constructorArgs)...); - RewriteListBuilder<Arg>::build( - patterns, std::forward<ConstructorArgs>(constructorArgs)...); - } -}; - -// Template specialization to stop recursion. -template <typename Arg> struct RewriteListBuilder<Arg> { - template <typename... ConstructorArgs> - static void build(OwningRewritePatternList &patterns, - ConstructorArgs &&... constructorArgs) { - patterns.emplace_back(llvm::make_unique<Arg>( - std::forward<ConstructorArgs>(constructorArgs)...)); - } -}; } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/include/mlir/Transforms/LowerAffine.h b/mlir/include/mlir/Transforms/LowerAffine.h index 9ad3f66def5..5fae4763bf7 100644 --- a/mlir/include/mlir/Transforms/LowerAffine.h +++ b/mlir/include/mlir/Transforms/LowerAffine.h @@ -32,7 +32,7 @@ class RewritePattern; class Value; // Owning list of rewriting patterns. -using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; +class OwningRewritePatternList; /// Emit code that computes the given affine expression using standard /// arithmetic operations applied to the provided dimension and symbol values. diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 9a026231ab2..767c2e344d9 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -708,7 +708,7 @@ struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> { void AffineApplyOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique<SimplifyAffineApply>(context)); + results.insert<SimplifyAffineApply>(context); } //===----------------------------------------------------------------------===// @@ -912,8 +912,7 @@ LogicalResult AffineDmaStartOp::verify() { void AffineDmaStartOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -989,8 +988,7 @@ LogicalResult AffineDmaWaitOp::verify() { void AffineDmaWaitOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1333,7 +1331,7 @@ struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> { void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique<AffineForLoopBoundFolder>(context)); + results.insert<AffineForLoopBoundFolder>(context); } AffineBound AffineForOp::getLowerBound() { @@ -1659,8 +1657,7 @@ LogicalResult AffineLoadOp::verify() { void AffineLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1752,8 +1749,7 @@ LogicalResult AffineStoreOp::verify() { void AffineStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } #define GET_OP_CLASSES diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index c37decf69e6..034aa22f922 100644 --- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -258,8 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build( - patterns, ctx); + patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx); } void ControlFlowToCFGPass::runOnFunction() { diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 4eadb874908..58f01fc6689 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -104,8 +104,7 @@ void GPUToSPIRVPass::runOnModule() { SPIRVTypeConverter typeConverter(context); SPIRVEntryFnTypeConverter entryFnConverter(context); OwningRewritePatternList patterns; - RewriteListBuilder<KernelFnConversion>::build( - patterns, context, typeConverter, entryFnConverter); + patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter); populateStandardToSPIRVPatterns(context, patterns); ConversionTarget target(*context); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index af8812c8cf4..09ddcd1e475 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1023,7 +1023,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed - RewriteListBuilder< + patterns.insert< AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, DeallocOpLowering, @@ -1032,8 +1032,7 @@ void mlir::populateStdToLLVMConversionPatterns( MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering, - SubIOpLowering, XOrOpLowering>::build(patterns, *converter.getDialect(), - converter); + SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter); } // Convert types using the stored LLVM IR module. diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index d32d8668046..067f2aeda06 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -201,6 +201,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, OwningRewritePatternList &patterns) { populateWithGenerated(context, &patterns); // Add the return op conversion. - RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context); + patterns.insert<ReturnToSPIRVConversion>(context); } } // namespace mlir diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index dafc8e711f5..d2f3881710c 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -368,8 +368,7 @@ void LowerUniformRealMathPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context)); - patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context)); + patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context); applyPatternsGreedily(fn, std::move(patterns)); } @@ -389,7 +388,7 @@ void LowerUniformCastsPass::runOnFunction() { auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context)); + patterns.insert<UniformDequantizePattern>(context); applyPatternsGreedily(fn, std::move(patterns)); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index bda5979939c..2fbaa49f56e 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -372,7 +372,7 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder<PropagateConstantBounds>::build(results, context); + results.insert<PropagateConstantBounds>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index e237e8b6eb2..3bd49d43adc 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -60,8 +60,7 @@ public: void StorageCastOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.push_back( - llvm::make_unique<RemoveRedundantStorageCastsRewrite>(context)); + patterns.insert<RemoveRedundantStorageCastsRewrite>(context); } QuantizationDialect::QuantizationDialect(MLIRContext *context) diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 8469fa2ea70..2276fbd21c9 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -108,7 +108,7 @@ void ConvertConstPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context)); + patterns.insert<QuantizedConstRewrite>(context); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 32d8c8a81c1..8f5d1b33c64 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -97,8 +97,7 @@ void ConvertSimulatedQuantPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back( - llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure)); + patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure); applyPatternsGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5010b845c78..94fa7ab43f7 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -149,12 +149,13 @@ void PatternRewriter::updatedRootInPlace( //===----------------------------------------------------------------------===// RewritePatternMatcher::RewritePatternMatcher( - OwningRewritePatternList &&patterns) - : patterns(std::move(patterns)) { + OwningRewritePatternList &patterns) { + for (auto &pattern : patterns) + this->patterns.push_back(pattern.get()); + // Sort the patterns by benefit to simplify the matching logic. std::stable_sort(this->patterns.begin(), this->patterns.end(), - [](const std::unique_ptr<RewritePattern> &l, - const std::unique_ptr<RewritePattern> &r) { + [](RewritePattern *l, RewritePattern *r) { return r->getBenefit() < l->getBenefit(); }); } @@ -162,7 +163,7 @@ RewritePatternMatcher::RewritePatternMatcher( /// Try to match the given operation to a pattern and rewrite it. bool RewritePatternMatcher::matchAndRewrite(Operation *op, PatternRewriter &rewriter) { - for (auto &pattern : patterns) { + for (auto *pattern : patterns) { // Ignore patterns that are for the wrong root or are impossible to match. if (pattern->getRootKind() != op->getName() || pattern->getBenefit().isImpossibleToMatch()) diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 6b62a8e1340..7c2ea5945f4 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -678,12 +678,11 @@ static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion, - BufferSizeOpConversion, DimOpConversion, - LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>, - LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, ViewOpConversion>::build(patterns, ctx, - converter); + patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion, + BufferSizeOpConversion, DimOpConversion, + LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>, + LoadOpConversion, RangeOpConversion, SliceOpConversion, + StoreOpConversion, ViewOpConversion>(ctx, converter); } namespace { diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index 6b376db8516..3de89137c3c 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -60,12 +60,9 @@ void RemoveInstrumentationPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); - patterns.push_back( - llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context)); - patterns.push_back( - llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context)); - patterns.push_back( - llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context)); + patterns.insert<RemoveIdentityOpRewrite<StatisticsOp>, + RemoveIdentityOpRewrite<StatisticsRefOp>, + RemoveIdentityOpRewrite<CoupledRefOp>>(context); applyPatternsGreedily(func, std::move(patterns)); } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index df99f00c110..9ecd99a5169 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -365,8 +365,7 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder<SimplifyAllocConst, SimplifyDeadAlloc>::build(results, - context); + results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context); } //===----------------------------------------------------------------------===// @@ -544,8 +543,7 @@ static LogicalResult verify(CallIndirectOp op) { void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back( - llvm::make_unique<SimplifyIndirectCallWithKnownCallee>(context)); + results.insert<SimplifyIndirectCallWithKnownCallee>(context); } //===----------------------------------------------------------------------===// @@ -1015,7 +1013,7 @@ static void print(OpAsmPrinter *p, CondBranchOp op) { void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique<SimplifyConstCondBranchPred>(context)); + results.insert<SimplifyConstCondBranchPred>(context); } //===----------------------------------------------------------------------===// @@ -1231,9 +1229,8 @@ static LogicalResult verify(DeallocOp op) { void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dealloc(memrefcast) -> dealloc - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); - results.push_back(llvm::make_unique<SimplifyDeadDealloc>(context)); + results.insert<MemRefCastFolder>(getOperationName(), context); + results.insert<SimplifyDeadDealloc>(context); } //===----------------------------------------------------------------------===// @@ -1497,8 +1494,7 @@ LogicalResult DmaStartOp::verify() { void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } // --------------------------------------------------------------------------- @@ -1561,8 +1557,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -1695,8 +1690,7 @@ static LogicalResult verify(LoadOp op) { void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } //===----------------------------------------------------------------------===// @@ -2007,8 +2001,7 @@ static LogicalResult verify(StoreOp op) { void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// store(memrefcast) -> store - results.push_back( - llvm::make_unique<MemRefCastFolder>(getOperationName(), context)); + results.insert<MemRefCastFolder>(getOperationName(), context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 50c636f708e..6f264b0af35 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1243,8 +1243,7 @@ struct FuncOpSignatureConversion : public ConversionPattern { void mlir::populateFuncOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - RewriteListBuilder<FuncOpSignatureConversion>::build(patterns, ctx, - converter); + patterns.insert<FuncOpSignatureConversion>(ctx, converter); } /// This function converts the type signature of the given block, by invoking diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f35f963b8ae..1c558efd8e4 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -507,10 +507,11 @@ public: void mlir::populateAffineToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering, - AffineDmaWaitLowering, AffineLoadLowering, - AffineStoreLowering, AffineForLowering, AffineIfLowering, - AffineTerminatorLowering>::build(patterns, ctx); + patterns + .insert<AffineApplyLowering, AffineDmaStartLowering, + AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering, + AffineForLowering, AffineIfLowering, AffineTerminatorLowering>( + ctx); } namespace { diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 3585e2befd6..ef67488023f 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -365,12 +365,8 @@ struct LowerVectorTransfersPass void runOnFunction() { OwningRewritePatternList patterns; auto *context = &getContext(); - patterns.push_back( - llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>( - context)); - patterns.push_back( - llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>( - context)); + patterns.insert<VectorTransferRewriter<VectorTransferReadOp>, + VectorTransferRewriter<VectorTransferWriteOp>>(context); applyPatternsGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 52952178b37..1df4ceec8f3 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -44,8 +44,8 @@ namespace { class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - OwningRewritePatternList &&patterns) - : PatternRewriter(ctx), matcher(std::move(patterns)) { + OwningRewritePatternList &patterns) + : PatternRewriter(ctx), matcher(patterns) { worklist.reserve(64); } @@ -224,7 +224,7 @@ bool mlir::applyPatternsGreedily(Operation *op, if (!op->isKnownIsolatedFromAbove()) return false; - GreedyPatternRewriteDriver driver(op->getContext(), std::move(patterns)); + GreedyPatternRewriteDriver driver(op->getContext(), patterns); bool converged = driver.simplify(op, maxPatternMatchIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 201dfc3005c..ed94eed4fdd 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -41,7 +41,7 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> { populateWithGenerated(&getContext(), &patterns); // Verify named pattern is generated with expected name. - RewriteListBuilder<TestNamedPatternRule>::build(patterns, &getContext()); + patterns.insert<TestNamedPatternRule>(&getContext()); applyPatternsGreedily(getFunction(), std::move(patterns)); } @@ -193,9 +193,9 @@ struct TestLegalizePatternDriver TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - RewriteListBuilder<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, - TestDropOp, TestPassthroughInvalidOp, - TestSplitReturnType>::build(patterns, &getContext()); + patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, + TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>( + &getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index edf6aeae469..f75413fdaed 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -133,7 +133,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) { pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateStdToLLVMConversionPatterns(converter, patterns); - patterns.push_back(llvm::make_unique<GPULaunchFuncOpLowering>(converter)); + patterns.insert<GPULaunchFuncOpLowering>(converter); })); pm.addPass(createLowerGpuOpsToNVVMOpsPass()); pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index d408ecfa5eb..24eeaf50d78 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -935,8 +935,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { os << "void populateWithGenerated(MLIRContext *context, " << "OwningRewritePatternList *patterns) {\n"; for (const auto &name : rewriterNames) { - os << " patterns->push_back(llvm::make_unique<" << name - << ">(context));\n"; + os << " patterns->insert<" << name << ">(context);\n"; } os << "}\n"; } |