diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils')
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 343 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 316 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Pass.cpp | 41 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/PatternMatch.cpp | 196 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 394 |
5 files changed, 1290 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp new file mode 100644 index 00000000000..5ed8eac323e --- /dev/null +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -0,0 +1,343 @@ +//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements mlir::applyPatternsGreedily. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/StandardOps/StandardOps.h" +#include "mlir/Transforms/PatternMatch.h" +#include "llvm/ADT/DenseMap.h" +using namespace mlir; + +namespace { +class WorklistRewriter; + +/// This is a worklist-driven driver for the PatternMatcher, which repeatedly +/// applies the locally optimal patterns in a roughly "bottom up" way. +class GreedyPatternRewriteDriver { +public: + explicit GreedyPatternRewriteDriver(OwningPatternList &&patterns) + : matcher(std::move(patterns)) { + worklist.reserve(64); + } + + void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter); + + void addToWorklist(Operation *op) { + worklistMap[op] = worklist.size(); + worklist.push_back(op); + } + + Operation *popFromWorklist() { + auto *op = worklist.back(); + worklist.pop_back(); + + // This operation is no longer in the worklist, keep worklistMap up to date. + if (op) + worklistMap.erase(op); + return op; + } + + /// If the specified operation is in the worklist, remove it. If not, this is + /// a no-op. + void removeFromWorklist(Operation *op) { + auto it = worklistMap.find(op); + if (it != worklistMap.end()) { + assert(worklist[it->second] == op && "malformed worklist data structure"); + worklist[it->second] = nullptr; + } + } + +private: + /// The low-level pattern matcher. + PatternMatcher matcher; + + /// The worklist for this transformation keeps track of the operations that + /// need to be revisited, plus their index in the worklist. This allows us to + /// efficiently remove operations from the worklist when they are removed even + /// if they aren't the root of a pattern. + std::vector<Operation *> worklist; + DenseMap<Operation *, unsigned> worklistMap; + + /// As part of canonicalization, we move constants to the top of the entry + /// block of the current function and de-duplicate them. This keeps track of + /// constants we have done this for. + DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants; +}; +}; // end anonymous namespace + +/// This is a listener object that updates our worklists and other data +/// structures in response to operations being added and removed. +namespace { +class WorklistRewriter : public PatternRewriter { +public: + WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context) + : PatternRewriter(context), driver(driver) {} + + virtual void setInsertionPoint(Operation *op) = 0; + + // If an operation is about to be removed, make sure it is not in our + // worklist anymore because we'd get dangling references to it. + void notifyOperationRemoved(Operation *op) override { + driver.removeFromWorklist(op); + } + + GreedyPatternRewriteDriver &driver; +}; + +} // end anonymous namespace + +void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, + WorklistRewriter &rewriter) { + // These are scratch vectors used in the constant folding loop below. + SmallVector<Attribute *, 8> operandConstants, resultConstants; + + while (!worklist.empty()) { + auto *op = popFromWorklist(); + + // Nulls get added to the worklist when operations are removed, ignore them. + if (op == nullptr) + continue; + + // If we have a constant op, unique it into the entry block. + if (auto constant = op->dyn_cast<ConstantOp>()) { + // If this constant is dead, remove it, being careful to keep + // uniquedConstants up to date. + if (constant->use_empty()) { + auto it = + uniquedConstants.find({constant->getValue(), constant->getType()}); + if (it != uniquedConstants.end() && it->second == op) + uniquedConstants.erase(it); + constant->erase(); + continue; + } + + // Check to see if we already have a constant with this type and value: + auto &entry = uniquedConstants[std::make_pair(constant->getValue(), + constant->getType())]; + if (entry) { + // If this constant is already our uniqued one, then leave it alone. + if (entry == op) + continue; + + // Otherwise replace this redundant constant with the uniqued one. We + // know this is safe because we move constants to the top of the + // function when they are uniqued, so we know they dominate all uses. + constant->replaceAllUsesWith(entry->getResult(0)); + constant->erase(); + continue; + } + + // If we have no entry, then we should unique this constant as the + // canonical version. To ensure safe dominance, move the operation to the + // top of the function. + entry = op; + + // TODO: If we make terminators into Operations then we could turn this + // into a nice Operation::moveBefore(Operation*) method. We just need the + // guarantee that a block is non-empty. + if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) { + auto &entryBB = cfgFunc->front(); + cast<OperationInst>(op)->moveBefore(&entryBB, entryBB.begin()); + } else { + auto *mlFunc = cast<MLFunction>(currentFunction); + cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin()); + } + + continue; + } + + // If the operation has no side effects, and no users, then it is trivially + // dead - remove it. + if (op->hasNoSideEffect() && op->use_empty()) { + op->erase(); + continue; + } + + // Check to see if any operands to the instruction is constant and whether + // the operation knows how to constant fold itself. + operandConstants.clear(); + for (auto *operand : op->getOperands()) { + Attribute *operandCst = nullptr; + if (auto *operandOp = operand->getDefiningOperation()) { + if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) + operandCst = operandConstantOp->getValue(); + } + operandConstants.push_back(operandCst); + } + + // If constant folding was successful, create the result constants, RAUW the + // operation and remove it. + resultConstants.clear(); + if (!op->constantFold(operandConstants, resultConstants)) { + rewriter.setInsertionPoint(op); + + for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { + auto *res = op->getResult(i); + if (res->use_empty()) // ignore dead uses. + continue; + + // If we already have a canonicalized version of this constant, just + // reuse it. Otherwise create a new one. + SSAValue *cstValue; + auto it = uniquedConstants.find({resultConstants[i], res->getType()}); + if (it != uniquedConstants.end()) + cstValue = it->second->getResult(0); + else + cstValue = rewriter.create<ConstantOp>( + op->getLoc(), resultConstants[i], res->getType()); + res->replaceAllUsesWith(cstValue); + } + + assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); + op->erase(); + continue; + } + + // If this is an associative binary operation with a constant on the LHS, + // move it to the right side. + if (operandConstants.size() == 2 && operandConstants[0] && + !operandConstants[1]) { + auto *newLHS = op->getOperand(1); + op->setOperand(1, op->getOperand(0)); + op->setOperand(0, newLHS); + } + + // Check to see if we have any patterns that match this node. + auto match = matcher.findMatch(op); + if (!match.first) + continue; + + // Make sure that any new operations are inserted at this point. + rewriter.setInsertionPoint(op); + match.first->rewrite(op, std::move(match.second), rewriter); + } + + uniquedConstants.clear(); +} + +static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) { + class MLFuncRewriter : public WorklistRewriter { + public: + MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder) + : WorklistRewriter(driver, builder.getContext()), builder(builder) {} + + // Implement the hook for creating operations, and make sure that newly + // created ops are added to the worklist for processing. + Operation *createOperation(const OperationState &state) override { + auto *result = builder.createOperation(state); + driver.addToWorklist(result); + return result; + } + + // When the root of a pattern is about to be replaced, it can trigger + // simplifications to its users - make sure to add them to the worklist + // before the root is changed. + void notifyRootReplaced(Operation *op) override { + auto *opStmt = cast<OperationStmt>(op); + for (auto *result : opStmt->getResults()) + // TODO: Add a result->getUsers() iterator. + for (auto &user : result->getUses()) { + if (auto *op = dyn_cast<OperationStmt>(user.getOwner())) + driver.addToWorklist(op); + } + + // TODO: Walk the operand list dropping them as we go. If any of them + // drop to zero uses, then add them to the worklist to allow them to be + // deleted as dead. + } + + void setInsertionPoint(Operation *op) override { + // Any new operations should be added before this statement. + builder.setInsertionPoint(cast<OperationStmt>(op)); + } + + private: + MLFuncBuilder &builder; + }; + + GreedyPatternRewriteDriver driver(std::move(patterns)); + fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); }); + + MLFuncBuilder mlBuilder(fn); + MLFuncRewriter rewriter(driver, mlBuilder); + driver.simplifyFunction(fn, rewriter); +} + +static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) { + class CFGFuncRewriter : public WorklistRewriter { + public: + CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder) + : WorklistRewriter(driver, builder.getContext()), builder(builder) {} + + // Implement the hook for creating operations, and make sure that newly + // created ops are added to the worklist for processing. + Operation *createOperation(const OperationState &state) override { + auto *result = builder.createOperation(state); + driver.addToWorklist(result); + return result; + } + + // When the root of a pattern is about to be replaced, it can trigger + // simplifications to its users - make sure to add them to the worklist + // before the root is changed. + void notifyRootReplaced(Operation *op) override { + auto *opStmt = cast<OperationInst>(op); + for (auto *result : opStmt->getResults()) + // TODO: Add a result->getUsers() iterator. + for (auto &user : result->getUses()) { + if (auto *op = dyn_cast<OperationInst>(user.getOwner())) + driver.addToWorklist(op); + } + + // TODO: Walk the operand list dropping them as we go. If any of them + // drop to zero uses, then add them to the worklist to allow them to be + // deleted as dead. + } + + void setInsertionPoint(Operation *op) override { + // Any new operations should be added before this instruction. + builder.setInsertionPoint(cast<OperationInst>(op)); + } + + private: + CFGFuncBuilder &builder; + }; + + GreedyPatternRewriteDriver driver(std::move(patterns)); + for (auto &bb : *fn) + for (auto &op : bb) + driver.addToWorklist(&op); + + CFGFuncBuilder cfgBuilder(fn); + CFGFuncRewriter rewriter(driver, cfgBuilder); + driver.simplifyFunction(fn, rewriter); +} + +/// Rewrite the specified function by repeatedly applying the highest benefit +/// patterns in a greedy work-list driven manner. +/// +void mlir::applyPatternsGreedily(Function *fn, OwningPatternList &&patterns) { + if (auto *cfg = dyn_cast<CFGFunction>(fn)) { + processCFGFunction(cfg, std::move(patterns)); + } else { + processMLFunction(cast<MLFunction>(fn), std::move(patterns)); + } +} diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp new file mode 100644 index 00000000000..a6a850280bb --- /dev/null +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -0,0 +1,316 @@ +//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements miscellaneous loop transformation routines. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/LoopUtils.h" + +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Statements.h" +#include "mlir/IR/StmtVisitor.h" +#include "mlir/StandardOps/StandardOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "LoopUtils" + +using namespace mlir; + +/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with +/// the specified trip count, stride, and unroll factor. Returns nullptr when +/// the trip count can't be expressed as an affine expression. +AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt, + unsigned unrollFactor, + MLFuncBuilder *builder) { + auto lbMap = forStmt.getLowerBoundMap(); + + // Single result lower bound map only. + if (lbMap.getNumResults() != 1) + return AffineMap::Null(); + + // Sometimes, the trip count cannot be expressed as an affine expression. + auto tripCount = getTripCountExpr(forStmt); + if (!tripCount) + return AffineMap::Null(); + + AffineExpr lb(lbMap.getResult(0)); + unsigned step = forStmt.getStep(); + auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step; + + return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), + {newUb}, {}); +} + +/// Returns the lower bound of the cleanup loop when unrolling a loop with lower +/// bound 'lb' and with the specified trip count, stride, and unroll factor. +/// Returns an AffinMap with nullptr storage (that evaluates to false) +/// when the trip count can't be expressed as an affine expression. +AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt, + unsigned unrollFactor, + MLFuncBuilder *builder) { + auto lbMap = forStmt.getLowerBoundMap(); + + // Single result lower bound map only. + if (lbMap.getNumResults() != 1) + return AffineMap::Null(); + + // Sometimes the trip count cannot be expressed as an affine expression. + AffineExpr tripCount(getTripCountExpr(forStmt)); + if (!tripCount) + return AffineMap::Null(); + + AffineExpr lb(lbMap.getResult(0)); + unsigned step = forStmt.getStep(); + auto newLb = lb + (tripCount - tripCount % unrollFactor) * step; + return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), + {newLb}, {}); +} + +/// Promotes the loop body of a forStmt to its containing block if the forStmt +/// was known to have a single iteration. Returns false otherwise. +// TODO(bondhugula): extend this for arbitrary affine bounds. +bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { + Optional<uint64_t> tripCount = getConstantTripCount(*forStmt); + if (!tripCount.hasValue() || tripCount.getValue() != 1) + return false; + + // TODO(mlir-team): there is no builder for a max. + if (forStmt->getLowerBoundMap().getNumResults() != 1) + return false; + + // Replaces all IV uses to its single iteration value. + if (!forStmt->use_empty()) { + if (forStmt->hasConstantLowerBound()) { + auto *mlFunc = forStmt->findFunction(); + MLFuncBuilder topBuilder(&mlFunc->front()); + auto constOp = topBuilder.create<ConstantIndexOp>( + forStmt->getLoc(), forStmt->getConstantLowerBound()); + forStmt->replaceAllUsesWith(constOp); + } else { + const AffineBound lb = forStmt->getLowerBound(); + SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(), + lb.operand_end()); + MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt)); + auto affineApplyOp = builder.create<AffineApplyOp>( + forStmt->getLoc(), lb.getMap(), lbOperands); + forStmt->replaceAllUsesWith(affineApplyOp->getResult(0)); + } + } + // Move the loop body statements to the loop's containing block. + auto *block = forStmt->getBlock(); + block->getStatements().splice(StmtBlock::iterator(forStmt), + forStmt->getStatements()); + forStmt->erase(); + return true; +} + +/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves +/// their body into the containing StmtBlock. +void mlir::promoteSingleIterationLoops(MLFunction *f) { + // Gathers all innermost loops through a post order pruned walk. + class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> { + public: + void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); } + }; + + LoopBodyPromoter fsw; + fsw.walkPostOrder(f); +} + +/// Generates a for 'stmt' with the specified lower and upper bounds while +/// generating the right IV remappings for the delayed statements. The +/// statement blocks that go into the loop are specified in stmtGroupQueue +/// starting from the specified offset, and in that order; the first element of +/// the pair specifies the delay applied to that group of statements. Returns +/// nullptr if the generated loop simplifies to a single iteration one. +static ForStmt * +generateLoop(AffineMap lb, AffineMap ub, + const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> + &stmtGroupQueue, + unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) { + SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands()); + SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands()); + + auto *loopChunk = + b->createFor(srcForStmt->getLoc(), lbOperands, lb, ubOperands, ub); + OperationStmt::OperandMapTy operandMap; + + for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end(); + it != e; ++it) { + auto elt = *it; + // All 'same delay' statements get added with the operands being remapped + // (to results of cloned statements). + // Generate the remapping if the delay is not zero: oldIV = newIV - delay. + // TODO(bondhugula): check if srcForStmt is actually used in elt.second + // instead of just checking if it's used at all. + if (!srcForStmt->use_empty() && elt.first != 0) { + auto b = MLFuncBuilder::getForStmtBodyBuilder(loopChunk); + auto *oldIV = + b.create<AffineApplyOp>( + srcForStmt->getLoc(), + b.getSingleDimShiftAffineMap(-static_cast<int64_t>(elt.first)), + loopChunk) + ->getResult(0); + operandMap[srcForStmt] = cast<MLValue>(oldIV); + } else { + operandMap[srcForStmt] = static_cast<MLValue *>(loopChunk); + } + for (auto *stmt : elt.second) { + loopChunk->push_back(stmt->clone(operandMap, b->getContext())); + } + } + if (promoteIfSingleIteration(loopChunk)) + return nullptr; + return loopChunk; +} + +/// Skew the statements in the body of a 'for' statement with the specified +/// statement-wise delays. The delays are with respect to the original execution +/// order. A delay of zero for each statement will lead to no change. +// The skewing of statements with respect to one another can be used for example +// to allow overlap of asynchronous operations (such as DMA communication) with +// computation, or just relative shifting of statements for better register +// reuse, locality or parallelism. As such, the delays are typically expected to +// be at most of the order of the number of statements. This method should not +// be used as a substitute for loop distribution/fission. +// This method uses an algorithm// in time linear in the number of statements in +// the body of the for loop - (using the 'sweep line' paradigm). This method +// asserts preservation of SSA dominance. A check for that as well as that for +// memory-based depedence preservation check rests with the users of this +// method. +UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays, + bool unrollPrologueEpilogue) { + if (forStmt->getStatements().empty()) + return UtilResult::Success; + + // If the trip counts aren't constant, we would need versioning and + // conditional guards (or context information to prevent such versioning). The + // better way to pipeline for such loops is to first tile them and extract + // constant trip count "full tiles" before applying this. + auto mayBeConstTripCount = getConstantTripCount(*forStmt); + if (!mayBeConstTripCount.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";); + return UtilResult::Success; + } + uint64_t tripCount = mayBeConstTripCount.getValue(); + + assert(isStmtwiseShiftValid(*forStmt, delays) && + "shifts will lead to an invalid transformation\n"); + + unsigned numChildStmts = forStmt->getStatements().size(); + + // Do a linear time (counting) sort for the delays. + uint64_t maxDelay = 0; + for (unsigned i = 0; i < numChildStmts; i++) { + maxDelay = std::max(maxDelay, delays[i]); + } + // Such large delays are not the typical use case. + if (maxDelay >= numChildStmts) { + LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";); + return UtilResult::Success; + } + + // An array of statement groups sorted by delay amount; each group has all + // statements with the same delay in the order in which they appear in the + // body of the 'for' stmt. + std::vector<std::vector<Statement *>> sortedStmtGroups(maxDelay + 1); + unsigned pos = 0; + for (auto &stmt : *forStmt) { + auto delay = delays[pos++]; + sortedStmtGroups[delay].push_back(&stmt); + } + + // Unless the shifts have a specific pattern (which actually would be the + // common use case), prologue and epilogue are not meaningfully defined. + // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first + // loop generated as the prologue and the last as epilogue and unroll these + // fully. + ForStmt *prologue = nullptr; + ForStmt *epilogue = nullptr; + + // Do a sweep over the sorted delays while storing open groups in a + // vector, and generating loop portions as necessary during the sweep. A block + // of statements is paired with its delay. + std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue; + + auto origLbMap = forStmt->getLowerBoundMap(); + uint64_t lbDelay = 0; + MLFuncBuilder b(forStmt); + for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) { + // If nothing is delayed by d, continue. + if (sortedStmtGroups[d].empty()) + continue; + if (!stmtGroupQueue.empty()) { + assert(d >= 1 && + "Queue expected to be empty when the first block is found"); + // The interval for which the loop needs to be generated here is: + // ( lbDelay, min(lbDelay + tripCount - 1, d - 1) ] and the body of the + // loop needs to have all statements in stmtQueue in that order. + ForStmt *res; + if (lbDelay + tripCount - 1 < d - 1) { + res = generateLoop( + b.getShiftedAffineMap(origLbMap, lbDelay), + b.getShiftedAffineMap(origLbMap, lbDelay + tripCount - 1), + stmtGroupQueue, 0, forStmt, &b); + // Entire loop for the queued stmt groups generated, empty it. + stmtGroupQueue.clear(); + lbDelay += tripCount; + } else { + res = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay), + b.getShiftedAffineMap(origLbMap, d - 1), + stmtGroupQueue, 0, forStmt, &b); + lbDelay = d; + } + if (!prologue && res) + prologue = res; + epilogue = res; + } else { + // Start of first interval. + lbDelay = d; + } + // Augment the list of statements that get into the current open interval. + stmtGroupQueue.push_back({d, sortedStmtGroups[d]}); + } + + // Those statements groups left in the queue now need to be processed (FIFO) + // and their loops completed. + for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) { + uint64_t ubDelay = stmtGroupQueue[i].first + tripCount - 1; + epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbDelay), + b.getShiftedAffineMap(origLbMap, ubDelay), + stmtGroupQueue, i, forStmt, &b); + lbDelay = ubDelay + 1; + if (!prologue) + prologue = epilogue; + } + + // Erase the original for stmt. + forStmt->erase(); + + if (unrollPrologueEpilogue && prologue) + loopUnrollFull(prologue); + if (unrollPrologueEpilogue && !epilogue && epilogue != prologue) + loopUnrollFull(epilogue); + + return UtilResult::Success; +} diff --git a/mlir/lib/Transforms/Utils/Pass.cpp b/mlir/lib/Transforms/Utils/Pass.cpp new file mode 100644 index 00000000000..8b1110798bd --- /dev/null +++ b/mlir/lib/Transforms/Utils/Pass.cpp @@ -0,0 +1,41 @@ +//===- Pass.cpp - Pass infrastructure implementation ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements common pass infrastructure. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Pass.h" +#include "mlir/IR/CFGFunction.h" +#include "mlir/IR/MLFunction.h" +#include "mlir/IR/Module.h" + +using namespace mlir; + +/// Function passes walk a module and look at each function with their +/// corresponding hooks and terminates upon error encountered. +PassResult FunctionPass::runOnModule(Module *m) { + for (auto &fn : *m) { + if (auto *mlFunc = dyn_cast<MLFunction>(&fn)) + if (runOnMLFunction(mlFunc)) + return failure(); + if (auto *cfgFunc = dyn_cast<CFGFunction>(&fn)) + if (runOnCFGFunction(cfgFunc)) + return failure(); + } + return success(); +} diff --git a/mlir/lib/Transforms/Utils/PatternMatch.cpp b/mlir/lib/Transforms/Utils/PatternMatch.cpp new file mode 100644 index 00000000000..6cc3b436c6d --- /dev/null +++ b/mlir/lib/Transforms/Utils/PatternMatch.cpp @@ -0,0 +1,196 @@ +//===- PatternMatch.cpp - Base classes for pattern match ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/SSAValue.h" +#include "mlir/IR/Statements.h" +#include "mlir/StandardOps/StandardOps.h" +#include "mlir/Transforms/PatternMatch.h" +using namespace mlir; + +PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { + assert(representation == benefit && benefit != ImpossibleToMatchSentinel && + "This pattern match benefit is too large to represent"); +} + +unsigned short PatternBenefit::getBenefit() const { + assert(representation != ImpossibleToMatchSentinel && + "Pattern doesn't match"); + return representation; +} + +bool PatternBenefit::operator==(const PatternBenefit& other) { + if (isImpossibleToMatch()) + return other.isImpossibleToMatch(); + if (other.isImpossibleToMatch()) + return false; + return getBenefit() == other.getBenefit(); +} + +bool PatternBenefit::operator!=(const PatternBenefit& other) { + return !(*this == other); +} + +//===----------------------------------------------------------------------===// +// Pattern implementation +//===----------------------------------------------------------------------===// + +Pattern::Pattern(StringRef rootName, MLIRContext *context, + Optional<PatternBenefit> staticBenefit) + : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) { +} + +Pattern::Pattern(StringRef rootName, MLIRContext *context, + unsigned staticBenefit) + : rootKind(rootName, context), staticBenefit(staticBenefit) {} + +Optional<PatternBenefit> Pattern::getStaticBenefit() const { + return staticBenefit; +} + +OperationName Pattern::getRootKind() const { return rootKind; } + +void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state, + PatternRewriter &rewriter) const { + rewrite(op, rewriter); +} + +void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const { + llvm_unreachable("need to implement one of the rewrite functions!"); +} + +/// This method indicates that no match was found. +PatternMatchResult Pattern::matchFailure() { + return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()}; +} + +/// This method indicates that a match was found and has the specified cost. +PatternMatchResult +Pattern::matchSuccess(PatternBenefit benefit, + std::unique_ptr<PatternState> state) const { + assert((!getStaticBenefit().hasValue() || + getStaticBenefit().getValue() == benefit) && + "This version of matchSuccess must be called with a benefit that " + "matches the static benefit if set!"); + + return {benefit, std::move(state)}; +} + +/// This method indicates that a match was found for patterns that have a +/// known static benefit. +PatternMatchResult +Pattern::matchSuccess(std::unique_ptr<PatternState> state) const { + auto benefit = getStaticBenefit(); + assert(benefit.hasValue() && "Pattern doesn't have a static benefit"); + return matchSuccess(benefit.getValue(), std::move(state)); +} + +//===----------------------------------------------------------------------===// +// PatternRewriter implementation +//===----------------------------------------------------------------------===// + +PatternRewriter::~PatternRewriter() { + // Out of line to provide a vtable anchor for the class. +} + +/// This method is used as the final replacement hook for patterns that match +/// a single result value. In addition to replacing and removing the +/// specified operation, clients can specify a list of other nodes that this +/// replacement may make (perhaps transitively) dead. If any of those ops are +/// dead, this will remove them as well. +void PatternRewriter::replaceSingleResultOp( + Operation *op, SSAValue *newValue, ArrayRef<SSAValue *> opsToRemoveIfDead) { + // Notify the rewriter subclass that we're about to replace this root. + notifyRootReplaced(op); + + assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!"); + op->getResult(0)->replaceAllUsesWith(newValue); + + notifyOperationRemoved(op); + op->erase(); + + // TODO: Process the opsToRemoveIfDead list, removing things and calling the + // notifyOperationRemoved hook in the process. +} + +/// This method is used as the final notification hook for patterns that end +/// up modifying the pattern root in place, by changing its operands. This is +/// a minor efficiency win (it avoids creating a new instruction and removing +/// the old one) but also often allows simpler code in the client. +/// +/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter +/// should remove if they are dead at this point. +/// +void PatternRewriter::updatedRootInPlace( + Operation *op, ArrayRef<SSAValue *> opsToRemoveIfDead) { + // Notify the rewriter subclass that we're about to replace this root. + notifyRootUpdated(op); + + // TODO: Process the opsToRemoveIfDead list, removing things and calling the + // notifyOperationRemoved hook in the process. +} + +//===----------------------------------------------------------------------===// +// PatternMatcher implementation +//===----------------------------------------------------------------------===// + +/// Find the highest benefit pattern available in the pattern set for the DAG +/// rooted at the specified node. This returns the pattern if found, or null +/// if there are no matches. +auto PatternMatcher::findMatch(Operation *op) -> MatchResult { + // TODO: This is a completely trivial implementation, expand this in the + // future. + + // Keep track of the best match, the benefit of it, and any matcher specific + // state it is maintaining. + MatchResult bestMatch = {nullptr, nullptr}; + Optional<PatternBenefit> bestBenefit; + + for (auto &pattern : patterns) { + // Ignore patterns that are for the wrong root. + if (pattern->getRootKind() != op->getName()) + continue; + + // If we know the static cost of the pattern is worse than what we've + // already found then don't run it. + auto staticBenefit = pattern->getStaticBenefit(); + if (staticBenefit.hasValue() && bestBenefit.hasValue() && + staticBenefit.getValue().getBenefit() < + bestBenefit.getValue().getBenefit()) + continue; + + // Check to see if this pattern matches this node. + auto result = pattern->match(op); + auto benefit = result.first; + + // If this pattern failed to match, ignore it. + if (benefit.isImpossibleToMatch()) + continue; + + // If it matched but had lower benefit than our best match so far, then + // ignore it. + if (bestBenefit.hasValue() && + benefit.getBenefit() < bestBenefit.getValue().getBenefit()) + continue; + + // Okay we found a match that is better than our previous one, remember it. + bestBenefit = benefit; + bestMatch = {pattern.get(), std::move(result.second)}; + } + + // If we found any match, return it. + return bestMatch; +} diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp new file mode 100644 index 00000000000..4432a8070de --- /dev/null +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -0,0 +1,394 @@ +//===- Utils.cpp ---- Misc utilities for code and data transformation -----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements miscellaneous transformation routines for non-loop IR +// structures. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Utils.h" + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/IR/Builders.h" +#include "mlir/StandardOps/StandardOps.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/DenseMap.h" + +using namespace mlir; + +/// Return true if this operation dereferences one or more memref's. +// Temporary utility: will be replaced when this is modeled through +// side-effects/op traits. TODO(b/117228571) +static bool isMemRefDereferencingOp(const Operation &op) { + if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() || + op.isa<DmaWaitOp>()) + return true; + return false; +} + +/// Replaces all uses of oldMemRef with newMemRef while optionally remapping +/// old memref's indices to the new memref using the supplied affine map +/// and adding any additional indices. The new memref could be of a different +/// shape or rank, but of the same elemental type. Additional indices are added +/// at the start for now. +// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily +// extended to add additional indices at any position. +bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, + MLValue *newMemRef, + ArrayRef<MLValue *> extraIndices, + AffineMap indexRemap) { + unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank(); + (void)newMemRefRank; // unused in opt mode + unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank(); + (void)newMemRefRank; + if (indexRemap) { + assert(indexRemap.getNumInputs() == oldMemRefRank); + assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); + } else { + assert(oldMemRefRank + extraIndices.size() == newMemRefRank); + } + + // Assert same elemental type. + assert(cast<MemRefType>(oldMemRef->getType())->getElementType() == + cast<MemRefType>(newMemRef->getType())->getElementType()); + + // Check if memref was used in a non-deferencing context. + for (const StmtOperand &use : oldMemRef->getUses()) { + auto *opStmt = cast<OperationStmt>(use.getOwner()); + // Failure: memref used in a non-deferencing op (potentially escapes); no + // replacement in these cases. + if (!isMemRefDereferencingOp(*opStmt)) + return false; + } + + // Walk all uses of old memref. Statement using the memref gets replaced. + for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { + StmtOperand &use = *(it++); + auto *opStmt = cast<OperationStmt>(use.getOwner()); + assert(isMemRefDereferencingOp(*opStmt) && + "memref deferencing op expected"); + + auto getMemRefOperandPos = [&]() -> unsigned { + unsigned i; + for (i = 0; i < opStmt->getNumOperands(); i++) { + if (opStmt->getOperand(i) == oldMemRef) + break; + } + assert(i < opStmt->getNumOperands() && "operand guaranteed to be found"); + return i; + }; + unsigned memRefOperandPos = getMemRefOperandPos(); + + // Construct the new operation statement using this memref. + SmallVector<MLValue *, 8> operands; + operands.reserve(opStmt->getNumOperands() + extraIndices.size()); + // Insert the non-memref operands. + operands.insert(operands.end(), opStmt->operand_begin(), + opStmt->operand_begin() + memRefOperandPos); + operands.push_back(newMemRef); + + MLFuncBuilder builder(opStmt); + for (auto *extraIndex : extraIndices) { + // TODO(mlir-team): An operation/SSA value should provide a method to + // return the position of an SSA result in its defining + // operation. + assert(extraIndex->getDefiningStmt()->getNumResults() == 1 && + "single result op's expected to generate these indices"); + assert((cast<MLValue>(extraIndex)->isValidDim() || + cast<MLValue>(extraIndex)->isValidSymbol()) && + "invalid memory op index"); + operands.push_back(cast<MLValue>(extraIndex)); + } + + // Construct new indices. The indices of a memref come right after it, i.e., + // at position memRefOperandPos + 1. + SmallVector<SSAValue *, 4> indices( + opStmt->operand_begin() + memRefOperandPos + 1, + opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank); + if (indexRemap) { + auto remapOp = + builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices); + // Remapped indices. + for (auto *index : remapOp->getOperation()->getResults()) + operands.push_back(cast<MLValue>(index)); + } else { + // No remapping specified. + for (auto *index : indices) + operands.push_back(cast<MLValue>(index)); + } + + // Insert the remaining operands unmodified. + operands.insert(operands.end(), + opStmt->operand_begin() + memRefOperandPos + 1 + + oldMemRefRank, + opStmt->operand_end()); + + // Result types don't change. Both memref's are of the same elemental type. + SmallVector<Type *, 8> resultTypes; + resultTypes.reserve(opStmt->getNumResults()); + for (const auto *result : opStmt->getResults()) + resultTypes.push_back(result->getType()); + + // Create the new operation. + auto *repOp = + builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands, + resultTypes, opStmt->getAttrs()); + // Replace old memref's deferencing op's uses. + unsigned r = 0; + for (auto *res : opStmt->getResults()) { + res->replaceAllUsesWith(repOp->getResult(r++)); + } + opStmt->erase(); + } + return true; +} + +// Creates and inserts into 'builder' a new AffineApplyOp, with the number of +// its results equal to the number of 'operands, as a composition +// of all other AffineApplyOps reachable from input parameter 'operands'. If the +// operands were drawing results from multiple affine apply ops, this also leads +// to a collapse into a single affine apply op. The final results of the +// composed AffineApplyOp are returned in output parameter 'results'. +OperationStmt * +mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc, + ArrayRef<MLValue *> operands, + ArrayRef<OperationStmt *> affineApplyOps, + SmallVectorImpl<SSAValue *> &results) { + // Create identity map with same number of dimensions as number of operands. + auto map = builder->getMultiDimIdentityMap(operands.size()); + // Initialize AffineValueMap with identity map. + AffineValueMap valueMap(map, operands); + + for (auto *opStmt : affineApplyOps) { + assert(opStmt->isa<AffineApplyOp>()); + auto affineApplyOp = opStmt->cast<AffineApplyOp>(); + // Forward substitute 'affineApplyOp' into 'valueMap'. + valueMap.forwardSubstitute(*affineApplyOp); + } + // Compose affine maps from all ancestor AffineApplyOps. + // Create new AffineApplyOp from 'valueMap'. + unsigned numOperands = valueMap.getNumOperands(); + SmallVector<SSAValue *, 4> outOperands(numOperands); + for (unsigned i = 0; i < numOperands; ++i) { + outOperands[i] = valueMap.getOperand(i); + } + // Create new AffineApplyOp based on 'valueMap'. + auto affineApplyOp = + builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands); + results.resize(operands.size()); + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + results[i] = affineApplyOp->getResult(i); + } + return cast<OperationStmt>(affineApplyOp->getOperation()); +} + +/// Given an operation statement, inserts a new single affine apply operation, +/// that is exclusively used by this operation statement, and that provides all +/// operands that are results of an affine_apply as a function of loop iterators +/// and program parameters and whose results are. +/// +/// Before +/// +/// for %i = 0 to #map(%N) +/// %idx = affine_apply (d0) -> (d0 mod 2) (%i) +/// "send"(%idx, %A, ...) +/// "compute"(%idx) +/// +/// After +/// +/// for %i = 0 to #map(%N) +/// %idx = affine_apply (d0) -> (d0 mod 2) (%i) +/// "send"(%idx, %A, ...) +/// %idx_ = affine_apply (d0) -> (d0 mod 2) (%i) +/// "compute"(%idx_) +/// +/// This allows applying different transformations on send and compute (for eg. +/// different shifts/delays). +/// +/// Returns nullptr either if none of opStmt's operands were the result of an +/// affine_apply and thus there was no affine computation slice to create, or if +/// all the affine_apply op's supplying operands to this opStmt do not have any +/// uses besides this opStmt. Returns the new affine_apply operation statement +/// otherwise. +OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { + // Collect all operands that are results of affine apply ops. + SmallVector<MLValue *, 4> subOperands; + subOperands.reserve(opStmt->getNumOperands()); + for (auto *operand : opStmt->getOperands()) { + auto *defStmt = operand->getDefiningStmt(); + if (defStmt && defStmt->isa<AffineApplyOp>()) { + subOperands.push_back(operand); + } + } + + // Gather sequence of AffineApplyOps reachable from 'subOperands'. + SmallVector<OperationStmt *, 4> affineApplyOps; + getReachableAffineApplyOps(subOperands, affineApplyOps); + // Skip transforming if there are no affine maps to compose. + if (affineApplyOps.empty()) + return nullptr; + + // Check if all uses of the affine apply op's lie in this op stmt + // itself, in which case there would be nothing to do. + bool localized = true; + for (auto *op : affineApplyOps) { + for (auto *result : op->getResults()) { + for (auto &use : result->getUses()) { + if (use.getOwner() != opStmt) { + localized = false; + break; + } + } + } + } + if (localized) + return nullptr; + + MLFuncBuilder builder(opStmt); + SmallVector<SSAValue *, 4> results; + auto *affineApplyStmt = createComposedAffineApplyOp( + &builder, opStmt->getLoc(), subOperands, affineApplyOps, results); + assert(results.size() == subOperands.size() && + "number of results should be the same as the number of subOperands"); + + // Construct the new operands that include the results from the composed + // affine apply op above instead of existing ones (subOperands). So, they + // differ from opStmt's operands only for those operands in 'subOperands', for + // which they will be replaced by the corresponding one from 'results'. + SmallVector<MLValue *, 4> newOperands(opStmt->getOperands()); + for (unsigned i = 0, e = newOperands.size(); i < e; i++) { + // Replace the subOperands from among the new operands. + unsigned j, f; + for (j = 0, f = subOperands.size(); j < f; j++) { + if (newOperands[i] == subOperands[j]) + break; + } + if (j < subOperands.size()) { + newOperands[i] = cast<MLValue>(results[j]); + } + } + + for (unsigned idx = 0; idx < newOperands.size(); idx++) { + opStmt->setOperand(idx, newOperands[idx]); + } + + return affineApplyStmt; +} + +void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { + if (affineApplyOp->getOperation()->getOperationFunction()->getKind() != + Function::Kind::MLFunc) { + // TODO: Support forward substitution for CFGFunctions. + return; + } + auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation()); + // Iterate through all uses of all results of 'opStmt', forward substituting + // into any uses which are AffineApplyOps. + for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; + ++resultIndex) { + const MLValue *result = opStmt->getResult(resultIndex); + for (auto it = result->use_begin(); it != result->use_end();) { + StmtOperand &use = *(it++); + auto *useStmt = use.getOwner(); + auto *useOpStmt = dyn_cast<OperationStmt>(useStmt); + // Skip if use is not AffineApplyOp. + if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>()) + continue; + // Advance iterator past 'opStmt' operands which also use 'result'. + while (it != result->use_end() && it->getOwner() == useStmt) + ++it; + + MLFuncBuilder builder(useOpStmt); + // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'. + auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>(); + AffineValueMap valueMap(*oldAffineApplyOp); + // Forward substitute 'result' at index 'i' into 'valueMap'. + valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex); + + // Create new AffineApplyOp from 'valueMap'. + unsigned numOperands = valueMap.getNumOperands(); + SmallVector<SSAValue *, 4> operands(numOperands); + for (unsigned i = 0; i < numOperands; ++i) { + operands[i] = valueMap.getOperand(i); + } + auto newAffineApplyOp = builder.create<AffineApplyOp>( + useOpStmt->getLoc(), valueMap.getAffineMap(), operands); + + // Update all uses to use results from 'newAffineApplyOp'. + for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) { + oldAffineApplyOp->getResult(i)->replaceAllUsesWith( + newAffineApplyOp->getResult(i)); + } + // Erase 'oldAffineApplyOp'. + oldAffineApplyOp->getOperation()->erase(); + } + } +} + +/// Folds the specified (lower or upper) bound to a constant if possible +/// considering its operands. Returns false if the folding happens for any of +/// the bounds, true otherwise. +bool mlir::constantFoldBounds(ForStmt *forStmt) { + auto foldLowerOrUpperBound = [forStmt](bool lower) { + // Check if the bound is already a constant. + if (lower && forStmt->hasConstantLowerBound()) + return true; + if (!lower && forStmt->hasConstantUpperBound()) + return true; + + // Check to see if each of the operands is the result of a constant. If so, + // get the value. If not, ignore it. + SmallVector<Attribute *, 8> operandConstants; + auto boundOperands = lower ? forStmt->getLowerBoundOperands() + : forStmt->getUpperBoundOperands(); + for (const auto *operand : boundOperands) { + Attribute *operandCst = nullptr; + if (auto *operandOp = operand->getDefiningOperation()) { + if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) + operandCst = operandConstantOp->getValue(); + } + operandConstants.push_back(operandCst); + } + + AffineMap boundMap = + lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap(); + assert(boundMap.getNumResults() >= 1 && + "bound maps should have at least one result"); + SmallVector<Attribute *, 4> foldedResults; + if (boundMap.constantFold(operandConstants, foldedResults)) + return true; + + // Compute the max or min as applicable over the results. + assert(!foldedResults.empty() && "bounds should have at least one result"); + auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue(); + for (unsigned i = 1; i < foldedResults.size(); i++) { + auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue(); + maxOrMin = lower ? std::max(maxOrMin, foldedResult) + : std::min(maxOrMin, foldedResult); + } + lower ? forStmt->setConstantLowerBound(maxOrMin) + : forStmt->setConstantUpperBound(maxOrMin); + + // Return false on success. + return false; + }; + + bool ret = foldLowerOrUpperBound(/*lower=*/true); + ret &= foldLowerOrUpperBound(/*lower=*/false); + return ret; +} |

