diff options
Diffstat (limited to 'mlir/lib/Analysis/LoopAnalysis.cpp')
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 101 |
1 files changed, 50 insertions, 51 deletions
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index dd14f38df55..b66b665c563 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -27,7 +27,7 @@ #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -42,27 +42,27 @@ using namespace mlir; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) { +AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { // upper_bound - lower_bound int64_t loopSpan; - int64_t step = forStmt.getStep(); - auto *context = forStmt.getContext(); + int64_t step = forInst.getStep(); + auto *context = forInst.getContext(); - if (forStmt.hasConstantBounds()) { - int64_t lb = forStmt.getConstantLowerBound(); - int64_t ub = forStmt.getConstantUpperBound(); + if (forInst.hasConstantBounds()) { + int64_t lb = forInst.getConstantLowerBound(); + int64_t ub = forInst.getConstantUpperBound(); loopSpan = ub - lb; } else { - auto lbMap = forStmt.getLowerBoundMap(); - auto ubMap = forStmt.getUpperBoundMap(); + auto lbMap = forInst.getLowerBoundMap(); + auto ubMap = forInst.getUpperBoundMap(); // TODO(bondhugula): handle max/min of multiple expressions. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return nullptr; // TODO(bondhugula): handle bounds with different operands. // Bounds have different operands, unhandled for now. - if (!forStmt.matchingBoundOperandList()) + if (!forInst.matchingBoundOperandList()) return nullptr; // ub_expr - lb_expr @@ -88,8 +88,8 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) { /// Returns the trip count of the loop if it's a constant, None otherwise. This /// method uses affine expression analysis (in turn using getTripCount) and is /// able to determine constant trip count in non-trivial cases. -llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) { - auto tripCountExpr = getTripCountExpr(forStmt); +llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) { + auto tripCountExpr = getTripCountExpr(forInst); if (!tripCountExpr) return None; @@ -103,8 +103,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) { /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { - auto tripCountExpr = getTripCountExpr(forStmt); +uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { + auto tripCountExpr = getTripCountExpr(forInst); if (!tripCountExpr) return 1; @@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { } bool mlir::isAccessInvariant(const Value &iv, const Value &index) { - assert(isa<ForStmt>(iv) && "iv must be a ForStmt"); + assert(isa<ForInst>(iv) && "iv must be a ForInst"); assert(index.getType().isa<IndexType>() && "index must be of IndexType"); SmallVector<OperationInst *, 4> affineApplyOps; getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps); @@ -172,7 +172,7 @@ mlir::getInvariantAccesses(const Value &iv, } /// Given: -/// 1. an induction variable `iv` of type ForStmt; +/// 1. an induction variable `iv` of type ForInst; /// 2. a `memoryOp` of type const LoadOp& or const StoreOp&; /// 3. the index of the `fastestVaryingDim` along which to check; /// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access @@ -233,37 +233,37 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { return memRefType.getElementType().template isa<VectorType>(); } -static bool isVectorTransferReadOrWrite(const Statement &stmt) { - const auto *opStmt = cast<OperationInst>(&stmt); - return opStmt->isa<VectorTransferReadOp>() || - opStmt->isa<VectorTransferWriteOp>(); +static bool isVectorTransferReadOrWrite(const Instruction &inst) { + const auto *opInst = cast<OperationInst>(&inst); + return opInst->isa<VectorTransferReadOp>() || + opInst->isa<VectorTransferWriteOp>(); } -using VectorizableStmtFun = - std::function<bool(const ForStmt &, const OperationInst &)>; +using VectorizableInstFun = + std::function<bool(const ForInst &, const OperationInst &)>; -static bool isVectorizableLoopWithCond(const ForStmt &loop, - VectorizableStmtFun isVectorizableStmt) { +static bool isVectorizableLoopWithCond(const ForInst &loop, + VectorizableInstFun isVectorizableInst) { if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) { return false; } // No vectorization across conditionals for now. auto conditionals = matcher::If(); - auto *forStmt = const_cast<ForStmt *>(&loop); - auto conditionalsMatched = conditionals.match(forStmt); + auto *forInst = const_cast<ForInst *>(&loop); + auto conditionalsMatched = conditionals.match(forInst); if (!conditionalsMatched.empty()) { return false; } auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); - auto vectorTransfersMatched = vectorTransfers.match(forStmt); + auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { return false; } auto loadAndStores = matcher::Op(matcher::isLoadOrStore); - auto loadAndStoresMatched = loadAndStores.match(forStmt); + auto loadAndStoresMatched = loadAndStores.match(forInst); for (auto ls : loadAndStoresMatched) { auto *op = cast<OperationInst>(ls.first); auto load = op->dyn_cast<LoadOp>(); @@ -275,7 +275,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, if (vector) { return false; } - if (!isVectorizableStmt(loop, *op)) { + if (!isVectorizableInst(loop, *op)) { return false; } } @@ -283,9 +283,9 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, } bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( - const ForStmt &loop, unsigned fastestVaryingDim) { - VectorizableStmtFun fun( - [fastestVaryingDim](const ForStmt &loop, const OperationInst &op) { + const ForInst &loop, unsigned fastestVaryingDim) { + VectorizableInstFun fun( + [fastestVaryingDim](const ForInst &loop, const OperationInst &op) { auto load = op.dyn_cast<LoadOp>(); auto store = op.dyn_cast<StoreOp>(); return load ? isContiguousAccess(loop, *load, fastestVaryingDim) @@ -294,37 +294,36 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( return isVectorizableLoopWithCond(loop, fun); } -bool mlir::isVectorizableLoop(const ForStmt &loop) { - VectorizableStmtFun fun( +bool mlir::isVectorizableLoop(const ForInst &loop) { + VectorizableInstFun fun( // TODO: implement me - [](const ForStmt &loop, const OperationInst &op) { return true; }); + [](const ForInst &loop, const OperationInst &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } -/// Checks whether SSA dominance would be violated if a for stmt's body -/// statements are shifted by the specified shifts. This method checks if a +/// Checks whether SSA dominance would be violated if a for inst's body +/// instructions are shifted by the specified shifts. This method checks if a /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, +bool mlir::isInstwiseShiftValid(const ForInst &forInst, ArrayRef<uint64_t> shifts) { - auto *forBody = forStmt.getBody(); + auto *forBody = forInst.getBody(); assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; - for (const auto &stmt : *forBody) { - // A for or if stmt does not produce any def/results (that are used + for (const auto &inst : *forBody) { + // A for or if inst does not produce any def/results (that are used // outside). - if (const auto *opStmt = dyn_cast<OperationInst>(&stmt)) { - for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { - const Value *result = opStmt->getResult(i); + if (const auto *opInst = dyn_cast<OperationInst>(&inst)) { + for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) { + const Value *result = opInst->getResult(i); for (const InstOperand &use : result->getUses()) { - // If an ancestor statement doesn't lie in the block of forStmt, there - // is no shift to check. - // This is a naive way. If performance becomes an issue, a map can - // be used to store 'shifts' - to look up the shift for a statement in - // constant time. - if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner())) - if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)]) + // If an ancestor instruction doesn't lie in the block of forInst, + // there is no shift to check. This is a naive way. If performance + // becomes an issue, a map can be used to store 'shifts' - to look up + // the shift for a instruction in constant time. + if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner())) + if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)]) return false; } } |

