summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Canonicalizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Canonicalizer.cpp')
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp187
1 files changed, 7 insertions, 180 deletions
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 7f42ada5428..b7bfd785101 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -20,170 +20,13 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/PatternMatch.h"
-#include <memory>
using namespace mlir;
//===----------------------------------------------------------------------===//
-// Definition of a few patterns for canonicalizing operations.
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This is a common class used for patterns of the form
-/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
-/// into the root operation directly.
-struct MemRefCastFolder : public Pattern {
- /// The rootOpName is the name of the root operation to match against.
- MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
- : Pattern(rootOpName, context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- for (auto *operand : op->getOperands())
- if (auto *memref = operand->getDefiningOperation())
- if (memref->isa<MemRefCastOp>())
- return matchSuccess();
-
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOperation())
- if (auto cast = memref->dyn_cast<MemRefCastOp>())
- op->setOperand(i, cast->getOperand());
- rewriter.updatedRootInPlace(op);
- }
-};
-} // end anonymous namespace.
-
-namespace {
-/// subi(x,x) -> 0
-///
-struct SimplifyXMinusX : public Pattern {
- SimplifyXMinusX(MLIRContext *context)
- : Pattern(SubIOp::getOperationName(), context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- auto subi = op->cast<SubIOp>();
- if (subi->getOperand(0) == subi->getOperand(1))
- return matchSuccess();
-
- return matchFailure();
- }
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto subi = op->cast<SubIOp>();
- auto result =
- rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
-
- rewriter.replaceSingleResultOp(op, result);
- }
-};
-} // end anonymous namespace.
-
-namespace {
-/// addi(x, 0) -> x
-///
-struct SimplifyAddX0 : public Pattern {
- SimplifyAddX0(MLIRContext *context)
- : Pattern(AddIOp::getOperationName(), context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- auto addi = op->cast<AddIOp>();
- if (auto *operandOp = addi->getOperand(1)->getDefiningOperation())
- // TODO: Support splatted zero as well. We need a general zero pattern.
- if (auto cst = operandOp->dyn_cast<ConstantIntOp>()) {
- if (cst->getValue() == 0)
- return matchSuccess();
- }
-
- return matchFailure();
- }
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- rewriter.replaceSingleResultOp(op, op->getOperand(0));
- }
-};
-} // end anonymous namespace.
-
-namespace {
-/// Fold constant dimensions into an alloc instruction.
-struct SimplifyAllocConst : public Pattern {
- SimplifyAllocConst(MLIRContext *context)
- : Pattern(AllocOp::getOperationName(), context, 1) {}
-
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- match(Operation *op) const override {
- auto alloc = op->cast<AllocOp>();
-
- // Check to see if any dimensions operands are constants. If so, we can
- // substitute and drop them.
- for (auto *operand : alloc->getOperands())
- if (auto *opOperation = operand->getDefiningOperation())
- if (opOperation->isa<ConstantIndexOp>())
- return matchSuccess();
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto allocOp = op->cast<AllocOp>();
- auto memrefType = allocOp->getType();
-
- // Ok, we have one or more constant operands. Collect the non-constant ones
- // and keep track of the resultant memref type to build.
- SmallVector<int, 4> newShapeConstants;
- newShapeConstants.reserve(memrefType->getRank());
- SmallVector<SSAValue *, 4> newOperands;
- SmallVector<SSAValue *, 4> droppedOperands;
-
- unsigned dynamicDimPos = 0;
- for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
- int dimSize = memrefType->getDimSize(dim);
- // If this is already static dimension, keep it.
- if (dimSize != -1) {
- newShapeConstants.push_back(dimSize);
- continue;
- }
- auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningOperation();
- OpPointer<ConstantIndexOp> constantIndexOp;
- if (defOp && (constantIndexOp = defOp->dyn_cast<ConstantIndexOp>())) {
- // Dynamic shape dimension will be folded.
- newShapeConstants.push_back(constantIndexOp->getValue());
- // Record to check for zero uses later below.
- droppedOperands.push_back(constantIndexOp);
- } else {
- // Dynamic shape dimension not folded; copy operand from old memref.
- newShapeConstants.push_back(-1);
- newOperands.push_back(allocOp->getOperand(dynamicDimPos));
- }
- dynamicDimPos++;
- }
-
- // Create new memref type (which will have fewer dynamic dimensions).
- auto *newMemRefType = MemRefType::get(
- newShapeConstants, memrefType->getElementType(),
- memrefType->getAffineMaps(), memrefType->getMemorySpace());
- assert(newOperands.size() == newMemRefType->getNumDynamicDims());
-
- // Create and insert the alloc op for the new memref.
- auto newAlloc =
- rewriter.create<AllocOp>(allocOp->getLoc(), newMemRefType, newOperands);
- // Insert a cast so we have the same type as the old alloc.
- auto resultCast = rewriter.create<MemRefCastOp>(allocOp->getLoc(), newAlloc,
- allocOp->getType());
-
- rewriter.replaceSingleResultOp(op, resultCast, droppedOperands);
- }
-};
-} // end anonymous namespace.
-
-//===----------------------------------------------------------------------===//
// The actual Canonicalizer Pass.
//===----------------------------------------------------------------------===//
@@ -208,29 +51,13 @@ PassResult Canonicalizer::runOnMLFunction(MLFunction *fn) {
PassResult Canonicalizer::runOnFunction(Function *fn) {
auto *context = fn->getContext();
-
- // TODO: Instead of a hard coded list of patterns, ask the operations
- // for their canonicalization patterns.
OwningPatternList patterns;
- patterns.push_back(std::make_unique<SimplifyXMinusX>(context));
- patterns.push_back(std::make_unique<SimplifyAddX0>(context));
- patterns.push_back(std::make_unique<SimplifyAllocConst>(context));
- /// load(memrefcast) -> load
- patterns.push_back(
- std::make_unique<MemRefCastFolder>(LoadOp::getOperationName(), context));
- /// store(memrefcast) -> store
- patterns.push_back(
- std::make_unique<MemRefCastFolder>(StoreOp::getOperationName(), context));
- /// dealloc(memrefcast) -> dealloc
- patterns.push_back(std::make_unique<MemRefCastFolder>(
- DeallocOp::getOperationName(), context));
- /// dma_start(memrefcast) -> dma_start
- patterns.push_back(std::make_unique<MemRefCastFolder>(
- DmaStartOp::getOperationName(), context));
- /// dma_wait(memrefcast) -> dma_wait
- patterns.push_back(std::make_unique<MemRefCastFolder>(
- DmaWaitOp::getOperationName(), context));
+ // TODO: Instead of adding all known patterns from the whole system lazily add
+ // and cache the canonicalization patterns for ops we see in practice when
+ // building the worklist. For now, we just grab everything.
+ for (auto *op : fn->getContext()->getRegisteredOperations())
+ op->getCanonicalizationPatterns(patterns, context);
applyPatternsGreedily(fn, std::move(patterns));
return success();
OpenPOWER on IntegriCloud