summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/IR/MLIRContext.cpp33
-rw-r--r--mlir/lib/IR/PatternMatch.cpp (renamed from mlir/lib/Transforms/Utils/PatternMatch.cpp)3
-rw-r--r--mlir/lib/StandardOps/StandardOps.cpp279
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp187
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp3
5 files changed, 284 insertions, 221 deletions
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index a55c1449eb4..f6d236211d4 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
+#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
@@ -440,12 +441,44 @@ void MLIRContext::emitDiagnostic(Location *location, const llvm::Twine &message,
// Dialect and Operation Registration
//===----------------------------------------------------------------------===//
+/// Return information about all registered IR dialects.
+std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
+ std::vector<Dialect *> result;
+ result.reserve(getImpl().dialects.size());
+ for (auto &dialect : getImpl().dialects)
+ result.push_back(dialect.get());
+ return result;
+}
+
/// Register this dialect object with the specified context. The context
/// takes ownership of the heap allocated dialect.
void Dialect::registerDialect(MLIRContext *context) {
context->getImpl().dialects.push_back(std::unique_ptr<Dialect>(this));
}
+/// Return information about all registered operations. This isn't very
+/// efficient, typically you should ask the operations about their properties
+/// directly.
+std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() const {
+ // We just have the operations in a non-deterministic hash table order. Dump
+ // into a temporary array, then sort it by operation name to get a stable
+ // ordering.
+ StringMap<AbstractOperation> &registeredOps = getImpl().registeredOperations;
+
+ std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
+ opsToSort.reserve(registeredOps.size());
+ for (auto &elt : registeredOps)
+ opsToSort.push_back({elt.first(), &elt.second});
+
+ llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
+
+ std::vector<AbstractOperation *> result;
+ result.reserve(opsToSort.size());
+ for (auto &elt : opsToSort)
+ result.push_back(elt.second);
+ return result;
+}
+
void Dialect::addOperation(AbstractOperation opInfo) {
assert(opInfo.name.startswith(opPrefix) &&
"op name doesn't start with prefix");
diff --git a/mlir/lib/Transforms/Utils/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 6cc3b436c6d..01a5e57686a 100644
--- a/mlir/lib/Transforms/Utils/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -15,10 +15,9 @@
// limitations under the License.
// =============================================================================
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Statements.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Transforms/PatternMatch.h"
using namespace mlir;
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp
index d118c2bc0e4..ea460a17742 100644
--- a/mlir/lib/StandardOps/StandardOps.cpp
+++ b/mlir/lib/StandardOps/StandardOps.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
@@ -41,6 +42,39 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
}
//===----------------------------------------------------------------------===//
+// Common canonicalization pattern support logic
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This is a common class used for patterns of the form
+/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
+/// into the root operation directly.
+struct MemRefCastFolder : public Pattern {
+ /// The rootOpName is the name of the root operation to match against.
+ MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
+ : Pattern(rootOpName, context, 1) {}
+
+ std::pair<PatternBenefit, std::unique_ptr<PatternState>>
+ match(Operation *op) const override {
+ for (auto *operand : op->getOperands())
+ if (auto *memref = operand->getDefiningOperation())
+ if (memref->isa<MemRefCastOp>())
+ return matchSuccess();
+
+ return matchFailure();
+ }
+
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+ if (auto *memref = op->getOperand(i)->getDefiningOperation())
+ if (auto cast = memref->dyn_cast<MemRefCastOp>())
+ op->setOperand(i, cast->getOperand());
+ rewriter.updatedRootInPlace(op);
+ }
+};
+} // end anonymous namespace.
+
+//===----------------------------------------------------------------------===//
// AddFOp
//===----------------------------------------------------------------------===//
@@ -72,6 +106,36 @@ Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
return nullptr;
}
+namespace {
+/// addi(x, 0) -> x
+///
+struct SimplifyAddX0 : public Pattern {
+ SimplifyAddX0(MLIRContext *context)
+ : Pattern(AddIOp::getOperationName(), context, 1) {}
+
+ std::pair<PatternBenefit, std::unique_ptr<PatternState>>
+ match(Operation *op) const override {
+ auto addi = op->cast<AddIOp>();
+ if (auto *operandOp = addi->getOperand(1)->getDefiningOperation())
+ // TODO: Support splatted zero as well. We need a general zero pattern.
+ if (auto cst = operandOp->dyn_cast<ConstantIntOp>()) {
+ if (cst->getValue() == 0)
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ rewriter.replaceSingleResultOp(op, op->getOperand(0));
+ }
+};
+} // end anonymous namespace.
+
+void AddIOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ results.push_back(std::make_unique<SimplifyAddX0>(context));
+}
+
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
@@ -148,6 +212,82 @@ bool AllocOp::verify() const {
return false;
}
+namespace {
+/// Fold constant dimensions into an alloc instruction.
+struct SimplifyAllocConst : public Pattern {
+ SimplifyAllocConst(MLIRContext *context)
+ : Pattern(AllocOp::getOperationName(), context, 1) {}
+
+ std::pair<PatternBenefit, std::unique_ptr<PatternState>>
+ match(Operation *op) const override {
+ auto alloc = op->cast<AllocOp>();
+
+ // Check to see if any dimensions operands are constants. If so, we can
+ // substitute and drop them.
+ for (auto *operand : alloc->getOperands())
+ if (auto *opOperation = operand->getDefiningOperation())
+ if (opOperation->isa<ConstantIndexOp>())
+ return matchSuccess();
+ return matchFailure();
+ }
+
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ auto allocOp = op->cast<AllocOp>();
+ auto memrefType = allocOp->getType();
+
+ // Ok, we have one or more constant operands. Collect the non-constant ones
+ // and keep track of the resultant memref type to build.
+ SmallVector<int, 4> newShapeConstants;
+ newShapeConstants.reserve(memrefType->getRank());
+ SmallVector<SSAValue *, 4> newOperands;
+ SmallVector<SSAValue *, 4> droppedOperands;
+
+ unsigned dynamicDimPos = 0;
+ for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
+ int dimSize = memrefType->getDimSize(dim);
+ // If this is already static dimension, keep it.
+ if (dimSize != -1) {
+ newShapeConstants.push_back(dimSize);
+ continue;
+ }
+ auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningOperation();
+ OpPointer<ConstantIndexOp> constantIndexOp;
+ if (defOp && (constantIndexOp = defOp->dyn_cast<ConstantIndexOp>())) {
+ // Dynamic shape dimension will be folded.
+ newShapeConstants.push_back(constantIndexOp->getValue());
+ // Record to check for zero uses later below.
+ droppedOperands.push_back(constantIndexOp);
+ } else {
+ // Dynamic shape dimension not folded; copy operand from old memref.
+ newShapeConstants.push_back(-1);
+ newOperands.push_back(allocOp->getOperand(dynamicDimPos));
+ }
+ dynamicDimPos++;
+ }
+
+ // Create new memref type (which will have fewer dynamic dimensions).
+ auto *newMemRefType = MemRefType::get(
+ newShapeConstants, memrefType->getElementType(),
+ memrefType->getAffineMaps(), memrefType->getMemorySpace());
+ assert(newOperands.size() == newMemRefType->getNumDynamicDims());
+
+ // Create and insert the alloc op for the new memref.
+ auto newAlloc =
+ rewriter.create<AllocOp>(allocOp->getLoc(), newMemRefType, newOperands);
+ // Insert a cast so we have the same type as the old alloc.
+ auto resultCast = rewriter.create<MemRefCastOp>(allocOp->getLoc(), newAlloc,
+ allocOp->getType());
+
+ rewriter.replaceSingleResultOp(op, resultCast, droppedOperands);
+ }
+};
+} // end anonymous namespace.
+
+void AllocOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ results.push_back(std::make_unique<SimplifyAllocConst>(context));
+}
+
//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
@@ -310,6 +450,13 @@ bool DeallocOp::verify() const {
return false;
}
+void DeallocOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ /// dealloc(memrefcast) -> dealloc
+ results.push_back(
+ std::make_unique<MemRefCastFolder>(getOperationName(), context));
+}
+
//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
@@ -465,6 +612,13 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
return false;
}
+void DmaStartOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ /// dma_start(memrefcast) -> dma_start
+ results.push_back(
+ std::make_unique<MemRefCastFolder>(getOperationName(), context));
+}
+
// ---------------------------------------------------------------------------
// DmaWaitOp
// ---------------------------------------------------------------------------
@@ -509,6 +663,13 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
return false;
}
+void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ /// dma_wait(memrefcast) -> dma_wait
+ results.push_back(
+ std::make_unique<MemRefCastFolder>(getOperationName(), context));
+}
+
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
@@ -630,6 +791,13 @@ bool LoadOp::verify() const {
return false;
}
+void LoadOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ /// load(memrefcast) -> load
+ results.push_back(
+ std::make_unique<MemRefCastFolder>(getOperationName(), context));
+}
+
//===----------------------------------------------------------------------===//
// MemRefCastOp
//===----------------------------------------------------------------------===//
@@ -710,43 +878,6 @@ Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
}
//===----------------------------------------------------------------------===//
-// TensorCastOp
-//===----------------------------------------------------------------------===//
-
-bool TensorCastOp::verify() const {
- auto *opType = dyn_cast<TensorType>(getOperand()->getType());
- auto *resType = dyn_cast<TensorType>(getType());
- if (!opType || !resType)
- return emitOpError("requires input and result types to be tensors");
-
- if (opType == resType)
- return emitOpError("requires the input and result type to be different");
-
- if (opType->getElementType() != resType->getElementType())
- return emitOpError(
- "requires input and result element types to be the same");
-
- // If the source or destination are unranked, then the cast is valid.
- auto *opRType = dyn_cast<RankedTensorType>(opType);
- auto *resRType = dyn_cast<RankedTensorType>(resType);
- if (!opRType || !resRType)
- return false;
-
- // If they are both ranked, they have to have the same rank, and any specified
- // dimensions must match.
- if (opRType->getRank() != resRType->getRank())
- return emitOpError("requires input and result ranks to match");
-
- for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
- int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
- if (opDim != -1 && resultDim != -1 && opDim != resultDim)
- return emitOpError("requires static dimensions to match");
- }
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -813,6 +944,13 @@ bool StoreOp::verify() const {
return false;
}
+void StoreOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ /// store(memrefcast) -> store
+ results.push_back(
+ std::make_unique<MemRefCastFolder>(getOperationName(), context));
+}
+
//===----------------------------------------------------------------------===//
// SubFOp
//===----------------------------------------------------------------------===//
@@ -844,3 +982,70 @@ Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
return nullptr;
}
+
+namespace {
+/// subi(x,x) -> 0
+///
+struct SimplifyXMinusX : public Pattern {
+ SimplifyXMinusX(MLIRContext *context)
+ : Pattern(SubIOp::getOperationName(), context, 1) {}
+
+ std::pair<PatternBenefit, std::unique_ptr<PatternState>>
+ match(Operation *op) const override {
+ auto subi = op->cast<SubIOp>();
+ if (subi->getOperand(0) == subi->getOperand(1))
+ return matchSuccess();
+
+ return matchFailure();
+ }
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ auto subi = op->cast<SubIOp>();
+ auto result =
+ rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
+
+ rewriter.replaceSingleResultOp(op, result);
+ }
+};
+} // end anonymous namespace.
+
+void SubIOp::getCanonicalizationPatterns(OwningPatternList &results,
+ MLIRContext *context) {
+ results.push_back(std::make_unique<SimplifyXMinusX>(context));
+}
+
+//===----------------------------------------------------------------------===//
+// TensorCastOp
+//===----------------------------------------------------------------------===//
+
+bool TensorCastOp::verify() const {
+ auto *opType = dyn_cast<TensorType>(getOperand()->getType());
+ auto *resType = dyn_cast<TensorType>(getType());
+ if (!opType || !resType)
+ return emitOpError("requires input and result types to be tensors");
+
+ if (opType == resType)
+ return emitOpError("requires the input and result type to be different");
+
+ if (opType->getElementType() != resType->getElementType())
+ return emitOpError(
+ "requires input and result element types to be the same");
+
+ // If the source or destination are unranked, then the cast is valid.
+ auto *opRType = dyn_cast<RankedTensorType>(opType);
+ auto *resRType = dyn_cast<RankedTensorType>(resType);
+ if (!opRType || !resRType)
+ return false;
+
+ // If they are both ranked, they have to have the same rank, and any specified
+ // dimensions must match.
+ if (opRType->getRank() != resRType->getRank())
+ return emitOpError("requires input and result ranks to match");
+
+ for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
+ int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
+ if (opDim != -1 && resultDim != -1 && opDim != resultDim)
+ return emitOpError("requires static dimensions to match");
+ }
+
+ return false;
+}
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 7f42ada5428..b7bfd785101 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -20,170 +20,13 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/PatternMatch.h"
-#include <memory>
using namespace mlir;
//===----------------------------------------------------------------------===//
-// Definition of a few patterns for canonicalizing operations.
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
-/// into the root operation directly.
-struct MemRefCastFolder : public Pattern {
- /// The rootOpName is the name of the root operation to match against.
- MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
- : Pattern(rootOpName, context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- for (auto *operand : op->getOperands())
- if (auto *memref = operand->getDefiningOperation())
- if (memref->isa<MemRefCastOp>())
- return matchSuccess();
-
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOperation())
- if (auto cast = memref->dyn_cast<MemRefCastOp>())
- op->setOperand(i, cast->getOperand());
- rewriter.updatedRootInPlace(op);
- }
-};
-} // end anonymous namespace.
-
-namespace {
-/// subi(x,x) -> 0
-///
-struct SimplifyXMinusX : public Pattern {
- SimplifyXMinusX(MLIRContext *context)
- : Pattern(SubIOp::getOperationName(), context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- auto subi = op->cast<SubIOp>();
- if (subi->getOperand(0) == subi->getOperand(1))
- return matchSuccess();
-
- return matchFailure();
- }
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto subi = op->cast<SubIOp>();
- auto result =
- rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
-
- rewriter.replaceSingleResultOp(op, result);
- }
-};
-} // end anonymous namespace.
-
-namespace {
-/// addi(x, 0) -> x
-///
-struct SimplifyAddX0 : public Pattern {
- SimplifyAddX0(MLIRContext *context)
- : Pattern(AddIOp::getOperationName(), context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- auto addi = op->cast<AddIOp>();
- if (auto *operandOp = addi->getOperand(1)->getDefiningOperation())
- // TODO: Support splatted zero as well. We need a general zero pattern.
- if (auto cst = operandOp->dyn_cast<ConstantIntOp>()) {
- if (cst->getValue() == 0)
- return matchSuccess();
- }
-
- return matchFailure();
- }
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- rewriter.replaceSingleResultOp(op, op->getOperand(0));
- }
-};
-} // end anonymous namespace.
-
-namespace {
-/// Fold constant dimensions into an alloc instruction.
-struct SimplifyAllocConst : public Pattern {
- SimplifyAllocConst(MLIRContext *context)
- : Pattern(AllocOp::getOperationName(), context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- auto alloc = op->cast<AllocOp>();
-
- // Check to see if any dimensions operands are constants. If so, we can
- // substitute and drop them.
- for (auto *operand : alloc->getOperands())
- if (auto *opOperation = operand->getDefiningOperation())
- if (opOperation->isa<ConstantIndexOp>())
- return matchSuccess();
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto allocOp = op->cast<AllocOp>();
- auto memrefType = allocOp->getType();
-
- // Ok, we have one or more constant operands. Collect the non-constant ones
- // and keep track of the resultant memref type to build.
- SmallVector<int, 4> newShapeConstants;
- newShapeConstants.reserve(memrefType->getRank());
- SmallVector<SSAValue *, 4> newOperands;
- SmallVector<SSAValue *, 4> droppedOperands;
-
- unsigned dynamicDimPos = 0;
- for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
- int dimSize = memrefType->getDimSize(dim);
- // If this is already static dimension, keep it.
- if (dimSize != -1) {
- newShapeConstants.push_back(dimSize);
- continue;
- }
- auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningOperation();
- OpPointer<ConstantIndexOp> constantIndexOp;
- if (defOp && (constantIndexOp = defOp->dyn_cast<ConstantIndexOp>())) {
- // Dynamic shape dimension will be folded.
- newShapeConstants.push_back(constantIndexOp->getValue());
- // Record to check for zero uses later below.
- droppedOperands.push_back(constantIndexOp);
- } else {
- // Dynamic shape dimension not folded; copy operand from old memref.
- newShapeConstants.push_back(-1);
- newOperands.push_back(allocOp->getOperand(dynamicDimPos));
- }
- dynamicDimPos++;
- }
-
- // Create new memref type (which will have fewer dynamic dimensions).
- auto *newMemRefType = MemRefType::get(
- newShapeConstants, memrefType->getElementType(),
- memrefType->getAffineMaps(), memrefType->getMemorySpace());
- assert(newOperands.size() == newMemRefType->getNumDynamicDims());
-
- // Create and insert the alloc op for the new memref.
- auto newAlloc =
- rewriter.create<AllocOp>(allocOp->getLoc(), newMemRefType, newOperands);
- // Insert a cast so we have the same type as the old alloc.
- auto resultCast = rewriter.create<MemRefCastOp>(allocOp->getLoc(), newAlloc,
- allocOp->getType());
-
- rewriter.replaceSingleResultOp(op, resultCast, droppedOperands);
- }
-};
-} // end anonymous namespace.
-
-//===----------------------------------------------------------------------===//
// The actual Canonicalizer Pass.
//===----------------------------------------------------------------------===//
@@ -208,29 +51,13 @@ PassResult Canonicalizer::runOnMLFunction(MLFunction *fn) {
PassResult Canonicalizer::runOnFunction(Function *fn) {
auto *context = fn->getContext();
-
- // TODO: Instead of a hard coded list of patterns, ask the operations
- // for their canonicalization patterns.
OwningPatternList patterns;
- patterns.push_back(std::make_unique<SimplifyXMinusX>(context));
- patterns.push_back(std::make_unique<SimplifyAddX0>(context));
- patterns.push_back(std::make_unique<SimplifyAllocConst>(context));
- /// load(memrefcast) -> load
- patterns.push_back(
- std::make_unique<MemRefCastFolder>(LoadOp::getOperationName(), context));
- /// store(memrefcast) -> store
- patterns.push_back(
- std::make_unique<MemRefCastFolder>(StoreOp::getOperationName(), context));
- /// dealloc(memrefcast) -> dealloc
- patterns.push_back(std::make_unique<MemRefCastFolder>(
- DeallocOp::getOperationName(), context));
- /// dma_start(memrefcast) -> dma_start
- patterns.push_back(std::make_unique<MemRefCastFolder>(
- DmaStartOp::getOperationName(), context));
- /// dma_wait(memrefcast) -> dma_wait
- patterns.push_back(std::make_unique<MemRefCastFolder>(
- DmaWaitOp::getOperationName(), context));
+ // TODO: Instead of adding all known patterns from the whole system lazily add
+ // and cache the canonicalization patterns for ops we see in practice when
+ // building the worklist. For now, we just grab everything.
+ for (auto *op : fn->getContext()->getRegisteredOperations())
+ op->getCanonicalizationPatterns(patterns, context);
applyPatternsGreedily(fn, std::move(patterns));
return success();
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 44b211f1af9..ebad9e20316 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -21,8 +21,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/StandardOps/StandardOps.h"
-#include "mlir/Transforms/PatternMatch.h"
+#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
OpenPOWER on IntegriCloud