summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Utils
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp343
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp316
-rw-r--r--mlir/lib/Transforms/Utils/Pass.cpp41
-rw-r--r--mlir/lib/Transforms/Utils/PatternMatch.cpp196
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp394
5 files changed, 1290 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
new file mode 100644
index 00000000000..5ed8eac323e
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -0,0 +1,343 @@
+//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements mlir::applyPatternsGreedily.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/PatternMatch.h"
+#include "llvm/ADT/DenseMap.h"
+using namespace mlir;
+
+namespace {
+class WorklistRewriter;
+
+/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
+/// applies the locally optimal patterns in a roughly "bottom up" way.
+class GreedyPatternRewriteDriver {
+public:
+ explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns)
+ : matcher(std::move(patterns)) {
+ worklist.reserve(64);
+ }
+
+ void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
+
+ void addToWorklist(Operation *op) {
+ worklistMap[op] = worklist.size();
+ worklist.push_back(op);
+ }
+
+ Operation *popFromWorklist() {
+ auto *op = worklist.back();
+ worklist.pop_back();
+
+ // This operation is no longer in the worklist, keep worklistMap up to date.
+ if (op)
+ worklistMap.erase(op);
+ return op;
+ }
+
+ /// If the specified operation is in the worklist, remove it. If not, this is
+ /// a no-op.
+ void removeFromWorklist(Operation *op) {
+ auto it = worklistMap.find(op);
+ if (it != worklistMap.end()) {
+ assert(worklist[it->second] == op && "malformed worklist data structure");
+ worklist[it->second] = nullptr;
+ }
+ }
+
+private:
+ /// The low-level pattern matcher.
+ PatternMatcher matcher;
+
+ /// The worklist for this transformation keeps track of the operations that
+ /// need to be revisited, plus their index in the worklist. This allows us to
+ /// efficiently remove operations from the worklist when they are removed even
+ /// if they aren't the root of a pattern.
+ std::vector<Operation *> worklist;
+ DenseMap<Operation *, unsigned> worklistMap;
+
+ /// As part of canonicalization, we move constants to the top of the entry
+ /// block of the current function and de-duplicate them. This keeps track of
+ /// constants we have done this for.
+ DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
+};
+}; // end anonymous namespace
+
+/// This is a listener object that updates our worklists and other data
+/// structures in response to operations being added and removed.
+namespace {
+class WorklistRewriter : public PatternRewriter {
+public:
+ WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
+ : PatternRewriter(context), driver(driver) {}
+
+ virtual void setInsertionPoint(Operation *op) = 0;
+
+ // If an operation is about to be removed, make sure it is not in our
+ // worklist anymore because we'd get dangling references to it.
+ void notifyOperationRemoved(Operation *op) override {
+ driver.removeFromWorklist(op);
+ }
+
+ GreedyPatternRewriteDriver &driver;
+};
+
+} // end anonymous namespace
+
+void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
+ WorklistRewriter &rewriter) {
+ // These are scratch vectors used in the constant folding loop below.
+ SmallVector<Attribute *, 8> operandConstants, resultConstants;
+
+ while (!worklist.empty()) {
+ auto *op = popFromWorklist();
+
+ // Nulls get added to the worklist when operations are removed, ignore them.
+ if (op == nullptr)
+ continue;
+
+ // If we have a constant op, unique it into the entry block.
+ if (auto constant = op->dyn_cast<ConstantOp>()) {
+ // If this constant is dead, remove it, being careful to keep
+ // uniquedConstants up to date.
+ if (constant->use_empty()) {
+ auto it =
+ uniquedConstants.find({constant->getValue(), constant->getType()});
+ if (it != uniquedConstants.end() && it->second == op)
+ uniquedConstants.erase(it);
+ constant->erase();
+ continue;
+ }
+
+ // Check to see if we already have a constant with this type and value:
+ auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
+ constant->getType())];
+ if (entry) {
+ // If this constant is already our uniqued one, then leave it alone.
+ if (entry == op)
+ continue;
+
+ // Otherwise replace this redundant constant with the uniqued one. We
+ // know this is safe because we move constants to the top of the
+ // function when they are uniqued, so we know they dominate all uses.
+ constant->replaceAllUsesWith(entry->getResult(0));
+ constant->erase();
+ continue;
+ }
+
+ // If we have no entry, then we should unique this constant as the
+ // canonical version. To ensure safe dominance, move the operation to the
+ // top of the function.
+ entry = op;
+
+ // TODO: If we make terminators into Operations then we could turn this
+ // into a nice Operation::moveBefore(Operation*) method. We just need the
+ // guarantee that a block is non-empty.
+ if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
+ auto &entryBB = cfgFunc->front();
+ cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin());
+ } else {
+ auto *mlFunc = cast<MLFunction>(currentFunction);
+ cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
+ }
+
+ continue;
+ }
+
+ // If the operation has no side effects, and no users, then it is trivially
+ // dead - remove it.
+ if (op->hasNoSideEffect() && op->use_empty()) {
+ op->erase();
+ continue;
+ }
+
+ // Check to see if any operands to the instruction is constant and whether
+ // the operation knows how to constant fold itself.
+ operandConstants.clear();
+ for (auto *operand : op->getOperands()) {
+ Attribute *operandCst = nullptr;
+ if (auto *operandOp = operand->getDefiningOperation()) {
+ if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
+ operandCst = operandConstantOp->getValue();
+ }
+ operandConstants.push_back(operandCst);
+ }
+
+ // If constant folding was successful, create the result constants, RAUW the
+ // operation and remove it.
+ resultConstants.clear();
+ if (!op->constantFold(operandConstants, resultConstants)) {
+ rewriter.setInsertionPoint(op);
+
+ for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+ auto *res = op->getResult(i);
+ if (res->use_empty()) // ignore dead uses.
+ continue;
+
+ // If we already have a canonicalized version of this constant, just
+ // reuse it. Otherwise create a new one.
+ SSAValue *cstValue;
+ auto it = uniquedConstants.find({resultConstants[i], res->getType()});
+ if (it != uniquedConstants.end())
+ cstValue = it->second->getResult(0);
+ else
+ cstValue = rewriter.create<ConstantOp>(
+ op->getLoc(), resultConstants[i], res->getType());
+ res->replaceAllUsesWith(cstValue);
+ }
+
+ assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
+ op->erase();
+ continue;
+ }
+
+ // If this is an associative binary operation with a constant on the LHS,
+ // move it to the right side.
+ if (operandConstants.size() == 2 && operandConstants[0] &&
+ !operandConstants[1]) {
+ auto *newLHS = op->getOperand(1);
+ op->setOperand(1, op->getOperand(0));
+ op->setOperand(0, newLHS);
+ }
+
+ // Check to see if we have any patterns that match this node.
+ auto match = matcher.findMatch(op);
+ if (!match.first)
+ continue;
+
+ // Make sure that any new operations are inserted at this point.
+ rewriter.setInsertionPoint(op);
+ match.first->rewrite(op, std::move(match.second), rewriter);
+ }
+
+ uniquedConstants.clear();
+}
+
+static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
+ class MLFuncRewriter : public WorklistRewriter {
+ public:
+ MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
+ : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
+
+ // Implement the hook for creating operations, and make sure that newly
+ // created ops are added to the worklist for processing.
+ Operation *createOperation(const OperationState &state) override {
+ auto *result = builder.createOperation(state);
+ driver.addToWorklist(result);
+ return result;
+ }
+
+ // When the root of a pattern is about to be replaced, it can trigger
+ // simplifications to its users - make sure to add them to the worklist
+ // before the root is changed.
+ void notifyRootReplaced(Operation *op) override {
+ auto *opStmt = cast<OperationStmt>(op);
+ for (auto *result : opStmt->getResults())
+ // TODO: Add a result->getUsers() iterator.
+ for (auto &user : result->getUses()) {
+ if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
+ driver.addToWorklist(op);
+ }
+
+ // TODO: Walk the operand list dropping them as we go. If any of them
+ // drop to zero uses, then add them to the worklist to allow them to be
+ // deleted as dead.
+ }
+
+ void setInsertionPoint(Operation *op) override {
+ // Any new operations should be added before this statement.
+ builder.setInsertionPoint(cast<OperationStmt>(op));
+ }
+
+ private:
+ MLFuncBuilder &builder;
+ };
+
+ GreedyPatternRewriteDriver driver(std::move(patterns));
+ fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
+
+ MLFuncBuilder mlBuilder(fn);
+ MLFuncRewriter rewriter(driver, mlBuilder);
+ driver.simplifyFunction(fn, rewriter);
+}
+
+static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
+ class CFGFuncRewriter : public WorklistRewriter {
+ public:
+ CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
+ : WorklistRewriter(driver, builder.getContext()), builder(builder) {}
+
+ // Implement the hook for creating operations, and make sure that newly
+ // created ops are added to the worklist for processing.
+ Operation *createOperation(const OperationState &state) override {
+ auto *result = builder.createOperation(state);
+ driver.addToWorklist(result);
+ return result;
+ }
+
+ // When the root of a pattern is about to be replaced, it can trigger
+ // simplifications to its users - make sure to add them to the worklist
+ // before the root is changed.
+ void notifyRootReplaced(Operation *op) override {
+ auto *opStmt = cast<OperationInst>(op);
+ for (auto *result : opStmt->getResults())
+ // TODO: Add a result->getUsers() iterator.
+ for (auto &user : result->getUses()) {
+ if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
+ driver.addToWorklist(op);
+ }
+
+ // TODO: Walk the operand list dropping them as we go. If any of them
+ // drop to zero uses, then add them to the worklist to allow them to be
+ // deleted as dead.
+ }
+
+ void setInsertionPoint(Operation *op) override {
+ // Any new operations should be added before this instruction.
+ builder.setInsertionPoint(cast<OperationInst>(op));
+ }
+
+ private:
+ CFGFuncBuilder &builder;
+ };
+
+ GreedyPatternRewriteDriver driver(std::move(patterns));
+ for (auto &bb : *fn)
+ for (auto &op : bb)
+ driver.addToWorklist(&op);
+
+ CFGFuncBuilder cfgBuilder(fn);
+ CFGFuncRewriter rewriter(driver, cfgBuilder);
+ driver.simplifyFunction(fn, rewriter);
+}
+
+/// Rewrite the specified function by repeatedly applying the highest benefit
+/// patterns in a greedy work-list driven manner.
+///
+void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) {
+ if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
+ processCFGFunction(cfg, std::move(patterns));
+ } else {
+ processMLFunction(cast<MLFunction>(fn), std::move(patterns));
+ }
+}
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
new file mode 100644
index 00000000000..a6a850280bb
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -0,0 +1,316 @@
+//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous loop transformation routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/LoopUtils.h"
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "LoopUtils"
+
+using namespace mlir;
+
+/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
+/// the specified trip count, stride, and unroll factor. Returns nullptr when
+/// the trip count can't be expressed as an affine expression.
+AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder) {
+ auto lbMap = forStmt.getLowerBoundMap();
+
+ // Single result lower bound map only.
+ if (lbMap.getNumResults() != 1)
+ return AffineMap::Null();
+
+ // Sometimes, the trip count cannot be expressed as an affine expression.
+ auto tripCount = getTripCountExpr(forStmt);
+ if (!tripCount)
+ return AffineMap::Null();
+
+ AffineExpr lb(lbMap.getResult(0));
+ unsigned step = forStmt.getStep();
+ auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
+
+ return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
+ {newUb}, {});
+}
+
+/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
+/// bound 'lb' and with the specified trip count, stride, and unroll factor.
+/// Returns an AffinMap with nullptr storage (that evaluates to false)
+/// when the trip count can't be expressed as an affine expression.
+AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
+ unsigned unrollFactor,
+ MLFuncBuilder *builder) {
+ auto lbMap = forStmt.getLowerBoundMap();
+
+ // Single result lower bound map only.
+ if (lbMap.getNumResults() != 1)
+ return AffineMap::Null();
+
+ // Sometimes the trip count cannot be expressed as an affine expression.
+ AffineExpr tripCount(getTripCountExpr(forStmt));
+ if (!tripCount)
+ return AffineMap::Null();
+
+ AffineExpr lb(lbMap.getResult(0));
+ unsigned step = forStmt.getStep();
+ auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
+ return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
+ {newLb}, {});
+}
+
+/// Promotes the loop body of a forStmt to its containing block if the forStmt
+/// was known to have a single iteration. Returns false otherwise.
+// TODO(bondhugula): extend this for arbitrary affine bounds.
+bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
+ Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
+ if (!tripCount.hasValue() || tripCount.getValue() != 1)
+ return false;
+
+ // TODO(mlir-team): there is no builder for a max.
+ if (forStmt->getLowerBoundMap().getNumResults() != 1)
+ return false;
+
+ // Replaces all IV uses to its single iteration value.
+ if (!forStmt->use_empty()) {
+ if (forStmt->hasConstantLowerBound()) {
+ auto *mlFunc = forStmt->findFunction();
+ MLFuncBuilder topBuilder(&mlFunc->front());
+ auto constOp = topBuilder.create<ConstantIndexOp>(
+ forStmt->getLoc(), forStmt->getConstantLowerBound());
+ forStmt->replaceAllUsesWith(constOp);
+ } else {
+ const AffineBound lb = forStmt->getLowerBound();
+ SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
+ lb.operand_end());
+ MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
+ auto affineApplyOp = builder.create<AffineApplyOp>(
+ forStmt->getLoc(), lb.getMap(), lbOperands);
+ forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
+ }
+ }
+ // Move the loop body statements to the loop's containing block.
+ auto *block = forStmt->getBlock();
+ block->getStatements().splice(StmtBlock::iterator(forStmt),
+ forStmt->getStatements());
+ forStmt->erase();
+ return true;
+}
+
+/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void mlir::promoteSingleIterationLoops(MLFunction *f) {
+ // Gathers all innermost loops through a post order pruned walk.
+ class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
+ public:
+ void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
+ };
+
+ LoopBodyPromoter fsw;
+ fsw.walkPostOrder(f);
+}
+
+/// Generates a for 'stmt' with the specified lower and upper bounds while
+/// generating the right IV remappings for the delayed statements. The
+/// statement blocks that go into the loop are specified in stmtGroupQueue
+/// starting from the specified offset, and in that order; the first element of
+/// the pair specifies the delay applied to that group of statements. Returns
+/// nullptr if the generated loop simplifies to a single iteration one.
+static ForStmt *
+generateLoop(AffineMap lb, AffineMap ub,
+ const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
+ &stmtGroupQueue,
+ unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
+ SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
+ SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
+
+ auto *loopChunk =
+ b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub);
+ OperationStmt::OperandMapTy operandMap;
+
+ for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
+ it != e; ++it) {
+ auto elt = *it;
+ // All 'same delay' statements get added with the operands being remapped
+ // (to results of cloned statements).
+ // Generate the remapping if the delay is not zero: oldIV = newIV - delay.
+ // TODO(bondhugula): check if srcForStmt is actually used in elt.second
+ // instead of just checking if it's used at all.
+ if (!srcForStmt->use_empty() && elt.first != 0) {
+ auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk);
+ auto *oldIV =
+ b.create<AffineApplyOp>(
+ srcForStmt->getLoc(),
+ b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)),
+ loopChunk)
+ ->getResult(0);
+ operandMap[srcForStmt] = cast<MLValue>(oldIV);
+ } else {
+ operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk);
+ }
+ for (auto *stmt : elt.second) {
+ loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
+ }
+ }
+ if (promoteIfSingleIteration(loopChunk))
+ return nullptr;
+ return loopChunk;
+}
+
+/// Skew the statements in the body of a 'for' statement with the specified
+/// statement-wise delays. The delays are with respect to the original execution
+/// order. A delay of zero for each statement will lead to no change.
+// The skewing of statements with respect to one another can be used for example
+// to allow overlap of asynchronous operations (such as DMA communication) with
+// computation, or just relative shifting of statements for better register
+// reuse, locality or parallelism. As such, the delays are typically expected to
+// be at most of the order of the number of statements. This method should not
+// be used as a substitute for loop distribution/fission.
+// This method uses an algorithm// in time linear in the number of statements in
+// the body of the for loop - (using the 'sweep line' paradigm). This method
+// asserts preservation of SSA dominance. A check for that as well as that for
+// memory-based depedence preservation check rests with the users of this
+// method.
+UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
+ bool unrollPrologueEpilogue) {
+ if (forStmt->getStatements().empty())
+ return UtilResult::Success;
+
+ // If the trip counts aren't constant, we would need versioning and
+ // conditional guards (or context information to prevent such versioning). The
+ // better way to pipeline for such loops is to first tile them and extract
+ // constant trip count "full tiles" before applying this.
+ auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+ if (!mayBeConstTripCount.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
+ return UtilResult::Success;
+ }
+ uint64_t tripCount = mayBeConstTripCount.getValue();
+
+ assert(isStmtwiseShiftValid(*forStmt, delays) &&
+ "shifts will lead to an invalid transformation\n");
+
+ unsigned numChildStmts = forStmt->getStatements().size();
+
+ // Do a linear time (counting) sort for the delays.
+ uint64_t maxDelay = 0;
+ for (unsigned i = 0; i < numChildStmts; i++) {
+ maxDelay = std::max(maxDelay, delays[i]);
+ }
+ // Such large delays are not the typical use case.
+ if (maxDelay >= numChildStmts) {
+ LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";);
+ return UtilResult::Success;
+ }
+
+ // An array of statement groups sorted by delay amount; each group has all
+ // statements with the same delay in the order in which they appear in the
+ // body of the 'for' stmt.
+ std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1);
+ unsigned pos = 0;
+ for (auto &stmt : *forStmt) {
+ auto delay = delays[pos++];
+ sortedStmtGroups[delay].push_back(&stmt);
+ }
+
+ // Unless the shifts have a specific pattern (which actually would be the
+ // common use case), prologue and epilogue are not meaningfully defined.
+ // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
+ // loop generated as the prologue and the last as epilogue and unroll these
+ // fully.
+ ForStmt *prologue = nullptr;
+ ForStmt *epilogue = nullptr;
+
+ // Do a sweep over the sorted delays while storing open groups in a
+ // vector, and generating loop portions as necessary during the sweep. A block
+ // of statements is paired with its delay.
+ std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
+
+ auto origLbMap = forStmt->getLowerBoundMap();
+ uint64_t lbDelay = 0;
+ MLFuncBuilder b(forStmt);
+ for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
+ // If nothing is delayed by d, continue.
+ if (sortedStmtGroups[d].empty())
+ continue;
+ if (!stmtGroupQueue.empty()) {
+ assert(d >= 1 &&
+ "Queue expected to be empty when the first block is found");
+ // The interval for which the loop needs to be generated here is:
+ // ( lbDelay, min(lbDelay + tripCount - 1, d - 1) ] and the body of the
+ // loop needs to have all statements in stmtQueue in that order.
+ ForStmt *res;
+ if (lbDelay + tripCount - 1 < d - 1) {
+ res = generateLoop(
+ b.getShiftedAffineMap(origLbMap, lbDelay),
+ b.getShiftedAffineMap(origLbMap, lbDelay + tripCount - 1),
+ stmtGroupQueue, 0, forStmt, &b);
+ // Entire loop for the queued stmt groups generated, empty it.
+ stmtGroupQueue.clear();
+ lbDelay += tripCount;
+ } else {
+ res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
+ b.getShiftedAffineMap(origLbMap, d - 1),
+ stmtGroupQueue, 0, forStmt, &b);
+ lbDelay = d;
+ }
+ if (!prologue && res)
+ prologue = res;
+ epilogue = res;
+ } else {
+ // Start of first interval.
+ lbDelay = d;
+ }
+ // Augment the list of statements that get into the current open interval.
+ stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
+ }
+
+ // Those statements groups left in the queue now need to be processed (FIFO)
+ // and their loops completed.
+ for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
+ uint64_t ubDelay = stmtGroupQueue[i].first + tripCount - 1;
+ epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay),
+ b.getShiftedAffineMap(origLbMap, ubDelay),
+ stmtGroupQueue, i, forStmt, &b);
+ lbDelay = ubDelay + 1;
+ if (!prologue)
+ prologue = epilogue;
+ }
+
+ // Erase the original for stmt.
+ forStmt->erase();
+
+ if (unrollPrologueEpilogue && prologue)
+ loopUnrollFull(prologue);
+ if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
+ loopUnrollFull(epilogue);
+
+ return UtilResult::Success;
+}
diff --git a/mlir/lib/Transforms/Utils/Pass.cpp b/mlir/lib/Transforms/Utils/Pass.cpp
new file mode 100644
index 00000000000..8b1110798bd
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/Pass.cpp
@@ -0,0 +1,41 @@
+//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements common pass infrastructure.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Pass.h"
+#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+/// Function passes walk a module and look at each function with their
+/// corresponding hooks and terminates upon error encountered.
+PassResult FunctionPass::runOnModule(Module *m) {
+ for (auto &fn : *m) {
+ if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
+ if (runOnMLFunction(mlFunc))
+ return failure();
+ if (auto *cfgFunc = dyn_cast<CFGFunction>(&fn))
+ if (runOnCFGFunction(cfgFunc))
+ return failure();
+ }
+ return success();
+}
diff --git a/mlir/lib/Transforms/Utils/PatternMatch.cpp b/mlir/lib/Transforms/Utils/PatternMatch.cpp
new file mode 100644
index 00000000000..6cc3b436c6d
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/PatternMatch.cpp
@@ -0,0 +1,196 @@
+//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/SSAValue.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/PatternMatch.h"
+using namespace mlir;
+
+PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
+ assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
+ "This pattern match benefit is too large to represent");
+}
+
+unsigned short PatternBenefit::getBenefit() const {
+ assert(representation != ImpossibleToMatchSentinel &&
+ "Pattern doesn't match");
+ return representation;
+}
+
+bool PatternBenefit::operator==(const PatternBenefit& other) {
+ if (isImpossibleToMatch())
+ return other.isImpossibleToMatch();
+ if (other.isImpossibleToMatch())
+ return false;
+ return getBenefit() == other.getBenefit();
+}
+
+bool PatternBenefit::operator!=(const PatternBenefit& other) {
+ return !(*this == other);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern implementation
+//===----------------------------------------------------------------------===//
+
+Pattern::Pattern(StringRef rootName, MLIRContext *context,
+ Optional<PatternBenefit> staticBenefit)
+ : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) {
+}
+
+Pattern::Pattern(StringRef rootName, MLIRContext *context,
+ unsigned staticBenefit)
+ : rootKind(rootName, context), staticBenefit(staticBenefit) {}
+
+Optional<PatternBenefit> Pattern::getStaticBenefit() const {
+ return staticBenefit;
+}
+
+OperationName Pattern::getRootKind() const { return rootKind; }
+
+void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
+ PatternRewriter &rewriter) const {
+ rewrite(op, rewriter);
+}
+
+void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
+ llvm_unreachable("need to implement one of the rewrite functions!");
+}
+
+/// This method indicates that no match was found.
+PatternMatchResult Pattern::matchFailure() {
+ return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()};
+}
+
+/// This method indicates that a match was found and has the specified cost.
+PatternMatchResult
+Pattern::matchSuccess(PatternBenefit benefit,
+ std::unique_ptr<PatternState> state) const {
+ assert((!getStaticBenefit().hasValue() ||
+ getStaticBenefit().getValue() == benefit) &&
+ "This version of matchSuccess must be called with a benefit that "
+ "matches the static benefit if set!");
+
+ return {benefit, std::move(state)};
+}
+
+/// This method indicates that a match was found for patterns that have a
+/// known static benefit.
+PatternMatchResult
+Pattern::matchSuccess(std::unique_ptr<PatternState> state) const {
+ auto benefit = getStaticBenefit();
+ assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
+ return matchSuccess(benefit.getValue(), std::move(state));
+}
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter implementation
+//===----------------------------------------------------------------------===//
+
+PatternRewriter::~PatternRewriter() {
+ // Out of line to provide a vtable anchor for the class.
+}
+
+/// This method is used as the final replacement hook for patterns that match
+/// a single result value. In addition to replacing and removing the
+/// specified operation, clients can specify a list of other nodes that this
+/// replacement may make (perhaps transitively) dead. If any of those ops are
+/// dead, this will remove them as well.
+void PatternRewriter::replaceSingleResultOp(
+ Operation *op, SSAValue *newValue, ArrayRef<SSAValue *> opsToRemoveIfDead) {
+ // Notify the rewriter subclass that we're about to replace this root.
+ notifyRootReplaced(op);
+
+ assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!");
+ op->getResult(0)->replaceAllUsesWith(newValue);
+
+ notifyOperationRemoved(op);
+ op->erase();
+
+ // TODO: Process the opsToRemoveIfDead list, removing things and calling the
+ // notifyOperationRemoved hook in the process.
+}
+
+/// This method is used as the final notification hook for patterns that end
+/// up modifying the pattern root in place, by changing its operands. This is
+/// a minor efficiency win (it avoids creating a new instruction and removing
+/// the old one) but also often allows simpler code in the client.
+///
+/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
+/// should remove if they are dead at this point.
+///
+void PatternRewriter::updatedRootInPlace(
+ Operation *op, ArrayRef<SSAValue *> opsToRemoveIfDead) {
+ // Notify the rewriter subclass that we're about to replace this root.
+ notifyRootUpdated(op);
+
+ // TODO: Process the opsToRemoveIfDead list, removing things and calling the
+ // notifyOperationRemoved hook in the process.
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatcher implementation
+//===----------------------------------------------------------------------===//
+
+/// Find the highest benefit pattern available in the pattern set for the DAG
+/// rooted at the specified node. This returns the pattern if found, or null
+/// if there are no matches.
+auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
+ // TODO: This is a completely trivial implementation, expand this in the
+ // future.
+
+ // Keep track of the best match, the benefit of it, and any matcher specific
+ // state it is maintaining.
+ MatchResult bestMatch = {nullptr, nullptr};
+ Optional<PatternBenefit> bestBenefit;
+
+ for (auto &pattern : patterns) {
+ // Ignore patterns that are for the wrong root.
+ if (pattern->getRootKind() != op->getName())
+ continue;
+
+ // If we know the static cost of the pattern is worse than what we've
+ // already found then don't run it.
+ auto staticBenefit = pattern->getStaticBenefit();
+ if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
+ staticBenefit.getValue().getBenefit() <
+ bestBenefit.getValue().getBenefit())
+ continue;
+
+ // Check to see if this pattern matches this node.
+ auto result = pattern->match(op);
+ auto benefit = result.first;
+
+ // If this pattern failed to match, ignore it.
+ if (benefit.isImpossibleToMatch())
+ continue;
+
+ // If it matched but had lower benefit than our best match so far, then
+ // ignore it.
+ if (bestBenefit.hasValue() &&
+ benefit.getBenefit() < bestBenefit.getValue().getBenefit())
+ continue;
+
+ // Okay we found a match that is better than our previous one, remember it.
+ bestBenefit = benefit;
+ bestMatch = {pattern.get(), std::move(result.second)};
+ }
+
+ // If we found any match, return it.
+ return bestMatch;
+}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
new file mode 100644
index 00000000000..4432a8070de
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -0,0 +1,394 @@
+//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous transformation routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Utils.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/DenseMap.h"
+
+using namespace mlir;
+
+/// Return true if this operation dereferences one or more memref's.
+// Temporary utility: will be replaced when this is modeled through
+// side-effects/op traits. TODO(b/117228571)
+static bool isMemRefDereferencingOp(const Operation &op) {
+ if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
+ op.isa<DmaWaitOp>())
+ return true;
+ return false;
+}
+
+/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
+/// old memref's indices to the new memref using the supplied affine map
+/// and adding any additional indices. The new memref could be of a different
+/// shape or rank, but of the same elemental type. Additional indices are added
+/// at the start for now.
+// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
+// extended to add additional indices at any position.
+bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
+ MLValue *newMemRef,
+ ArrayRef<MLValue *> extraIndices,
+ AffineMap indexRemap) {
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+ (void)newMemRefRank; // unused in opt mode
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
+ (void)newMemRefRank;
+ if (indexRemap) {
+ assert(indexRemap.getNumInputs() == oldMemRefRank);
+ assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+ } else {
+ assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+ }
+
+ // Assert same elemental type.
+ assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
+ cast<MemRefType>(newMemRef->getType())->getElementType());
+
+ // Check if memref was used in a non-deferencing context.
+ for (const StmtOperand &use : oldMemRef->getUses()) {
+ auto *opStmt = cast<OperationStmt>(use.getOwner());
+ // Failure: memref used in a non-deferencing op (potentially escapes); no
+ // replacement in these cases.
+ if (!isMemRefDereferencingOp(*opStmt))
+ return false;
+ }
+
+ // Walk all uses of old memref. Statement using the memref gets replaced.
+ for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
+ StmtOperand &use = *(it++);
+ auto *opStmt = cast<OperationStmt>(use.getOwner());
+ assert(isMemRefDereferencingOp(*opStmt) &&
+ "memref deferencing op expected");
+
+ auto getMemRefOperandPos = [&]() -> unsigned {
+ unsigned i;
+ for (i = 0; i < opStmt->getNumOperands(); i++) {
+ if (opStmt->getOperand(i) == oldMemRef)
+ break;
+ }
+ assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
+ return i;
+ };
+ unsigned memRefOperandPos = getMemRefOperandPos();
+
+ // Construct the new operation statement using this memref.
+ SmallVector<MLValue *, 8> operands;
+ operands.reserve(opStmt->getNumOperands() + extraIndices.size());
+ // Insert the non-memref operands.
+ operands.insert(operands.end(), opStmt->operand_begin(),
+ opStmt->operand_begin() + memRefOperandPos);
+ operands.push_back(newMemRef);
+
+ MLFuncBuilder builder(opStmt);
+ for (auto *extraIndex : extraIndices) {
+ // TODO(mlir-team): An operation/SSA value should provide a method to
+ // return the position of an SSA result in its defining
+ // operation.
+ assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
+ "single result op's expected to generate these indices");
+ assert((cast<MLValue>(extraIndex)->isValidDim() ||
+ cast<MLValue>(extraIndex)->isValidSymbol()) &&
+ "invalid memory op index");
+ operands.push_back(cast<MLValue>(extraIndex));
+ }
+
+ // Construct new indices. The indices of a memref come right after it, i.e.,
+ // at position memRefOperandPos + 1.
+ SmallVector<SSAValue *, 4> indices(
+ opStmt->operand_begin() + memRefOperandPos + 1,
+ opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
+ if (indexRemap) {
+ auto remapOp =
+ builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
+ // Remapped indices.
+ for (auto *index : remapOp->getOperation()->getResults())
+ operands.push_back(cast<MLValue>(index));
+ } else {
+ // No remapping specified.
+ for (auto *index : indices)
+ operands.push_back(cast<MLValue>(index));
+ }
+
+ // Insert the remaining operands unmodified.
+ operands.insert(operands.end(),
+ opStmt->operand_begin() + memRefOperandPos + 1 +
+ oldMemRefRank,
+ opStmt->operand_end());
+
+ // Result types don't change. Both memref's are of the same elemental type.
+ SmallVector<Type *, 8> resultTypes;
+ resultTypes.reserve(opStmt->getNumResults());
+ for (const auto *result : opStmt->getResults())
+ resultTypes.push_back(result->getType());
+
+ // Create the new operation.
+ auto *repOp =
+ builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
+ resultTypes, opStmt->getAttrs());
+ // Replace old memref's deferencing op's uses.
+ unsigned r = 0;
+ for (auto *res : opStmt->getResults()) {
+ res->replaceAllUsesWith(repOp->getResult(r++));
+ }
+ opStmt->erase();
+ }
+ return true;
+}
+
+// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
+// its results equal to the number of 'operands, as a composition
+// of all other AffineApplyOps reachable from input parameter 'operands'. If the
+// operands were drawing results from multiple affine apply ops, this also leads
+// to a collapse into a single affine apply op. The final results of the
+// composed AffineApplyOp are returned in output parameter 'results'.
+OperationStmt *
+mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
+ ArrayRef<MLValue *> operands,
+ ArrayRef<OperationStmt *> affineApplyOps,
+ SmallVectorImpl<SSAValue *> &results) {
+ // Create identity map with same number of dimensions as number of operands.
+ auto map = builder->getMultiDimIdentityMap(operands.size());
+ // Initialize AffineValueMap with identity map.
+ AffineValueMap valueMap(map, operands);
+
+ for (auto *opStmt : affineApplyOps) {
+ assert(opStmt->isa<AffineApplyOp>());
+ auto affineApplyOp = opStmt->cast<AffineApplyOp>();
+ // Forward substitute 'affineApplyOp' into 'valueMap'.
+ valueMap.forwardSubstitute(*affineApplyOp);
+ }
+ // Compose affine maps from all ancestor AffineApplyOps.
+ // Create new AffineApplyOp from 'valueMap'.
+ unsigned numOperands = valueMap.getNumOperands();
+ SmallVector<SSAValue *, 4> outOperands(numOperands);
+ for (unsigned i = 0; i < numOperands; ++i) {
+ outOperands[i] = valueMap.getOperand(i);
+ }
+ // Create new AffineApplyOp based on 'valueMap'.
+ auto affineApplyOp =
+ builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands);
+ results.resize(operands.size());
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ results[i] = affineApplyOp->getResult(i);
+ }
+ return cast<OperationStmt>(affineApplyOp->getOperation());
+}
+
+/// Given an operation statement, inserts a new single affine apply operation,
+/// that is exclusively used by this operation statement, and that provides all
+/// operands that are results of an affine_apply as a function of loop iterators
+/// and program parameters and whose results are.
+///
+/// Before
+///
+/// for %i = 0 to #map(%N)
+/// %idx = affine_apply (d0) -> (d0 mod 2) (%i)
+/// "send"(%idx, %A, ...)
+/// "compute"(%idx)
+///
+/// After
+///
+/// for %i = 0 to #map(%N)
+/// %idx = affine_apply (d0) -> (d0 mod 2) (%i)
+/// "send"(%idx, %A, ...)
+/// %idx_ = affine_apply (d0) -> (d0 mod 2) (%i)
+/// "compute"(%idx_)
+///
+/// This allows applying different transformations on send and compute (for eg.
+/// different shifts/delays).
+///
+/// Returns nullptr either if none of opStmt's operands were the result of an
+/// affine_apply and thus there was no affine computation slice to create, or if
+/// all the affine_apply op's supplying operands to this opStmt do not have any
+/// uses besides this opStmt. Returns the new affine_apply operation statement
+/// otherwise.
+OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
+ // Collect all operands that are results of affine apply ops.
+ SmallVector<MLValue *, 4> subOperands;
+ subOperands.reserve(opStmt->getNumOperands());
+ for (auto *operand : opStmt->getOperands()) {
+ auto *defStmt = operand->getDefiningStmt();
+ if (defStmt && defStmt->isa<AffineApplyOp>()) {
+ subOperands.push_back(operand);
+ }
+ }
+
+ // Gather sequence of AffineApplyOps reachable from 'subOperands'.
+ SmallVector<OperationStmt *, 4> affineApplyOps;
+ getReachableAffineApplyOps(subOperands, affineApplyOps);
+ // Skip transforming if there are no affine maps to compose.
+ if (affineApplyOps.empty())
+ return nullptr;
+
+ // Check if all uses of the affine apply op's lie in this op stmt
+ // itself, in which case there would be nothing to do.
+ bool localized = true;
+ for (auto *op : affineApplyOps) {
+ for (auto *result : op->getResults()) {
+ for (auto &use : result->getUses()) {
+ if (use.getOwner() != opStmt) {
+ localized = false;
+ break;
+ }
+ }
+ }
+ }
+ if (localized)
+ return nullptr;
+
+ MLFuncBuilder builder(opStmt);
+ SmallVector<SSAValue *, 4> results;
+ auto *affineApplyStmt = createComposedAffineApplyOp(
+ &builder, opStmt->getLoc(), subOperands, affineApplyOps, results);
+ assert(results.size() == subOperands.size() &&
+ "number of results should be the same as the number of subOperands");
+
+ // Construct the new operands that include the results from the composed
+ // affine apply op above instead of existing ones (subOperands). So, they
+ // differ from opStmt's operands only for those operands in 'subOperands', for
+ // which they will be replaced by the corresponding one from 'results'.
+ SmallVector<MLValue *, 4> newOperands(opStmt->getOperands());
+ for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
+ // Replace the subOperands from among the new operands.
+ unsigned j, f;
+ for (j = 0, f = subOperands.size(); j < f; j++) {
+ if (newOperands[i] == subOperands[j])
+ break;
+ }
+ if (j < subOperands.size()) {
+ newOperands[i] = cast<MLValue>(results[j]);
+ }
+ }
+
+ for (unsigned idx = 0; idx < newOperands.size(); idx++) {
+ opStmt->setOperand(idx, newOperands[idx]);
+ }
+
+ return affineApplyStmt;
+}
+
+void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
+ if (affineApplyOp->getOperation()->getOperationFunction()->getKind() !=
+ Function::Kind::MLFunc) {
+ // TODO: Support forward substitution for CFGFunctions.
+ return;
+ }
+ auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation());
+ // Iterate through all uses of all results of 'opStmt', forward substituting
+ // into any uses which are AffineApplyOps.
+ for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
+ ++resultIndex) {
+ const MLValue *result = opStmt->getResult(resultIndex);
+ for (auto it = result->use_begin(); it != result->use_end();) {
+ StmtOperand &use = *(it++);
+ auto *useStmt = use.getOwner();
+ auto *useOpStmt = dyn_cast<OperationStmt>(useStmt);
+ // Skip if use is not AffineApplyOp.
+ if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
+ continue;
+ // Advance iterator past 'opStmt' operands which also use 'result'.
+ while (it != result->use_end() && it->getOwner() == useStmt)
+ ++it;
+
+ MLFuncBuilder builder(useOpStmt);
+ // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
+ auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
+ AffineValueMap valueMap(*oldAffineApplyOp);
+ // Forward substitute 'result' at index 'i' into 'valueMap'.
+ valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
+
+ // Create new AffineApplyOp from 'valueMap'.
+ unsigned numOperands = valueMap.getNumOperands();
+ SmallVector<SSAValue *, 4> operands(numOperands);
+ for (unsigned i = 0; i < numOperands; ++i) {
+ operands[i] = valueMap.getOperand(i);
+ }
+ auto newAffineApplyOp = builder.create<AffineApplyOp>(
+ useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
+
+ // Update all uses to use results from 'newAffineApplyOp'.
+ for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
+ oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
+ newAffineApplyOp->getResult(i));
+ }
+ // Erase 'oldAffineApplyOp'.
+ oldAffineApplyOp->getOperation()->erase();
+ }
+ }
+}
+
+/// Folds the specified (lower or upper) bound to a constant if possible
+/// considering its operands. Returns false if the folding happens for any of
+/// the bounds, true otherwise.
+bool mlir::constantFoldBounds(ForStmt *forStmt) {
+ auto foldLowerOrUpperBound = [forStmt](bool lower) {
+ // Check if the bound is already a constant.
+ if (lower && forStmt->hasConstantLowerBound())
+ return true;
+ if (!lower && forStmt->hasConstantUpperBound())
+ return true;
+
+ // Check to see if each of the operands is the result of a constant. If so,
+ // get the value. If not, ignore it.
+ SmallVector<Attribute *, 8> operandConstants;
+ auto boundOperands = lower ? forStmt->getLowerBoundOperands()
+ : forStmt->getUpperBoundOperands();
+ for (const auto *operand : boundOperands) {
+ Attribute *operandCst = nullptr;
+ if (auto *operandOp = operand->getDefiningOperation()) {
+ if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
+ operandCst = operandConstantOp->getValue();
+ }
+ operandConstants.push_back(operandCst);
+ }
+
+ AffineMap boundMap =
+ lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
+ assert(boundMap.getNumResults() >= 1 &&
+ "bound maps should have at least one result");
+ SmallVector<Attribute *, 4> foldedResults;
+ if (boundMap.constantFold(operandConstants, foldedResults))
+ return true;
+
+ // Compute the max or min as applicable over the results.
+ assert(!foldedResults.empty() && "bounds should have at least one result");
+ auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
+ for (unsigned i = 1; i < foldedResults.size(); i++) {
+ auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
+ maxOrMin = lower ? std::max(maxOrMin, foldedResult)
+ : std::min(maxOrMin, foldedResult);
+ }
+ lower ? forStmt->setConstantLowerBound(maxOrMin)
+ : forStmt->setConstantUpperBound(maxOrMin);
+
+ // Return false on success.
+ return false;
+ };
+
+ bool ret = foldLowerOrUpperBound(/*lower=*/true);
+ ret &= foldLowerOrUpperBound(/*lower=*/false);
+ return ret;
+}
OpenPOWER on IntegriCloud