diff options
Diffstat (limited to 'mlir/lib')
40 files changed, 1556 insertions, 1549 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index f01735f26e1..8058af06b55 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/MathExtras.h" @@ -498,22 +498,22 @@ void mlir::getReachableAffineApplyOps( while (!worklist.empty()) { State &state = worklist.back(); - auto *opStmt = state.value->getDefiningInst(); + auto *opInst = state.value->getDefiningInst(); // Note: getDefiningInst will return nullptr if the operand is not an - // OperationInst (i.e. ForStmt), which is a terminator for the search. - if (opStmt == nullptr || !opStmt->isa<AffineApplyOp>()) { + // OperationInst (i.e. ForInst), which is a terminator for the search. + if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) { worklist.pop_back(); continue; } - if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) { + if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) { if (state.operandIndex == 0) { - // Pre-Visit: Add 'opStmt' to reachable sequence. - affineApplyOps.push_back(opStmt); + // Pre-Visit: Add 'opInst' to reachable sequence. + affineApplyOps.push_back(opInst); } - if (state.operandIndex < opStmt->getNumOperands()) { + if (state.operandIndex < opInst->getNumOperands()) { // Visit: Add next 'affineApplyOp' operand to worklist. // Get next operand to visit at 'operandIndex'. - auto *nextOperand = opStmt->getOperand(state.operandIndex); + auto *nextOperand = opInst->getOperand(state.operandIndex); // Increment 'operandIndex' in 'state'. ++state.operandIndex; // Add 'nextOperand' to worklist. @@ -533,47 +533,47 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) { SmallVector<OperationInst *, 4> affineApplyOps; getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps); // Compose AffineApplyOps in 'affineApplyOps'. - for (auto *opStmt : affineApplyOps) { - assert(opStmt->isa<AffineApplyOp>()); - auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>(); + for (auto *opInst : affineApplyOps) { + assert(opInst->isa<AffineApplyOp>()); + auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>(); // Forward substitute 'affineApplyOp' into 'valueMap'. valueMap->forwardSubstitute(*affineApplyOp); } } // Builds a system of constraints with dimensional identifiers corresponding to -// the loop IVs of the forStmts appearing in that order. Any symbols founds in +// the loop IVs of the forInsts appearing in that order. Any symbols founds in // the bound operands are added as symbols in the system. Returns false for the // yet unimplemented cases. // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or // stride information in FlatAffineConstraints. (For eg., by using iv - lb % // step = 0 and/or by introducing a method in FlatAffineConstraints // setExprStride(ArrayRef<int64_t> expr, int64_t stride) -bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts, +bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts, FlatAffineConstraints *domain) { - SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end()); + SmallVector<Value *, 4> indices(forInsts.begin(), forInsts.end()); // Reset while associated Values in 'indices' to the domain. - domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); - for (auto *forStmt : forStmts) { - // Add constraints from forStmt's bounds. - if (!domain->addForStmtDomain(*forStmt)) + domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); + for (auto *forInst : forInsts) { + // Add constraints from forInst's bounds. + if (!domain->addForInstDomain(*forInst)) return false; } return true; } -// Computes the iteration domain for 'opStmt' and populates 'indexSet', which -// encapsulates the constraints involving loops surrounding 'opStmt' and +// Computes the iteration domain for 'opInst' and populates 'indexSet', which +// encapsulates the constraints involving loops surrounding 'opInst' and // potentially involving any Function symbols. The dimensional identifiers in -// 'indexSet' correspond to the loops surounding 'stmt' from outermost to +// 'indexSet' correspond to the loops surounding 'inst' from outermost to // innermost. -// TODO(andydavis) Add support to handle IfStmts surrounding 'stmt'. -static bool getStmtIndexSet(const Statement *stmt, +// TODO(andydavis) Add support to handle IfInsts surrounding 'inst'. +static bool getInstIndexSet(const Instruction *inst, FlatAffineConstraints *indexSet) { - // TODO(andydavis) Extend this to gather enclosing IfStmts and consider + // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. - SmallVector<ForStmt *, 4> loops; - getLoopIVs(*stmt, &loops); + SmallVector<ForInst *, 4> loops; + getLoopIVs(*inst, &loops); return getIndexSet(loops, indexSet); } @@ -672,7 +672,7 @@ static void buildDimAndSymbolPositionMaps( auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto *value = values[i]; - if (!isa<ForStmt>(values[i])) + if (!isa<ForInst>(values[i])) valuePosMap->addSymbolValue(value); else if (isSrc) valuePosMap->addSrcValue(value); @@ -840,13 +840,13 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, // Add equality constraints for any operands that are defined by constant ops. auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (isa<ForStmt>(operands[i])) + if (isa<ForInst>(operands[i])) continue; auto *symbol = operands[i]; assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningInst()) { - if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) { + if (auto *opInst = symbol->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) { dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), constOp->getValue()); } @@ -909,8 +909,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (!isa<ForStmt>(srcDomain.getIdValue(i)) || - !isa<ForStmt>(dstDomain.getIdValue(i)) || + if (!isa<ForInst>(srcDomain.getIdValue(i)) || + !isa<ForInst>(dstDomain.getIdValue(i)) || srcDomain.getIdValue(i) != dstDomain.getIdValue(i)) break; ++numCommonLoops; @@ -918,26 +918,26 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, return numCommonLoops; } -// Returns Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. +// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. static Block *getCommonBlock(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { if (numCommonLoops == 0) { - auto *block = srcAccess.opStmt->getBlock(); + auto *block = srcAccess.opInst->getBlock(); while (block->getContainingInst()) { block = block->getContainingInst()->getBlock(); } return block; } auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); - assert(isa<ForStmt>(commonForValue)); - return cast<ForStmt>(commonForValue)->getBody(); + assert(isa<ForInst>(commonForValue)); + return cast<ForInst>(commonForValue)->getBody(); } -// Returns true if the ancestor operation statement of 'srcAccess' properly -// dominates the ancestor operation statement of 'dstAccess' in the same -// statement block. Returns false otherwise. +// Returns true if the ancestor operation instruction of 'srcAccess' properly +// dominates the ancestor operation instruction of 'dstAccess' in the same +// instruction block. Returns false otherwise. // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals, // the function is named 'srcMayExecuteBeforeDst'. // Note that 'numCommonLoops' is the number of contiguous surrounding outer @@ -946,16 +946,16 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { - // Get Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'. + // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. auto *commonBlock = getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); // Check the dominance relationship between the respective ancestors of the // src and dst in the Block of the innermost among the common loops. - auto *srcStmt = commonBlock->findAncestorInstInBlock(*srcAccess.opStmt); - assert(srcStmt != nullptr); - auto *dstStmt = commonBlock->findAncestorInstInBlock(*dstAccess.opStmt); - assert(dstStmt != nullptr); - return mlir::properlyDominates(*srcStmt, *dstStmt); + auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst); + assert(srcInst != nullptr); + auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst); + assert(dstInst != nullptr); + return mlir::properlyDominates(*srcInst, *dstInst); } // Adds ordering constraints to 'dependenceDomain' based on number of loops @@ -1119,7 +1119,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // until operands of the AffineValueMap are loop IVs or symbols. // *) Build iteration domain constraints for each access. Iteration domain // constraints are pairs of inequality contraints representing the -// upper/lower loop bounds for each ForStmt in the loop nest associated +// upper/lower loop bounds for each ForInst in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map // Values from access functions and iteration domains to their position @@ -1197,7 +1197,7 @@ bool mlir::checkMemrefAccessDependence( if (srcAccess.memref != dstAccess.memref) return false; // Return 'false' if one of these accesses is not a StoreOp. - if (!srcAccess.opStmt->isa<StoreOp>() && !dstAccess.opStmt->isa<StoreOp>()) + if (!srcAccess.opInst->isa<StoreOp>() && !dstAccess.opInst->isa<StoreOp>()) return false; // Get composed access function for 'srcAccess'. @@ -1208,19 +1208,19 @@ bool mlir::checkMemrefAccessDependence( AffineValueMap dstAccessMap; dstAccess.getAccessMap(&dstAccessMap); - // Get iteration domain for the 'srcAccess' statement. + // Get iteration domain for the 'srcAccess' instruction. FlatAffineConstraints srcDomain; - if (!getStmtIndexSet(srcAccess.opStmt, &srcDomain)) + if (!getInstIndexSet(srcAccess.opInst, &srcDomain)) return false; - // Get iteration domain for 'dstAccess' statement. + // Get iteration domain for 'dstAccess' instruction. FlatAffineConstraints dstDomain; - if (!getStmtIndexSet(dstAccess.opStmt, &dstDomain)) + if (!getInstIndexSet(dstAccess.opInst, &dstDomain)) return false; // Return 'false' if loopDepth > numCommonLoops and if the ancestor operation - // statement of 'srcAccess' does not properly dominate the ancestor operation - // statement of 'dstAccess' in the same common statement block. + // instruction of 'srcAccess' does not properly dominate the ancestor + // operation instruction of 'dstAccess' in the same common instruction block. unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); assert(loopDepth <= numCommonLoops + 1); if (loopDepth > numCommonLoops && diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index a45c5ffdf5e..d4b8a05dbf8 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -24,8 +24,8 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Statements.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" @@ -1248,22 +1248,22 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { numSymbols = newSymbolCount; } -bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { +bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { unsigned pos; // Pre-condition for this method. - if (!findId(forStmt, &pos)) { + if (!findId(forInst, &pos)) { assert(0 && "Value not found"); return false; } - if (forStmt.getStep() != 1) + if (forInst.getStep() != 1) LLVM_DEBUG(llvm::dbgs() << "Domain conservative: non-unit stride not handled\n"); // Adds a lower or upper bound when the bounds aren't constant. auto addLowerOrUpperBound = [&](bool lower) -> bool { - auto operands = lower ? forStmt.getLowerBoundOperands() - : forStmt.getUpperBoundOperands(); + auto operands = lower ? forInst.getLowerBoundOperands() + : forInst.getUpperBoundOperands(); for (const auto &operand : operands) { unsigned loc; if (!findId(*operand, &loc)) { @@ -1271,8 +1271,8 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand)); loc = getNumDimIds() + getNumSymbolIds() - 1; // Check if the symbol is a constant. - if (auto *opStmt = operand->getDefiningInst()) { - if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) { + if (auto *opInst = operand->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) { setIdToConstant(*operand, constOp->getValue()); } } @@ -1292,7 +1292,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { } auto boundMap = - lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap(); + lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap(); FlatAffineConstraints localVarCst; std::vector<SmallVector<int64_t, 8>> flatExprs; @@ -1322,16 +1322,16 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { return true; }; - if (forStmt.hasConstantLowerBound()) { - addConstantLowerBound(pos, forStmt.getConstantLowerBound()); + if (forInst.hasConstantLowerBound()) { + addConstantLowerBound(pos, forInst.getConstantLowerBound()); } else { // Non-constant lower bound case. if (!addLowerOrUpperBound(/*lower=*/true)) return false; } - if (forStmt.hasConstantUpperBound()) { - addConstantUpperBound(pos, forStmt.getConstantUpperBound() - 1); + if (forInst.hasConstantUpperBound()) { + addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1); return true; } // Non-constant upper bound case. diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 0c8db07dbb4..4ee1b393068 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -21,7 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Dominance.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "llvm/Support/GenericDomTreeConstruction.h" using namespace mlir; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index dd14f38df55..b66b665c563 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -27,7 +27,7 @@ #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -42,27 +42,27 @@ using namespace mlir; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) { +AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { // upper_bound - lower_bound int64_t loopSpan; - int64_t step = forStmt.getStep(); - auto *context = forStmt.getContext(); + int64_t step = forInst.getStep(); + auto *context = forInst.getContext(); - if (forStmt.hasConstantBounds()) { - int64_t lb = forStmt.getConstantLowerBound(); - int64_t ub = forStmt.getConstantUpperBound(); + if (forInst.hasConstantBounds()) { + int64_t lb = forInst.getConstantLowerBound(); + int64_t ub = forInst.getConstantUpperBound(); loopSpan = ub - lb; } else { - auto lbMap = forStmt.getLowerBoundMap(); - auto ubMap = forStmt.getUpperBoundMap(); + auto lbMap = forInst.getLowerBoundMap(); + auto ubMap = forInst.getUpperBoundMap(); // TODO(bondhugula): handle max/min of multiple expressions. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return nullptr; // TODO(bondhugula): handle bounds with different operands. // Bounds have different operands, unhandled for now. - if (!forStmt.matchingBoundOperandList()) + if (!forInst.matchingBoundOperandList()) return nullptr; // ub_expr - lb_expr @@ -88,8 +88,8 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) { /// Returns the trip count of the loop if it's a constant, None otherwise. This /// method uses affine expression analysis (in turn using getTripCount) and is /// able to determine constant trip count in non-trivial cases. -llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) { - auto tripCountExpr = getTripCountExpr(forStmt); +llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) { + auto tripCountExpr = getTripCountExpr(forInst); if (!tripCountExpr) return None; @@ -103,8 +103,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) { /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { - auto tripCountExpr = getTripCountExpr(forStmt); +uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { + auto tripCountExpr = getTripCountExpr(forInst); if (!tripCountExpr) return 1; @@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { } bool mlir::isAccessInvariant(const Value &iv, const Value &index) { - assert(isa<ForStmt>(iv) && "iv must be a ForStmt"); + assert(isa<ForInst>(iv) && "iv must be a ForInst"); assert(index.getType().isa<IndexType>() && "index must be of IndexType"); SmallVector<OperationInst *, 4> affineApplyOps; getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps); @@ -172,7 +172,7 @@ mlir::getInvariantAccesses(const Value &iv, } /// Given: -/// 1. an induction variable `iv` of type ForStmt; +/// 1. an induction variable `iv` of type ForInst; /// 2. a `memoryOp` of type const LoadOp& or const StoreOp&; /// 3. the index of the `fastestVaryingDim` along which to check; /// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access @@ -233,37 +233,37 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { return memRefType.getElementType().template isa<VectorType>(); } -static bool isVectorTransferReadOrWrite(const Statement &stmt) { - const auto *opStmt = cast<OperationInst>(&stmt); - return opStmt->isa<VectorTransferReadOp>() || - opStmt->isa<VectorTransferWriteOp>(); +static bool isVectorTransferReadOrWrite(const Instruction &inst) { + const auto *opInst = cast<OperationInst>(&inst); + return opInst->isa<VectorTransferReadOp>() || + opInst->isa<VectorTransferWriteOp>(); } -using VectorizableStmtFun = - std::function<bool(const ForStmt &, const OperationInst &)>; +using VectorizableInstFun = + std::function<bool(const ForInst &, const OperationInst &)>; -static bool isVectorizableLoopWithCond(const ForStmt &loop, - VectorizableStmtFun isVectorizableStmt) { +static bool isVectorizableLoopWithCond(const ForInst &loop, + VectorizableInstFun isVectorizableInst) { if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) { return false; } // No vectorization across conditionals for now. auto conditionals = matcher::If(); - auto *forStmt = const_cast<ForStmt *>(&loop); - auto conditionalsMatched = conditionals.match(forStmt); + auto *forInst = const_cast<ForInst *>(&loop); + auto conditionalsMatched = conditionals.match(forInst); if (!conditionalsMatched.empty()) { return false; } auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); - auto vectorTransfersMatched = vectorTransfers.match(forStmt); + auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { return false; } auto loadAndStores = matcher::Op(matcher::isLoadOrStore); - auto loadAndStoresMatched = loadAndStores.match(forStmt); + auto loadAndStoresMatched = loadAndStores.match(forInst); for (auto ls : loadAndStoresMatched) { auto *op = cast<OperationInst>(ls.first); auto load = op->dyn_cast<LoadOp>(); @@ -275,7 +275,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, if (vector) { return false; } - if (!isVectorizableStmt(loop, *op)) { + if (!isVectorizableInst(loop, *op)) { return false; } } @@ -283,9 +283,9 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, } bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( - const ForStmt &loop, unsigned fastestVaryingDim) { - VectorizableStmtFun fun( - [fastestVaryingDim](const ForStmt &loop, const OperationInst &op) { + const ForInst &loop, unsigned fastestVaryingDim) { + VectorizableInstFun fun( + [fastestVaryingDim](const ForInst &loop, const OperationInst &op) { auto load = op.dyn_cast<LoadOp>(); auto store = op.dyn_cast<StoreOp>(); return load ? isContiguousAccess(loop, *load, fastestVaryingDim) @@ -294,37 +294,36 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( return isVectorizableLoopWithCond(loop, fun); } -bool mlir::isVectorizableLoop(const ForStmt &loop) { - VectorizableStmtFun fun( +bool mlir::isVectorizableLoop(const ForInst &loop) { + VectorizableInstFun fun( // TODO: implement me - [](const ForStmt &loop, const OperationInst &op) { return true; }); + [](const ForInst &loop, const OperationInst &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } -/// Checks whether SSA dominance would be violated if a for stmt's body -/// statements are shifted by the specified shifts. This method checks if a +/// Checks whether SSA dominance would be violated if a for inst's body +/// instructions are shifted by the specified shifts. This method checks if a /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, +bool mlir::isInstwiseShiftValid(const ForInst &forInst, ArrayRef<uint64_t> shifts) { - auto *forBody = forStmt.getBody(); + auto *forBody = forInst.getBody(); assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; - for (const auto &stmt : *forBody) { - // A for or if stmt does not produce any def/results (that are used + for (const auto &inst : *forBody) { + // A for or if inst does not produce any def/results (that are used // outside). - if (const auto *opStmt = dyn_cast<OperationInst>(&stmt)) { - for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { - const Value *result = opStmt->getResult(i); + if (const auto *opInst = dyn_cast<OperationInst>(&inst)) { + for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) { + const Value *result = opInst->getResult(i); for (const InstOperand &use : result->getUses()) { - // If an ancestor statement doesn't lie in the block of forStmt, there - // is no shift to check. - // This is a naive way. If performance becomes an issue, a map can - // be used to store 'shifts' - to look up the shift for a statement in - // constant time. - if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner())) - if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)]) + // If an ancestor instruction doesn't lie in the block of forInst, + // there is no shift to check. This is a naive way. If performance + // becomes an issue, a map can be used to store 'shifts' - to look up + // the shift for a instruction in constant time. + if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner())) + if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)]) return false; } } diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp index 12ce8481516..5bb4548e670 100644 --- a/mlir/lib/Analysis/MLFunctionMatcher.cpp +++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp @@ -31,29 +31,29 @@ struct MLFunctionMatchesStorage { /// Underlying storage for MLFunctionMatcher. struct MLFunctionMatcherStorage { - MLFunctionMatcherStorage(Statement::Kind k, + MLFunctionMatcherStorage(Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> c, - FilterFunctionType filter, Statement *skip) + FilterFunctionType filter, Instruction *skip) : kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter), skip(skip) {} - Statement::Kind kind; + Instruction::Kind kind; SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers; FilterFunctionType filter; /// skip is needed so that we can implement match without switching on the - /// type of the Statement. + /// type of the Instruction. /// The idea is that a MLFunctionMatcher first checks if it matches locally /// and then recursively applies its children matchers to its elem->children. - /// Since we want to rely on the StmtWalker impl rather than duplicate its + /// Since we want to rely on the InstWalker impl rather than duplicate its /// the logic, we allow an off-by-one traversal to account for the fact that /// we write: /// - /// void match(Statement *elem) { + /// void match(Instruction *elem) { /// for (auto &c : getChildrenMLFunctionMatchers()) { /// MLFunctionMatcher childMLFunctionMatcher(...); /// ^~~~ Needs off-by-one skip. /// - Statement *skip; + Instruction *skip; }; } // end namespace mlir @@ -65,12 +65,12 @@ llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() { return allocator; } -void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) { +void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) { if (!storage) { storage = allocator()->Allocate<MLFunctionMatchesStorage>(); - new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children)); + new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children)); } else { - storage->matches.push_back(std::make_pair(stmt, children)); + storage->matches.push_back(std::make_pair(inst, children)); } } MLFunctionMatches::iterator MLFunctionMatches::begin() { @@ -98,10 +98,10 @@ MLFunctionMatches MLFunctionMatcher::match(Function *function) { return matches; } -/// Calls walk on `statement`. -MLFunctionMatches MLFunctionMatcher::match(Statement *statement) { +/// Calls walk on `instruction`. +MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) { assert(!matches && "MLFunctionMatcher already matched!"); - this->walkPostOrder(statement); + this->walkPostOrder(instruction); return matches; } @@ -117,17 +117,17 @@ unsigned MLFunctionMatcher::getDepth() { return depth + 1; } -/// Matches a single statement in the following way: -/// 1. checks the kind of statement against the matcher, if different then +/// Matches a single instruction in the following way: +/// 1. checks the kind of instruction against the matcher, if different then /// there is no match; -/// 2. calls the customizable filter function to refine the single statement +/// 2. calls the customizable filter function to refine the single instruction /// match with extra semantic constraints; /// 3. if all is good, recursivey matches the children patterns; -/// 4. if all children match then the single statement matches too and is +/// 4. if all children match then the single instruction matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. -void MLFunctionMatcher::matchOne(Statement *elem) { +void MLFunctionMatcher::matchOne(Instruction *elem) { if (storage->skip == elem) { return; } @@ -159,7 +159,8 @@ llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() { return allocator; } -MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child, +MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k, + MLFunctionMatcher child, FilterFunctionType filter) : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) { // Initialize with placement new. @@ -168,7 +169,7 @@ MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child, } MLFunctionMatcher::MLFunctionMatcher( - Statement::Kind k, MutableArrayRef<MLFunctionMatcher> children, + Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children, FilterFunctionType filter) : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) { // Initialize with placement new. @@ -178,14 +179,14 @@ MLFunctionMatcher::MLFunctionMatcher( MLFunctionMatcher MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl, - Statement *stmt) { + Instruction *inst) { MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(), tmpl.getFilterFunction()); - res.storage->skip = stmt; + res.storage->skip = inst; return res; } -Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; } +Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; } MutableArrayRef<MLFunctionMatcher> MLFunctionMatcher::getChildrenMLFunctionMatchers() { @@ -200,54 +201,55 @@ namespace mlir { namespace matcher { MLFunctionMatcher Op(FilterFunctionType filter) { - return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter); + return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter); } MLFunctionMatcher If(MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction); + return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction); } MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::If, child, filter); + return MLFunctionMatcher(Instruction::Kind::If, child, filter); } MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::If, children, + return MLFunctionMatcher(Instruction::Kind::If, children, defaultFilterFunction); } MLFunctionMatcher If(FilterFunctionType filter, MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::If, children, filter); + return MLFunctionMatcher(Instruction::Kind::If, children, filter); } MLFunctionMatcher For(MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction); + return MLFunctionMatcher(Instruction::Kind::For, child, + defaultFilterFunction); } MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::For, child, filter); + return MLFunctionMatcher(Instruction::Kind::For, child, filter); } MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::For, children, + return MLFunctionMatcher(Instruction::Kind::For, children, defaultFilterFunction); } MLFunctionMatcher For(FilterFunctionType filter, MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::For, children, filter); + return MLFunctionMatcher(Instruction::Kind::For, children, filter); } // TODO(ntv): parallel annotation on loops. -bool isParallelLoop(const Statement &stmt) { - const auto *loop = cast<ForStmt>(&stmt); +bool isParallelLoop(const Instruction &inst) { + const auto *loop = cast<ForInst>(&inst); return (void *)loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. -bool isReductionLoop(const Statement &stmt) { - const auto *loop = cast<ForStmt>(&stmt); +bool isReductionLoop(const Instruction &inst) { + const auto *loop = cast<ForInst>(&inst); return (void *)loop || true; // loop->isReduction(); }; -bool isLoadOrStore(const Statement &stmt) { - const auto *opStmt = dyn_cast<OperationInst>(&stmt); - return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()); +bool isLoadOrStore(const Instruction &inst) { + const auto *opInst = dyn_cast<OperationInst>(&inst); + return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()); }; } // end namespace matcher diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index ad935faf05d..e8b668892b8 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -26,7 +26,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,14 +38,14 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> { +struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> { explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} PassResult runOnMLFunction(Function *f) override; // Not applicable to CFG functions. PassResult runOnCFGFunction(Function *f) override { return success(); } - void visitOperationInst(OperationInst *opStmt); + void visitOperationInst(OperationInst *opInst); static char passID; }; @@ -58,10 +58,10 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) { - if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { +void MemRefBoundCheck::visitOperationInst(OperationInst *opInst) { + if (auto loadOp = opInst->dyn_cast<LoadOp>()) { boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) { + } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { boundCheckLoadOrStoreOp(storeOp); } // TODO(bondhugula): do this for DMA ops as well. diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index bb668f78624..8391f15b6d3 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -39,7 +39,7 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. struct MemRefDependenceCheck : public FunctionPass, - StmtWalker<MemRefDependenceCheck> { + InstWalker<MemRefDependenceCheck> { SmallVector<OperationInst *, 4> loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} @@ -48,9 +48,9 @@ struct MemRefDependenceCheck : public FunctionPass, // Not applicable to CFG functions. PassResult runOnCFGFunction(Function *f) override { return success(); } - void visitOperationInst(OperationInst *opStmt) { - if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) { - loadsAndStores.push_back(opStmt); + void visitOperationInst(OperationInst *opInst) { + if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) { + loadsAndStores.push_back(opInst); } } static char passID; @@ -74,17 +74,17 @@ static void addMemRefAccessIndices( } } -// Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'. -static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt, +// Populates 'access' with memref, indices and opinst from 'loadOrStoreOpInst'. +static void getMemRefAccess(const OperationInst *loadOrStoreOpInst, MemRefAccess *access) { - access->opStmt = loadOrStoreOpStmt; - if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { + access->opInst = loadOrStoreOpInst; + if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) { access->memref = loadOp->getMemRef(); addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(), access); } else { - assert(loadOrStoreOpStmt->isa<StoreOp>()); - auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); + assert(loadOrStoreOpInst->isa<StoreOp>()); + auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>(); access->memref = storeOp->getMemRef(); addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(), access); @@ -93,8 +93,8 @@ static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt, // Returns the number of surrounding loops common to 'loopsA' and 'loopsB', // where each lists loops from outer-most to inner-most in loop nest. -static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForStmt *> loopsA, - ArrayRef<const ForStmt *> loopsB) { +static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForInst *> loopsA, + ArrayRef<const ForInst *> loopsB) { unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { @@ -133,18 +133,18 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, // the source access. static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { - auto *srcOpStmt = loadsAndStores[i]; + auto *srcOpInst = loadsAndStores[i]; MemRefAccess srcAccess; - getMemRefAccess(srcOpStmt, &srcAccess); - SmallVector<ForStmt *, 4> srcLoops; - getLoopIVs(*srcOpStmt, &srcLoops); + getMemRefAccess(srcOpInst, &srcAccess); + SmallVector<ForInst *, 4> srcLoops; + getLoopIVs(*srcOpInst, &srcLoops); for (unsigned j = 0; j < e; ++j) { - auto *dstOpStmt = loadsAndStores[j]; + auto *dstOpInst = loadsAndStores[j]; MemRefAccess dstAccess; - getMemRefAccess(dstOpStmt, &dstAccess); + getMemRefAccess(dstOpInst, &dstAccess); - SmallVector<ForStmt *, 4> dstLoops; - getLoopIVs(*dstOpStmt, &dstLoops); + SmallVector<ForInst *, 4> dstLoops; + getLoopIVs(*dstOpInst, &dstLoops); unsigned numCommonLoops = getNumCommonSurroundingLoops(srcLoops, dstLoops); for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { @@ -156,7 +156,7 @@ static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) { // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance // vectors from ([1, 1], [3, 3]) to (1, 3). - srcOpStmt->emitNote( + srcOpInst->emitNote( "dependence from " + Twine(i) + " to " + Twine(j) + " at depth " + Twine(d) + " = " + getDirectionVectorStr(ret, numCommonLoops, d, dependenceComponents) diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index f4c509a5132..07edb13d1a3 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -16,9 +16,9 @@ // ============================================================================= #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/Pass.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/raw_ostream.h" @@ -26,7 +26,7 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> { +struct PrintOpStatsPass : public FunctionPass, InstWalker<PrintOpStatsPass> { explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : FunctionPass(&PrintOpStatsPass::passID), os(os) {} @@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> { // Process ML functions and operation statments in ML functions. PassResult runOnMLFunction(Function *function) override; - void visitOperationInst(OperationInst *stmt); + void visitOperationInst(OperationInst *inst); // Print summary of op stats. void printSummary(); @@ -69,8 +69,8 @@ PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) { return success(); } -void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) { - ++opCount[stmt->getName().getStringRef()]; +void PrintOpStatsPass::visitOperationInst(OperationInst *inst) { + ++opCount[inst->getName().getStringRef()]; } PassResult PrintOpStatsPass::runOnMLFunction(Function *function) { diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 393d7c59de0..a8cec771f0d 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -22,7 +22,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" @@ -38,36 +38,36 @@ using namespace mlir; using llvm::DenseSet; using llvm::SetVector; -void mlir::getForwardSlice(Statement *stmt, - SetVector<Statement *> *forwardSlice, +void mlir::getForwardSlice(Instruction *inst, + SetVector<Instruction *> *forwardSlice, TransitiveFilter filter, bool topLevel) { - if (!stmt) { + if (!inst) { return; } // Evaluate whether we should keep this use. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. - if (!filter(stmt)) { + if (!filter(inst)) { return; } - if (auto *opStmt = dyn_cast<OperationInst>(stmt)) { - assert(opStmt->getNumResults() <= 1 && "NYI: multiple results"); - if (opStmt->getNumResults() > 0) { - for (auto &u : opStmt->getResult(0)->getUses()) { - auto *ownerStmt = u.getOwner(); - if (forwardSlice->count(ownerStmt) == 0) { - getForwardSlice(ownerStmt, forwardSlice, filter, + if (auto *opInst = dyn_cast<OperationInst>(inst)) { + assert(opInst->getNumResults() <= 1 && "NYI: multiple results"); + if (opInst->getNumResults() > 0) { + for (auto &u : opInst->getResult(0)->getUses()) { + auto *ownerInst = u.getOwner(); + if (forwardSlice->count(ownerInst) == 0) { + getForwardSlice(ownerInst, forwardSlice, filter, /*topLevel=*/false); } } } - } else if (auto *forStmt = dyn_cast<ForStmt>(stmt)) { - for (auto &u : forStmt->getUses()) { - auto *ownerStmt = u.getOwner(); - if (forwardSlice->count(ownerStmt) == 0) { - getForwardSlice(ownerStmt, forwardSlice, filter, + } else if (auto *forInst = dyn_cast<ForInst>(inst)) { + for (auto &u : forInst->getUses()) { + auto *ownerInst = u.getOwner(); + if (forwardSlice->count(ownerInst) == 0) { + getForwardSlice(ownerInst, forwardSlice, filter, /*topLevel=*/false); } } @@ -80,61 +80,61 @@ void mlir::getForwardSlice(Statement *stmt, // std::reverse does not work out of the box on SetVector and I want an // in-place swap based thing (the real std::reverse, not the LLVM adapter). // TODO(clattner): Consider adding an extra method? - std::vector<Statement *> v(forwardSlice->takeVector()); + std::vector<Instruction *> v(forwardSlice->takeVector()); forwardSlice->insert(v.rbegin(), v.rend()); } else { - forwardSlice->insert(stmt); + forwardSlice->insert(inst); } } -void mlir::getBackwardSlice(Statement *stmt, - SetVector<Statement *> *backwardSlice, +void mlir::getBackwardSlice(Instruction *inst, + SetVector<Instruction *> *backwardSlice, TransitiveFilter filter, bool topLevel) { - if (!stmt) { + if (!inst) { return; } // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. - if (!filter(stmt)) { + if (!filter(inst)) { return; } - for (auto *operand : stmt->getOperands()) { - auto *stmt = operand->getDefiningInst(); - if (backwardSlice->count(stmt) == 0) { - getBackwardSlice(stmt, backwardSlice, filter, + for (auto *operand : inst->getOperands()) { + auto *inst = operand->getDefiningInst(); + if (backwardSlice->count(inst) == 0) { + getBackwardSlice(inst, backwardSlice, filter, /*topLevel=*/false); } } - // Don't insert the top level statement, we just queried on it and don't + // Don't insert the top level instruction, we just queried on it and don't // want it in the results. if (!topLevel) { - backwardSlice->insert(stmt); + backwardSlice->insert(inst); } } -SetVector<Statement *> mlir::getSlice(Statement *stmt, - TransitiveFilter backwardFilter, - TransitiveFilter forwardFilter) { - SetVector<Statement *> slice; - slice.insert(stmt); +SetVector<Instruction *> mlir::getSlice(Instruction *inst, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector<Instruction *> slice; + slice.insert(inst); unsigned currentIndex = 0; - SetVector<Statement *> backwardSlice; - SetVector<Statement *> forwardSlice; + SetVector<Instruction *> backwardSlice; + SetVector<Instruction *> forwardSlice; while (currentIndex != slice.size()) { - auto *currentStmt = (slice)[currentIndex]; - // Compute and insert the backwardSlice starting from currentStmt. + auto *currentInst = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentInst. backwardSlice.clear(); - getBackwardSlice(currentStmt, &backwardSlice, backwardFilter); + getBackwardSlice(currentInst, &backwardSlice, backwardFilter); slice.insert(backwardSlice.begin(), backwardSlice.end()); - // Compute and insert the forwardSlice starting from currentStmt. + // Compute and insert the forwardSlice starting from currentInst. forwardSlice.clear(); - getForwardSlice(currentStmt, &forwardSlice, forwardFilter); + getForwardSlice(currentInst, &forwardSlice, forwardFilter); slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } @@ -144,24 +144,24 @@ SetVector<Statement *> mlir::getSlice(Statement *stmt, namespace { /// DFS post-order implementation that maintains a global count to work across /// multiple invocations, to help implement topological sort on multi-root DAGs. -/// We traverse all statements but only record the ones that appear in `toSort` -/// for the final result. +/// We traverse all instructions but only record the ones that appear in +/// `toSort` for the final result. struct DFSState { - DFSState(const SetVector<Statement *> &set) + DFSState(const SetVector<Instruction *> &set) : toSort(set), topologicalCounts(), seen() {} - const SetVector<Statement *> &toSort; - SmallVector<Statement *, 16> topologicalCounts; - DenseSet<Statement *> seen; + const SetVector<Instruction *> &toSort; + SmallVector<Instruction *, 16> topologicalCounts; + DenseSet<Instruction *> seen; }; } // namespace -static void DFSPostorder(Statement *current, DFSState *state) { - auto *opStmt = cast<OperationInst>(current); - assert(opStmt->getNumResults() <= 1 && "NYI: multi-result"); - if (opStmt->getNumResults() > 0) { - for (auto &u : opStmt->getResult(0)->getUses()) { - auto *stmt = u.getOwner(); - DFSPostorder(stmt, state); +static void DFSPostorder(Instruction *current, DFSState *state) { + auto *opInst = cast<OperationInst>(current); + assert(opInst->getNumResults() <= 1 && "NYI: multi-result"); + if (opInst->getNumResults() > 0) { + for (auto &u : opInst->getResult(0)->getUses()) { + auto *inst = u.getOwner(); + DFSPostorder(inst, state); } } bool inserted; @@ -175,8 +175,8 @@ static void DFSPostorder(Statement *current, DFSState *state) { } } -SetVector<Statement *> -mlir::topologicalSort(const SetVector<Statement *> &toSort) { +SetVector<Instruction *> +mlir::topologicalSort(const SetVector<Instruction *> &toSort) { if (toSort.empty()) { return toSort; } @@ -189,7 +189,7 @@ mlir::topologicalSort(const SetVector<Statement *> &toSort) { } // Reorder and return. - SetVector<Statement *> res; + SetVector<Instruction *> res; for (auto it = state.topologicalCounts.rbegin(), eit = state.topologicalCounts.rend(); it != eit; ++it) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index f6191418f54..a7fc5ac619e 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -34,8 +34,8 @@ using namespace mlir; -/// Returns true if statement 'a' properly dominates statement b. -bool mlir::properlyDominates(const Statement &a, const Statement &b) { +/// Returns true if instruction 'a' properly dominates instruction b. +bool mlir::properlyDominates(const Instruction &a, const Instruction &b) { if (&a == &b) return false; @@ -64,24 +64,24 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) { return false; } -/// Returns true if statement A dominates statement B. -bool mlir::dominates(const Statement &a, const Statement &b) { +/// Returns true if instruction A dominates instruction B. +bool mlir::dominates(const Instruction &a, const Instruction &b) { return &a == &b || properlyDominates(a, b); } -/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from -/// the outermost 'for' statement to the innermost one. -void mlir::getLoopIVs(const Statement &stmt, - SmallVectorImpl<ForStmt *> *loops) { - auto *currStmt = stmt.getParentStmt(); - ForStmt *currForStmt; - // Traverse up the hierarchy collecing all 'for' statement while skipping over - // 'if' statements. - while (currStmt && ((currForStmt = dyn_cast<ForStmt>(currStmt)) || - isa<IfStmt>(currStmt))) { - if (currForStmt) - loops->push_back(currForStmt); - currStmt = currStmt->getParentStmt(); +/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from +/// the outermost 'for' instruction to the innermost one. +void mlir::getLoopIVs(const Instruction &inst, + SmallVectorImpl<ForInst *> *loops) { + auto *currInst = inst.getParentInst(); + ForInst *currForInst; + // Traverse up the hierarchy collecing all 'for' instruction while skipping + // over 'if' instructions. + while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) || + isa<IfInst>(currInst))) { + if (currForInst) + loops->push_back(currForInst); + currInst = currInst->getParentInst(); } std::reverse(loops->begin(), loops->end()); } @@ -129,7 +129,7 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape( /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parameteric in 'loopDepth' loops -/// surrounding opStmt and any additional Function symbols. Returns false if +/// surrounding opInst and any additional Function symbols. Returns false if /// this fails due to yet unimplemented cases. // For example, the memref region for this load operation at loopDepth = 1 will // be as below: @@ -145,21 +145,21 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape( // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, +bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, MemRefRegion *region) { OpPointer<LoadOp> loadOp; OpPointer<StoreOp> storeOp; unsigned rank; SmallVector<Value *, 4> indices; - if ((loadOp = opStmt->dyn_cast<LoadOp>())) { + if ((loadOp = opInst->dyn_cast<LoadOp>())) { rank = loadOp->getMemRefType().getRank(); for (auto *index : loadOp->getIndices()) { indices.push_back(index); } region->memref = loadOp->getMemRef(); region->setWrite(false); - } else if ((storeOp = opStmt->dyn_cast<StoreOp>())) { + } else if ((storeOp = opInst->dyn_cast<StoreOp>())) { rank = storeOp->getMemRefType().getRank(); for (auto *index : storeOp->getIndices()) { indices.push_back(index); @@ -173,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, // Build the constraints for this region. FlatAffineConstraints *regionCst = region->getConstraints(); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto idMap = b.getMultiDimIdentityMap(rank); // Initialize 'accessValueMap' and compose with reachable AffineApplyOps. @@ -192,20 +192,20 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto *loop = dyn_cast<ForStmt>(accessValueMap.getOperand(i))) { + if (auto *loop = dyn_cast<ForInst>(accessValueMap.getOperand(i))) { // Note that regionCst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. - // TODO(bondhugula): rewrite this to use getStmtIndexSet; this way + // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!regionCst->addForStmtDomain(*loop)) + if (!regionCst->addForInstDomain(*loop)) return false; } else { // Has to be a valid symbol. auto *symbol = accessValueMap.getOperand(i); assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningInst()) { - if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) { + if (auto *opInst = symbol->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) { regionCst->setIdToConstant(*symbol, constOp->getValue()); } } @@ -220,12 +220,12 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. - SmallVector<ForStmt *, 4> outerIVs; - getLoopIVs(*opStmt, &outerIVs); + SmallVector<ForInst *, 4> outerIVs; + getLoopIVs(*opInst, &outerIVs); outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { - ForStmt *iv; - if ((iv = dyn_cast<ForStmt>(operand)) && + ForInst *iv; + if ((iv = dyn_cast<ForInst>(operand)) && std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) { regionCst->projectOut(operand); } @@ -282,9 +282,9 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same<LoadOrStoreOpPointer, OpPointer<StoreOp>>::value, "function argument should be either a LoadOp or a StoreOp"); - OperationInst *opStmt = loadOrStoreOp->getInstruction(); + OperationInst *opInst = loadOrStoreOp->getInstruction(); MemRefRegion region; - if (!getMemRefRegion(opStmt, /*loopDepth=*/0, ®ion)) + if (!getMemRefRegion(opInst, /*loopDepth=*/0, ®ion)) return false; LLVM_DEBUG(llvm::dbgs() << "Memory region"); LLVM_DEBUG(region.getConstraints()->dump()); @@ -333,43 +333,43 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer<LoadOp> loadOp, template bool mlir::boundCheckLoadOrStoreOp(OpPointer<StoreOp> storeOp, bool emitError); -// Returns in 'positions' the Block positions of 'stmt' in each ancestor -// Block from the Block containing statement, stopping at 'limitBlock'. -static void findStmtPosition(const Statement *stmt, Block *limitBlock, +// Returns in 'positions' the Block positions of 'inst' in each ancestor +// Block from the Block containing instruction, stopping at 'limitBlock'. +static void findInstPosition(const Instruction *inst, Block *limitBlock, SmallVectorImpl<unsigned> *positions) { - Block *block = stmt->getBlock(); + Block *block = inst->getBlock(); while (block != limitBlock) { - int stmtPosInBlock = block->findInstPositionInBlock(*stmt); - assert(stmtPosInBlock >= 0); - positions->push_back(stmtPosInBlock); - stmt = block->getContainingInst(); - block = stmt->getBlock(); + int instPosInBlock = block->findInstPositionInBlock(*inst); + assert(instPosInBlock >= 0); + positions->push_back(instPosInBlock); + inst = block->getContainingInst(); + block = inst->getBlock(); } std::reverse(positions->begin(), positions->end()); } -// Returns the Statement in a possibly nested set of Blocks, where the -// position of the statement is represented by 'positions', which has a +// Returns the Instruction in a possibly nested set of Blocks, where the +// position of the instruction is represented by 'positions', which has a // Block position for each level of nesting. -static Statement *getStmtAtPosition(ArrayRef<unsigned> positions, - unsigned level, Block *block) { +static Instruction *getInstAtPosition(ArrayRef<unsigned> positions, + unsigned level, Block *block) { unsigned i = 0; - for (auto &stmt : *block) { + for (auto &inst : *block) { if (i != positions[level]) { ++i; continue; } if (level == positions.size() - 1) - return &stmt; - if (auto *childForStmt = dyn_cast<ForStmt>(&stmt)) - return getStmtAtPosition(positions, level + 1, childForStmt->getBody()); + return &inst; + if (auto *childForInst = dyn_cast<ForInst>(&inst)) + return getInstAtPosition(positions, level + 1, childForInst->getBody()); - if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) { - auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen()); + if (auto *ifInst = dyn_cast<IfInst>(&inst)) { + auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); if (ret != nullptr) return ret; - if (auto *elseClause = ifStmt->getElse()) - return getStmtAtPosition(positions, level + 1, elseClause); + if (auto *elseClause = ifInst->getElse()) + return getInstAtPosition(positions, level + 1, elseClause); } } return nullptr; @@ -379,7 +379,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions, // dependence constraint system to create AffineMaps with which to adjust the // loop bounds of the inserted compution slice so that they are functions of the // loop IVs and symbols of the loops surrounding 'dstAccess'. -ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, +ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, MemRefAccess *dstAccess, unsigned srcLoopDepth, unsigned dstLoopDepth) { @@ -390,14 +390,14 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, return nullptr; } // Get loop nest surrounding src operation. - SmallVector<ForStmt *, 4> srcLoopNest; - getLoopIVs(*srcAccess->opStmt, &srcLoopNest); + SmallVector<ForInst *, 4> srcLoopNest; + getLoopIVs(*srcAccess->opInst, &srcLoopNest); unsigned srcLoopNestSize = srcLoopNest.size(); assert(srcLoopDepth <= srcLoopNestSize); // Get loop nest surrounding dst operation. - SmallVector<ForStmt *, 4> dstLoopNest; - getLoopIVs(*dstAccess->opStmt, &dstLoopNest); + SmallVector<ForInst *, 4> dstLoopNest; + getLoopIVs(*dstAccess->opInst, &dstLoopNest); unsigned dstLoopNestSize = dstLoopNest.size(); (void)dstLoopNestSize; assert(dstLoopDepth > 0); @@ -425,7 +425,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, } SmallVector<unsigned, 2> nonZeroDimIds; SmallVector<unsigned, 2> nonZeroSymbolIds; - srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(), + srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opInst->getContext(), &nonZeroDimIds, &nonZeroSymbolIds); if (srcIvMaps[i] == AffineMap::Null()) { continue; @@ -446,23 +446,23 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // with a symbol identifiers in 'nonZeroSymbolIds'. } - // Find the stmt block positions of 'srcAccess->opStmt' within 'srcLoopNest'. + // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopNest'. SmallVector<unsigned, 4> positions; - findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions); + findInstPosition(srcAccess->opInst, srcLoopNest[0]->getBlock(), &positions); - // Clone src loop nest and insert it a the beginning of the statement block + // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopNest'. - auto *dstForStmt = dstLoopNest[dstLoopDepth - 1]; - FuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin()); + auto *dstForInst = dstLoopNest[dstLoopDepth - 1]; + FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin()); DenseMap<const Value *, Value *> operandMap; - auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap)); - - // Lookup stmt in cloned 'sliceLoopNest' at 'positions'. - Statement *sliceStmt = - getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); - // Get loop nest surrounding 'sliceStmt'. - SmallVector<ForStmt *, 4> sliceSurroundingLoops; - getLoopIVs(*sliceStmt, &sliceSurroundingLoops); + auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopNest[0], operandMap)); + + // Lookup inst in cloned 'sliceLoopNest' at 'positions'. + Instruction *sliceInst = + getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); + // Get loop nest surrounding 'sliceInst'. + SmallVector<ForInst *, 4> sliceSurroundingLoops; + getLoopIVs(*sliceInst, &sliceSurroundingLoops); unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); (void)sliceSurroundingLoopsSize; @@ -470,18 +470,18 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize; assert(sliceLoopLimit <= sliceSurroundingLoopsSize); for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) { - auto *forStmt = sliceSurroundingLoops[i]; + auto *forInst = sliceSurroundingLoops[i]; unsigned index = i - dstLoopDepth; AffineMap lbMap = srcIvMaps[index]; if (lbMap == AffineMap::Null()) continue; - forStmt->setLowerBound(srcIvOperands[index], lbMap); + forInst->setLowerBound(srcIvOperands[index], lbMap); // Create upper bound map with is lower bound map + 1; assert(lbMap.getNumResults() == 1); AffineExpr ubResultExpr = lbMap.getResult(0) + 1; AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), {ubResultExpr}, {}); - forStmt->setUpperBound(srcIvOperands[index], ubMap); + forInst->setUpperBound(srcIvOperands[index], ubMap); } return sliceLoopNest; } diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index cd9451cd5e9..e092b29a13b 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -19,7 +19,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -105,7 +105,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType, static AffineMap makePermutationMap( MLIRContext *context, llvm::iterator_range<OperationInst::operand_iterator> indices, - const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) { + const DenseMap<ForInst *, unsigned> &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices); @@ -137,10 +137,11 @@ static AffineMap makePermutationMap( /// the specified type. /// TODO(ntv): could also be implemented as a collect parents followed by a /// filter and made available outside this file. -template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) { +template <typename T> +static SetVector<T *> getParentsOfType(Instruction *inst) { SetVector<T *> res; - auto *current = stmt; - while (auto *parent = current->getParentStmt()) { + auto *current = inst; + while (auto *parent = current->getParentInst()) { auto *typedParent = dyn_cast<T>(parent); if (typedParent) { assert(res.count(typedParent) == 0 && "Already inserted"); @@ -151,34 +152,34 @@ template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) { return res; } -/// Returns the enclosing ForStmt, from closest to farthest. -static SetVector<ForStmt *> getEnclosingForStmts(Statement *stmt) { - return getParentsOfType<ForStmt>(stmt); +/// Returns the enclosing ForInst, from closest to farthest. +static SetVector<ForInst *> getEnclosingforInsts(Instruction *inst) { + return getParentsOfType<ForInst>(inst); } AffineMap -mlir::makePermutationMap(OperationInst *opStmt, - const DenseMap<ForStmt *, unsigned> &loopToVectorDim) { - DenseMap<ForStmt *, unsigned> enclosingLoopToVectorDim; - auto enclosingLoops = getEnclosingForStmts(opStmt); - for (auto *forStmt : enclosingLoops) { - auto it = loopToVectorDim.find(forStmt); +mlir::makePermutationMap(OperationInst *opInst, + const DenseMap<ForInst *, unsigned> &loopToVectorDim) { + DenseMap<ForInst *, unsigned> enclosingLoopToVectorDim; + auto enclosingLoops = getEnclosingforInsts(opInst); + for (auto *forInst : enclosingLoops) { + auto it = loopToVectorDim.find(forInst); if (it != loopToVectorDim.end()) { enclosingLoopToVectorDim.insert(*it); } } - if (auto load = opStmt->dyn_cast<LoadOp>()) { - return ::makePermutationMap(opStmt->getContext(), load->getIndices(), + if (auto load = opInst->dyn_cast<LoadOp>()) { + return ::makePermutationMap(opInst->getContext(), load->getIndices(), enclosingLoopToVectorDim); } - auto store = opStmt->cast<StoreOp>(); - return ::makePermutationMap(opStmt->getContext(), store->getIndices(), + auto store = opInst->cast<StoreOp>(); + return ::makePermutationMap(opInst->getContext(), store->getIndices(), enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, +bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, @@ -191,20 +192,20 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) { + if (auto read = opInst.dyn_cast<VectorTransferReadOp>()) { superVectorType = read->getResultType(); mustDivide = true; - } else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) { + } else if (auto write = opInst.dyn_cast<VectorTransferWriteOp>()) { superVectorType = write->getVectorType(); mustDivide = true; - } else if (opStmt.getNumResults() == 0) { - if (!opStmt.isa<ReturnOp>()) { - opStmt.emitError("NYI: assuming only return statements can have 0 " + } else if (opInst.getNumResults() == 0) { + if (!opInst.isa<ReturnOp>()) { + opInst.emitError("NYI: assuming only return instructions can have 0 " " results at this point"); } return false; - } else if (opStmt.getNumResults() == 1) { - if (auto v = opStmt.getResult(0)->getType().dyn_cast<VectorType>()) { + } else if (opInst.getNumResults() == 1) { + if (auto v = opInst.getResult(0)->getType().dyn_cast<VectorType>()) { superVectorType = v; } else { // Not a vector type. @@ -213,7 +214,7 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, } else { // Not a vector_transfer and has more than 1 result, fail hard for now to // wake us up when something changes. - opStmt.emitError("NYI: statement has more than 1 result"); + opInst.emitError("NYI: instruction has more than 1 result"); return false; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 4cad531ecaa..7217c5492a6 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -36,9 +36,9 @@ #include "mlir/Analysis/Dominance.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/Module.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/raw_ostream.h" @@ -239,14 +239,14 @@ bool CFGFuncVerifier::verifyBlock(const Block &block) { //===----------------------------------------------------------------------===// namespace { -struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> { +struct MLFuncVerifier : public Verifier, public InstWalker<MLFuncVerifier> { const Function &fn; bool hadError = false; MLFuncVerifier(const Function &fn) : Verifier(fn), fn(fn) {} - void visitOperationInst(OperationInst *opStmt) { - hadError |= verifyOperation(*opStmt); + void visitOperationInst(OperationInst *opInst) { + hadError |= verifyOperation(*opInst); } bool verify() { @@ -269,7 +269,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> { /// operations are properly dominated by their definitions. bool verifyDominance(); - /// Verify that function has a return statement that matches its signature. + /// Verify that function has a return instruction that matches its signature. bool verifyReturn(); }; } // end anonymous namespace @@ -285,48 +285,48 @@ bool MLFuncVerifier::verifyDominance() { for (auto *arg : fn.getArguments()) liveValues.insert(arg, true); - // This recursive function walks the statement list pushing scopes onto the + // This recursive function walks the instruction list pushing scopes onto the // stack as it goes, and popping them to remove them from the table. std::function<bool(const Block &block)> walkBlock; walkBlock = [&](const Block &block) -> bool { HashTable::ScopeTy blockScope(liveValues); - // The induction variable of a for statement is live within its body. - if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingInst())) - liveValues.insert(forStmt, true); + // The induction variable of a for instruction is live within its body. + if (auto *forInst = dyn_cast_or_null<ForInst>(block.getContainingInst())) + liveValues.insert(forInst, true); - for (auto &stmt : block) { + for (auto &inst : block) { // Verify that each of the operands are live. unsigned operandNo = 0; - for (auto *opValue : stmt.getOperands()) { + for (auto *opValue : inst.getOperands()) { if (!liveValues.count(opValue)) { - stmt.emitError("operand #" + Twine(operandNo) + + inst.emitError("operand #" + Twine(operandNo) + " does not dominate this use"); - if (auto *useStmt = opValue->getDefiningInst()) - useStmt->emitNote("operand defined here"); + if (auto *useInst = opValue->getDefiningInst()) + useInst->emitNote("operand defined here"); return true; } ++operandNo; } - if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { + if (auto *opInst = dyn_cast<OperationInst>(&inst)) { // Operations define values, add them to the hash table. - for (auto *result : opStmt->getResults()) + for (auto *result : opInst->getResults()) liveValues.insert(result, true); continue; } // If this is an if or for, recursively walk the block they contain. - if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) { - if (walkBlock(*ifStmt->getThen())) + if (auto *ifInst = dyn_cast<IfInst>(&inst)) { + if (walkBlock(*ifInst->getThen())) return true; - if (auto *elseClause = ifStmt->getElse()) + if (auto *elseClause = ifInst->getElse()) if (walkBlock(*elseClause)) return true; } - if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) - if (walkBlock(*forStmt->getBody())) + if (auto *forInst = dyn_cast<ForInst>(&inst)) + if (walkBlock(*forInst->getBody())) return true; } @@ -338,13 +338,14 @@ bool MLFuncVerifier::verifyDominance() { } bool MLFuncVerifier::verifyReturn() { - // TODO: fold return verification in the pass that verifies all statements. - const char missingReturnMsg[] = "ML function must end with return statement"; + // TODO: fold return verification in the pass that verifies all instructions. + const char missingReturnMsg[] = + "ML function must end with return instruction"; if (fn.getBody()->getInstructions().empty()) return failure(missingReturnMsg, fn); - const auto &stmt = fn.getBody()->getInstructions().back(); - if (const auto *op = dyn_cast<OperationInst>(&stmt)) { + const auto &inst = fn.getBody()->getInstructions().back(); + if (const auto *op = dyn_cast<OperationInst>(&inst)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index daaaee7010c..cf822e025b8 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -25,11 +25,11 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/APFloat.h" @@ -117,10 +117,10 @@ private: void visitExtFunction(const Function *fn); void visitCFGFunction(const Function *fn); void visitMLFunction(const Function *fn); - void visitStatement(const Statement *stmt); - void visitForStmt(const ForStmt *forStmt); - void visitIfStmt(const IfStmt *ifStmt); - void visitOperationInst(const OperationInst *opStmt); + void visitInstruction(const Instruction *inst); + void visitForInst(const ForInst *forInst); + void visitIfInst(const IfInst *ifInst); + void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); void visitOperation(const OperationInst *op); @@ -184,47 +184,47 @@ void ModuleState::visitCFGFunction(const Function *fn) { if (auto *opInst = dyn_cast<OperationInst>(&op)) visitOperation(opInst); else { - llvm_unreachable("IfStmt/ForStmt in a CFG Function isn't supported"); + llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported"); } } } } -void ModuleState::visitIfStmt(const IfStmt *ifStmt) { - recordIntegerSetReference(ifStmt->getIntegerSet()); - for (auto &childStmt : *ifStmt->getThen()) - visitStatement(&childStmt); - if (ifStmt->hasElse()) - for (auto &childStmt : *ifStmt->getElse()) - visitStatement(&childStmt); +void ModuleState::visitIfInst(const IfInst *ifInst) { + recordIntegerSetReference(ifInst->getIntegerSet()); + for (auto &childInst : *ifInst->getThen()) + visitInstruction(&childInst); + if (ifInst->hasElse()) + for (auto &childInst : *ifInst->getElse()) + visitInstruction(&childInst); } -void ModuleState::visitForStmt(const ForStmt *forStmt) { - AffineMap lbMap = forStmt->getLowerBoundMap(); +void ModuleState::visitForInst(const ForInst *forInst) { + AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasShorthandForm(lbMap)) recordAffineMapReference(lbMap); - AffineMap ubMap = forStmt->getUpperBoundMap(); + AffineMap ubMap = forInst->getUpperBoundMap(); if (!hasShorthandForm(ubMap)) recordAffineMapReference(ubMap); - for (auto &childStmt : *forStmt->getBody()) - visitStatement(&childStmt); + for (auto &childInst : *forInst->getBody()) + visitInstruction(&childInst); } -void ModuleState::visitOperationInst(const OperationInst *opStmt) { - for (auto attr : opStmt->getAttrs()) +void ModuleState::visitOperationInst(const OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) visitAttribute(attr.second); } -void ModuleState::visitStatement(const Statement *stmt) { - switch (stmt->getKind()) { - case Statement::Kind::If: - return visitIfStmt(cast<IfStmt>(stmt)); - case Statement::Kind::For: - return visitForStmt(cast<ForStmt>(stmt)); - case Statement::Kind::OperationInst: - return visitOperationInst(cast<OperationInst>(stmt)); +void ModuleState::visitInstruction(const Instruction *inst) { + switch (inst->getKind()) { + case Instruction::Kind::If: + return visitIfInst(cast<IfInst>(inst)); + case Instruction::Kind::For: + return visitForInst(cast<ForInst>(inst)); + case Instruction::Kind::OperationInst: + return visitOperationInst(cast<OperationInst>(inst)); default: return; } @@ -232,8 +232,8 @@ void ModuleState::visitStatement(const Statement *stmt) { void ModuleState::visitMLFunction(const Function *fn) { visitType(fn->getType()); - for (auto &stmt : *fn->getBody()) { - ModuleState::visitStatement(&stmt); + for (auto &inst : *fn->getBody()) { + ModuleState::visitInstruction(&inst); } } @@ -909,11 +909,11 @@ public: void printMLFunctionSignature(); void printOtherFunctionSignature(); - // Methods to print statements. - void print(const Statement *stmt); + // Methods to print instructions. + void print(const Instruction *inst); void print(const OperationInst *inst); - void print(const ForStmt *stmt); - void print(const IfStmt *stmt); + void print(const ForInst *inst); + void print(const IfInst *inst); void print(const Block *block); void printOperation(const OperationInst *op); @@ -959,7 +959,7 @@ public: void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims); void printBound(AffineBound bound, const char *prefix); - // Number of spaces used for indenting nested statements. + // Number of spaces used for indenting nested instructions. const static unsigned indentWidth = 2; protected: @@ -1019,22 +1019,22 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // We number instruction that have results, and we only number the first // result. switch (inst.getKind()) { - case Statement::Kind::OperationInst: { + case Instruction::Kind::OperationInst: { auto *opInst = cast<OperationInst>(&inst); if (opInst->getNumResults() != 0) numberValueID(opInst->getResult(0)); break; } - case Statement::Kind::For: { - auto *forInst = cast<ForStmt>(&inst); + case Instruction::Kind::For: { + auto *forInst = cast<ForInst>(&inst); // Number the induction variable. numberValueID(forInst); // Recursively number the stuff in the body. numberValuesInBlock(*forInst->getBody()); break; } - case Statement::Kind::If: { - auto *ifInst = cast<IfStmt>(&inst); + case Instruction::Kind::If: { + auto *ifInst = cast<IfInst>(&inst); numberValuesInBlock(*ifInst->getThen()); if (auto *elseBlock = ifInst->getElse()) numberValuesInBlock(*elseBlock); @@ -1086,7 +1086,7 @@ void FunctionPrinter::numberValueID(const Value *value) { // done with it. valueIDs[value] = nextValueID++; return; - case Value::Kind::ForStmt: + case Value::Kind::ForInst: specialName << 'i' << nextLoopID++; break; } @@ -1220,21 +1220,21 @@ void FunctionPrinter::print(const Block *block) { currentIndent += indentWidth; - for (auto &stmt : block->getInstructions()) { - print(&stmt); + for (auto &inst : block->getInstructions()) { + print(&inst); os << '\n'; } currentIndent -= indentWidth; } -void FunctionPrinter::print(const Statement *stmt) { - switch (stmt->getKind()) { - case Statement::Kind::OperationInst: - return print(cast<OperationInst>(stmt)); - case Statement::Kind::For: - return print(cast<ForStmt>(stmt)); - case Statement::Kind::If: - return print(cast<IfStmt>(stmt)); +void FunctionPrinter::print(const Instruction *inst) { + switch (inst->getKind()) { + case Instruction::Kind::OperationInst: + return print(cast<OperationInst>(inst)); + case Instruction::Kind::For: + return print(cast<ForInst>(inst)); + case Instruction::Kind::If: + return print(cast<IfInst>(inst)); } } @@ -1243,33 +1243,33 @@ void FunctionPrinter::print(const OperationInst *inst) { printOperation(inst); } -void FunctionPrinter::print(const ForStmt *stmt) { +void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "for "; - printOperand(stmt); + printOperand(inst); os << " = "; - printBound(stmt->getLowerBound(), "max"); + printBound(inst->getLowerBound(), "max"); os << " to "; - printBound(stmt->getUpperBound(), "min"); + printBound(inst->getUpperBound(), "min"); - if (stmt->getStep() != 1) - os << " step " << stmt->getStep(); + if (inst->getStep() != 1) + os << " step " << inst->getStep(); os << " {\n"; - print(stmt->getBody()); + print(inst->getBody()); os.indent(currentIndent) << "}"; } -void FunctionPrinter::print(const IfStmt *stmt) { +void FunctionPrinter::print(const IfInst *inst) { os.indent(currentIndent) << "if "; - IntegerSet set = stmt->getIntegerSet(); + IntegerSet set = inst->getIntegerSet(); printIntegerSetReference(set); - printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims()); + printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); os << " {\n"; - print(stmt->getThen()); + print(inst->getThen()); os.indent(currentIndent) << "}"; - if (stmt->hasElse()) { + if (inst->hasElse()) { os << " else {\n"; - print(stmt->getElse()); + print(inst->getElse()); os.indent(currentIndent) << "}"; } } @@ -1280,7 +1280,7 @@ void FunctionPrinter::printValueID(const Value *value, auto lookupValue = value; // If this is a reference to the result of a multi-result instruction or - // statement, print out the # identifier and make sure to map our lookup + // instruction, print out the # identifier and make sure to map our lookup // to the first result of the instruction. if (auto *result = dyn_cast<InstResult>(value)) { if (result->getOwner()->getNumResults() != 1) { @@ -1493,8 +1493,8 @@ void Value::print(raw_ostream &os) const { return; case Value::Kind::InstResult: return getDefiningInst()->print(os); - case Value::Kind::ForStmt: - return cast<ForStmt>(this)->print(os); + case Value::Kind::ForInst: + return cast<ForInst>(this)->print(os); } } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index c7e84194c35..2efba2bbf69 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -26,16 +26,16 @@ Block::~Block() { llvm::DeleteContainerPointers(arguments); } -/// Returns the closest surrounding statement that contains this block or -/// nullptr if this is a top-level statement block. -Statement *Block::getContainingInst() { +/// Returns the closest surrounding instruction that contains this block or +/// nullptr if this is a top-level instruction block. +Instruction *Block::getContainingInst() { return parent ? parent->getContainingInst() : nullptr; } Function *Block::getFunction() { Block *block = this; - while (auto *stmt = block->getContainingInst()) { - block = stmt->getBlock(); + while (auto *inst = block->getContainingInst()) { + block = inst->getBlock(); if (!block) return nullptr; } @@ -49,11 +49,11 @@ Function *Block::getFunction() { /// the latter fails. const Instruction * Block::findAncestorInstInBlock(const Instruction &inst) const { - // Traverse up the statement hierarchy starting from the owner of operand to - // find the ancestor statement that resides in the block of 'forStmt'. + // Traverse up the instruction hierarchy starting from the owner of operand to + // find the ancestor instruction that resides in the block of 'forInst'. const auto *currInst = &inst; while (currInst->getBlock() != this) { - currInst = currInst->getParentStmt(); + currInst = currInst->getParentInst(); if (!currInst) return nullptr; } @@ -106,10 +106,10 @@ OperationInst *Block::getTerminator() { // Check if the last instruction is a terminator. auto &backInst = back(); - auto *opStmt = dyn_cast<OperationInst>(&backInst); - if (!opStmt || !opStmt->isTerminator()) + auto *opInst = dyn_cast<OperationInst>(&backInst); + if (!opInst || !opInst->isTerminator()) return nullptr; - return opStmt; + return opInst; } /// Return true if this block has no predecessors. @@ -184,10 +184,10 @@ Block *Block::splitBlock(iterator splitBefore) { BlockList::BlockList(Function *container) : container(container) {} -BlockList::BlockList(Statement *container) : container(container) {} +BlockList::BlockList(Instruction *container) : container(container) {} -Statement *BlockList::getContainingInst() { - return container.dyn_cast<Statement *>(); +Instruction *BlockList::getContainingInst() { + return container.dyn_cast<Instruction *>(); } Function *BlockList::getContainingFunction() { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index a9eb6fe8c8a..4c7c8ddae81 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -268,7 +268,7 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { } //===----------------------------------------------------------------------===// -// Statements. +// Instructions. //===----------------------------------------------------------------------===// /// Add new basic block and set the insertion point to the end of it. If an @@ -298,25 +298,25 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) { return op; } -ForStmt *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands, +ForInst *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands, AffineMap lbMap, ArrayRef<Value *> ubOperands, AffineMap ubMap, int64_t step) { - auto *stmt = - ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step); - block->getInstructions().insert(insertPoint, stmt); - return stmt; + auto *inst = + ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step); + block->getInstructions().insert(insertPoint, inst); + return inst; } -ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, +ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, int64_t step) { auto lbMap = AffineMap::getConstantMap(lb, context); auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } -IfStmt *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands, +IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands, IntegerSet set) { - auto *stmt = IfStmt::create(location, operands, set); - block->getInstructions().insert(insertPoint, stmt); - return stmt; + auto *inst = IfInst::create(location, operands, set); + block->getInstructions().insert(insertPoint, inst); + return inst; } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index cbe84e10247..bacb504683b 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -18,9 +18,9 @@ #include "mlir/IR/Function.h" #include "AttributeListStorage.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringRef.h" @@ -161,21 +161,21 @@ bool Function::emitError(const Twine &message) const { // Function implementation. //===----------------------------------------------------------------------===// -const OperationInst *Function::getReturnStmt() const { +const OperationInst *Function::getReturn() const { return cast<OperationInst>(&getBody()->back()); } -OperationInst *Function::getReturnStmt() { +OperationInst *Function::getReturn() { return cast<OperationInst>(&getBody()->back()); } void Function::walk(std::function<void(OperationInst *)> callback) { - struct Walker : public StmtWalker<Walker> { + struct Walker : public InstWalker<Walker> { std::function<void(OperationInst *)> const &callback; Walker(std::function<void(OperationInst *)> const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opInst) { callback(opInst); } }; Walker v(callback); @@ -183,12 +183,12 @@ void Function::walk(std::function<void(OperationInst *)> callback) { } void Function::walkPostOrder(std::function<void(OperationInst *)> callback) { - struct Walker : public StmtWalker<Walker> { + struct Walker : public InstWalker<Walker> { std::function<void(OperationInst *)> const &callback; Walker(std::function<void(OperationInst *)> const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opInst) { callback(opInst); } }; Walker v(callback); diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Instruction.cpp index 6bd9944bb65..92f3c4ecba3 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -1,4 +1,5 @@ -//===- Statement.cpp - MLIR Statement Classes ----------------------------===// +//===- Instruction.cpp - MLIR Instruction Classes +//----------------------------===// // // Copyright 2019 The MLIR Authors. // @@ -20,10 +21,10 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; @@ -54,41 +55,43 @@ template <> unsigned BlockOperand::getOperandNumber() const { } //===----------------------------------------------------------------------===// -// Statement +// Instruction //===----------------------------------------------------------------------===// -// Statements are deleted through the destroy() member because we don't have +// Instructions are deleted through the destroy() member because we don't have // a virtual destructor. -Statement::~Statement() { - assert(block == nullptr && "statement destroyed but still in a block"); +Instruction::~Instruction() { + assert(block == nullptr && "instruction destroyed but still in a block"); } -/// Destroy this statement or one of its subclasses. -void Statement::destroy() { +/// Destroy this instruction or one of its subclasses. +void Instruction::destroy() { switch (this->getKind()) { case Kind::OperationInst: cast<OperationInst>(this)->destroy(); break; case Kind::For: - delete cast<ForStmt>(this); + delete cast<ForInst>(this); break; case Kind::If: - delete cast<IfStmt>(this); + delete cast<IfInst>(this); break; } } -Statement *Statement::getParentStmt() const { +Instruction *Instruction::getParentInst() const { return block ? block->getContainingInst() : nullptr; } -Function *Statement::getFunction() const { +Function *Instruction::getFunction() const { return block ? block->getFunction() : nullptr; } -Value *Statement::getOperand(unsigned idx) { return getInstOperand(idx).get(); } +Value *Instruction::getOperand(unsigned idx) { + return getInstOperand(idx).get(); +} -const Value *Statement::getOperand(unsigned idx) const { +const Value *Instruction::getOperand(unsigned idx) const { return getInstOperand(idx).get(); } @@ -96,12 +99,12 @@ const Value *Statement::getOperand(unsigned idx) const { // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. bool Value::isValidDim() const { - if (auto *stmt = getDefiningInst()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) + if (auto *inst = getDefiningInst()) { + // Top level instruction or constant operation is ok. + if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->dyn_cast<AffineApplyOp>()) + if (auto op = inst->dyn_cast<AffineApplyOp>()) return op->isValidDim(); return false; } @@ -114,12 +117,12 @@ bool Value::isValidDim() const { // the top level, or it is a result of affine apply operation with symbol // arguments. bool Value::isValidSymbol() const { - if (auto *stmt = getDefiningInst()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) + if (auto *inst = getDefiningInst()) { + // Top level instruction or constant operation is ok. + if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>()) return true; // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->dyn_cast<AffineApplyOp>()) + if (auto op = inst->dyn_cast<AffineApplyOp>()) return op->isValidSymbol(); return false; } @@ -128,42 +131,42 @@ bool Value::isValidSymbol() const { return isa<BlockArgument>(this); } -void Statement::setOperand(unsigned idx, Value *value) { +void Instruction::setOperand(unsigned idx, Value *value) { getInstOperand(idx).set(value); } -unsigned Statement::getNumOperands() const { +unsigned Instruction::getNumOperands() const { switch (getKind()) { case Kind::OperationInst: return cast<OperationInst>(this)->getNumOperands(); case Kind::For: - return cast<ForStmt>(this)->getNumOperands(); + return cast<ForInst>(this)->getNumOperands(); case Kind::If: - return cast<IfStmt>(this)->getNumOperands(); + return cast<IfInst>(this)->getNumOperands(); } } -MutableArrayRef<InstOperand> Statement::getInstOperands() { +MutableArrayRef<InstOperand> Instruction::getInstOperands() { switch (getKind()) { case Kind::OperationInst: return cast<OperationInst>(this)->getInstOperands(); case Kind::For: - return cast<ForStmt>(this)->getInstOperands(); + return cast<ForInst>(this)->getInstOperands(); case Kind::If: - return cast<IfStmt>(this)->getInstOperands(); + return cast<IfInst>(this)->getInstOperands(); } } -/// Emit a note about this statement, reporting up to any diagnostic +/// Emit a note about this instruction, reporting up to any diagnostic /// handlers that may be listening. -void Statement::emitNote(const Twine &message) const { +void Instruction::emitNote(const Twine &message) const { getContext()->emitDiagnostic(getLoc(), message, MLIRContext::DiagnosticKind::Note); } -/// Emit a warning about this statement, reporting up to any diagnostic +/// Emit a warning about this instruction, reporting up to any diagnostic /// handlers that may be listening. -void Statement::emitWarning(const Twine &message) const { +void Instruction::emitWarning(const Twine &message) const { getContext()->emitDiagnostic(getLoc(), message, MLIRContext::DiagnosticKind::Warning); } @@ -172,80 +175,80 @@ void Statement::emitWarning(const Twine &message) const { /// any diagnostic handlers that may be listening. This function always /// returns true. NOTE: This may terminate the containing application, only /// use when the IR is in an inconsistent state. -bool Statement::emitError(const Twine &message) const { +bool Instruction::emitError(const Twine &message) const { return getContext()->emitError(getLoc(), message); } -// Returns whether the Statement is a terminator. -bool Statement::isTerminator() const { +// Returns whether the Instruction is a terminator. +bool Instruction::isTerminator() const { if (auto *op = dyn_cast<OperationInst>(this)) return op->isTerminator(); return false; } //===----------------------------------------------------------------------===// -// ilist_traits for Statement +// ilist_traits for Instruction //===----------------------------------------------------------------------===// -void llvm::ilist_traits<::mlir::Statement>::deleteNode(Statement *stmt) { - stmt->destroy(); +void llvm::ilist_traits<::mlir::Instruction>::deleteNode(Instruction *inst) { + inst->destroy(); } -Block *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() { +Block *llvm::ilist_traits<::mlir::Instruction>::getContainingBlock() { size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); - iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this)); + iplist<Instruction> *Anchor(static_cast<iplist<Instruction> *>(this)); return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); } -/// This is a trait method invoked when a statement is added to a block. We +/// This is a trait method invoked when a instruction is added to a block. We /// keep the block pointer up to date. -void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) { - assert(!stmt->getBlock() && "already in a statement block!"); - stmt->block = getContainingBlock(); +void llvm::ilist_traits<::mlir::Instruction>::addNodeToList(Instruction *inst) { + assert(!inst->getBlock() && "already in a instruction block!"); + inst->block = getContainingBlock(); } -/// This is a trait method invoked when a statement is removed from a block. +/// This is a trait method invoked when a instruction is removed from a block. /// We keep the block pointer up to date. -void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList( - Statement *stmt) { - assert(stmt->block && "not already in a statement block!"); - stmt->block = nullptr; +void llvm::ilist_traits<::mlir::Instruction>::removeNodeFromList( + Instruction *inst) { + assert(inst->block && "not already in a instruction block!"); + inst->block = nullptr; } -/// This is a trait method invoked when a statement is moved from one block +/// This is a trait method invoked when a instruction is moved from one block /// to another. We keep the block pointer up to date. -void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList( - ilist_traits<Statement> &otherList, stmt_iterator first, - stmt_iterator last) { - // If we are transferring statements within the same block, the block +void llvm::ilist_traits<::mlir::Instruction>::transferNodesFromList( + ilist_traits<Instruction> &otherList, inst_iterator first, + inst_iterator last) { + // If we are transferring instructions within the same block, the block // pointer doesn't need to be updated. Block *curParent = getContainingBlock(); if (curParent == otherList.getContainingBlock()) return; - // Update the 'block' member of each statement. + // Update the 'block' member of each instruction. for (; first != last; ++first) first->block = curParent; } -/// Remove this statement (and its descendants) from its Block and delete +/// Remove this instruction (and its descendants) from its Block and delete /// all of them. -void Statement::erase() { - assert(getBlock() && "Statement has no block"); +void Instruction::erase() { + assert(getBlock() && "Instruction has no block"); getBlock()->getInstructions().erase(this); } -/// Unlink this statement from its current block and insert it right before -/// `existingStmt` which may be in the same or another block in the same +/// Unlink this instruction from its current block and insert it right before +/// `existingInst` which may be in the same or another block in the same /// function. -void Statement::moveBefore(Statement *existingStmt) { - moveBefore(existingStmt->getBlock(), existingStmt->getIterator()); +void Instruction::moveBefore(Instruction *existingInst) { + moveBefore(existingInst->getBlock(), existingInst->getIterator()); } /// Unlink this operation instruction from its current basic block and insert /// it right before `iterator` in the specified basic block. -void Statement::moveBefore(Block *block, - llvm::iplist<Statement>::iterator iterator) { +void Instruction::moveBefore(Block *block, + llvm::iplist<Instruction>::iterator iterator) { block->getInstructions().splice(iterator, getBlock()->getInstructions(), getIterator()); } @@ -253,7 +256,7 @@ void Statement::moveBefore(Block *block, /// This drops all operand uses from this instruction, which is an essential /// step in breaking cyclic dependences between references when they are to /// be deleted. -void Statement::dropAllReferences() { +void Instruction::dropAllReferences() { for (auto &op : getInstOperands()) op.drop(); @@ -284,17 +287,17 @@ OperationInst *OperationInst::create(Location location, OperationName name, resultTypes.size(), numSuccessors, numSuccessors, numOperands); void *rawMem = malloc(byteSize); - // Initialize the OperationInst part of the statement. - auto stmt = ::new (rawMem) + // Initialize the OperationInst part of the instruction. + auto inst = ::new (rawMem) OperationInst(location, name, numOperands, resultTypes.size(), numSuccessors, attributes, context); // Initialize the results and operands. - auto instResults = stmt->getInstResults(); + auto instResults = inst->getInstResults(); for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) - new (&instResults[i]) InstResult(resultTypes[i], stmt); + new (&instResults[i]) InstResult(resultTypes[i], inst); - auto InstOperands = stmt->getInstOperands(); + auto InstOperands = inst->getInstOperands(); // Initialize normal operands. unsigned operandIt = 0, operandE = operands.size(); @@ -305,7 +308,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, // separately below. if (!operands[operandIt]) break; - new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]); + new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]); } unsigned currentSuccNum = 0; @@ -313,13 +316,13 @@ OperationInst *OperationInst::create(Location location, OperationName name, // Verify that the amount of sentinal operands is equivalent to the number // of successors. assert(currentSuccNum == numSuccessors); - return stmt; + return inst; } - assert(stmt->isTerminator() && + assert(inst->isTerminator() && "Sentinal operand found in non terminator operand list."); - auto instBlockOperands = stmt->getBlockOperands(); - unsigned *succOperandCountIt = stmt->getTrailingObjects<unsigned>(); + auto instBlockOperands = inst->getBlockOperands(); + unsigned *succOperandCountIt = inst->getTrailingObjects<unsigned>(); unsigned *succOperandCountE = succOperandCountIt + numSuccessors; (void)succOperandCountE; @@ -338,12 +341,12 @@ OperationInst *OperationInst::create(Location location, OperationName name, } new (&instBlockOperands[currentSuccNum]) - BlockOperand(stmt, successors[currentSuccNum]); + BlockOperand(inst, successors[currentSuccNum]); *succOperandCountIt = 0; ++currentSuccNum; continue; } - new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]); + new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]); ++(*succOperandCountIt); } @@ -351,7 +354,7 @@ OperationInst *OperationInst::create(Location location, OperationName name, // successors. assert(currentSuccNum == numSuccessors); - return stmt; + return inst; } OperationInst::OperationInst(Location location, OperationName name, @@ -359,7 +362,7 @@ OperationInst::OperationInst(Location location, OperationName name, unsigned numSuccessors, ArrayRef<NamedAttribute> attributes, MLIRContext *context) - : Statement(Kind::OperationInst, location), numOperands(numOperands), + : Instruction(Kind::OperationInst, location), numOperands(numOperands), numResults(numResults), numSuccs(numSuccessors), name(name) { #ifndef NDEBUG for (auto elt : attributes) @@ -524,10 +527,10 @@ bool OperationInst::emitOpError(const Twine &message) const { } //===----------------------------------------------------------------------===// -// ForStmt +// ForInst //===----------------------------------------------------------------------===// -ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands, +ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands, AffineMap lbMap, ArrayRef<Value *> ubOperands, AffineMap ubMap, int64_t step) { assert(lbOperands.size() == lbMap.getNumInputs() && @@ -537,39 +540,39 @@ ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands, assert(step > 0 && "step has to be a positive integer constant"); unsigned numOperands = lbOperands.size() + ubOperands.size(); - ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step); + ForInst *inst = new ForInst(location, numOperands, lbMap, ubMap, step); unsigned i = 0; for (unsigned e = lbOperands.size(); i != e; ++i) - stmt->operands.emplace_back(InstOperand(stmt, lbOperands[i])); + inst->operands.emplace_back(InstOperand(inst, lbOperands[i])); for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j) - stmt->operands.emplace_back(InstOperand(stmt, ubOperands[j])); + inst->operands.emplace_back(InstOperand(inst, ubOperands[j])); - return stmt; + return inst; } -ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, +ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap, AffineMap ubMap, int64_t step) - : Statement(Statement::Kind::For, location), - Value(Value::Kind::ForStmt, + : Instruction(Instruction::Kind::For, location), + Value(Value::Kind::ForInst, Type::getIndex(lbMap.getResult(0).getContext())), body(this), lbMap(lbMap), ubMap(ubMap), step(step) { - // The body of a for stmt always has one block. + // The body of a for inst always has one block. body.push_back(new Block()); operands.reserve(numOperands); } -const AffineBound ForStmt::getLowerBound() const { +const AffineBound ForInst::getLowerBound() const { return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap); } -const AffineBound ForStmt::getUpperBound() const { +const AffineBound ForInst::getUpperBound() const { return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); } -void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { +void ForInst::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); @@ -586,7 +589,7 @@ void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { this->lbMap = map; } -void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { +void ForInst::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); @@ -603,57 +606,57 @@ void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { this->ubMap = map; } -void ForStmt::setLowerBoundMap(AffineMap map) { +void ForInst::setLowerBoundMap(AffineMap map) { assert(lbMap.getNumDims() == map.getNumDims() && lbMap.getNumSymbols() == map.getNumSymbols()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); this->lbMap = map; } -void ForStmt::setUpperBoundMap(AffineMap map) { +void ForInst::setUpperBoundMap(AffineMap map) { assert(ubMap.getNumDims() == map.getNumDims() && ubMap.getNumSymbols() == map.getNumSymbols()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); this->ubMap = map; } -bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } +bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } -bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } +bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } -int64_t ForStmt::getConstantLowerBound() const { +int64_t ForInst::getConstantLowerBound() const { return lbMap.getSingleConstantResult(); } -int64_t ForStmt::getConstantUpperBound() const { +int64_t ForInst::getConstantUpperBound() const { return ubMap.getSingleConstantResult(); } -void ForStmt::setConstantLowerBound(int64_t value) { +void ForInst::setConstantLowerBound(int64_t value) { setLowerBound({}, AffineMap::getConstantMap(value, getContext())); } -void ForStmt::setConstantUpperBound(int64_t value) { +void ForInst::setConstantUpperBound(int64_t value) { setUpperBound({}, AffineMap::getConstantMap(value, getContext())); } -ForStmt::operand_range ForStmt::getLowerBoundOperands() { +ForInst::operand_range ForInst::getLowerBoundOperands() { return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; } -ForStmt::const_operand_range ForStmt::getLowerBoundOperands() const { +ForInst::const_operand_range ForInst::getLowerBoundOperands() const { return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; } -ForStmt::operand_range ForStmt::getUpperBoundOperands() { +ForInst::operand_range ForInst::getUpperBoundOperands() { return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; } -ForStmt::const_operand_range ForStmt::getUpperBoundOperands() const { +ForInst::const_operand_range ForInst::getUpperBoundOperands() const { return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; } -bool ForStmt::matchingBoundOperandList() const { +bool ForInst::matchingBoundOperandList() const { if (lbMap.getNumDims() != ubMap.getNumDims() || lbMap.getNumSymbols() != ubMap.getNumSymbols()) return false; @@ -668,46 +671,46 @@ bool ForStmt::matchingBoundOperandList() const { } //===----------------------------------------------------------------------===// -// IfStmt +// IfInst //===----------------------------------------------------------------------===// -IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set) - : Statement(Kind::If, location), thenClause(this), elseClause(nullptr), +IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) + : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), set(set) { operands.reserve(numOperands); - // The then of an 'if' stmt always has one block. + // The then of an 'if' inst always has one block. thenClause.push_back(new Block()); } -IfStmt::~IfStmt() { +IfInst::~IfInst() { if (elseClause) delete elseClause; - // An IfStmt's IntegerSet 'set' should not be deleted since it is + // An IfInst's IntegerSet 'set' should not be deleted since it is // allocated through MLIRContext's bump pointer allocator. } -IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands, +IfInst *IfInst::create(Location location, ArrayRef<Value *> operands, IntegerSet set) { unsigned numOperands = operands.size(); assert(numOperands == set.getNumOperands() && "operand cound does not match the integer set operand count"); - IfStmt *stmt = new IfStmt(location, numOperands, set); + IfInst *inst = new IfInst(location, numOperands, set); for (auto *op : operands) - stmt->operands.emplace_back(InstOperand(stmt, op)); + inst->operands.emplace_back(InstOperand(inst, op)); - return stmt; + return inst; } -const AffineCondition IfStmt::getCondition() const { +const AffineCondition IfInst::getCondition() const { return AffineCondition(*this, set); } -MLIRContext *IfStmt::getContext() const { - // Check for degenerate case of if statement with no operands. +MLIRContext *IfInst::getContext() const { + // Check for degenerate case of if instruction with no operands. // This is unlikely, but legal. if (operands.empty()) return getFunction()->getContext(); @@ -716,16 +719,16 @@ MLIRContext *IfStmt::getContext() const { } //===----------------------------------------------------------------------===// -// Statement Cloning +// Instruction Cloning //===----------------------------------------------------------------------===// -/// Create a deep copy of this statement, remapping any operands that use -/// values outside of the statement using the map that is provided (leaving +/// Create a deep copy of this instruction, remapping any operands that use +/// values outside of the instruction using the map that is provided (leaving /// them alone if no entry is present). Replaces references to cloned -/// sub-statements to the corresponding statement that is copied, and adds +/// sub-instructions to the corresponding instruction that is copied, and adds /// those mappings to the map. -Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, - MLIRContext *context) const { +Instruction *Instruction::clone(DenseMap<const Value *, Value *> &operandMap, + MLIRContext *context) const { // If the specified value is in operandMap, return the remapped value. // Otherwise return the value itself. auto remapOperand = [&](const Value *value) -> Value * { @@ -735,48 +738,48 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, SmallVector<Value *, 8> operands; SmallVector<Block *, 2> successors; - if (auto *opStmt = dyn_cast<OperationInst>(this)) { - operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); + if (auto *opInst = dyn_cast<OperationInst>(this)) { + operands.reserve(getNumOperands() + opInst->getNumSuccessors()); - if (!opStmt->isTerminator()) { + if (!opInst->isTerminator()) { // Non-terminators just add all the operands. for (auto *opValue : getOperands()) operands.push_back(remapOperand(opValue)); } else { // We add the operands separated by nullptr's for each successor. - unsigned firstSuccOperand = opStmt->getNumSuccessors() - ? opStmt->getSuccessorOperandIndex(0) - : opStmt->getNumOperands(); - auto InstOperands = opStmt->getInstOperands(); + unsigned firstSuccOperand = opInst->getNumSuccessors() + ? opInst->getSuccessorOperandIndex(0) + : opInst->getNumOperands(); + auto InstOperands = opInst->getInstOperands(); unsigned i = 0; for (; i != firstSuccOperand; ++i) operands.push_back(remapOperand(InstOperands[i].get())); - successors.reserve(opStmt->getNumSuccessors()); - for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e; + successors.reserve(opInst->getNumSuccessors()); + for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; ++succ) { - successors.push_back(const_cast<Block *>(opStmt->getSuccessor(succ))); + successors.push_back(const_cast<Block *>(opInst->getSuccessor(succ))); // Add sentinel to delineate successor operands. operands.push_back(nullptr); // Remap the successors operands. - for (auto *operand : opStmt->getSuccessorOperands(succ)) + for (auto *operand : opInst->getSuccessorOperands(succ)) operands.push_back(remapOperand(operand)); } } SmallVector<Type, 8> resultTypes; - resultTypes.reserve(opStmt->getNumResults()); - for (auto *result : opStmt->getResults()) + resultTypes.reserve(opInst->getNumResults()); + for (auto *result : opInst->getResults()) resultTypes.push_back(result->getType()); - auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands, - resultTypes, opStmt->getAttrs(), + auto *newOp = OperationInst::create(getLoc(), opInst->getName(), operands, + resultTypes, opInst->getAttrs(), successors, context); // Remember the mapping of any results. - for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i) - operandMap[opStmt->getResult(i)] = newOp->getResult(i); + for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) + operandMap[opInst->getResult(i)] = newOp->getResult(i); return newOp; } @@ -784,43 +787,43 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, for (auto *opValue : getOperands()) operands.push_back(remapOperand(opValue)); - if (auto *forStmt = dyn_cast<ForStmt>(this)) { - auto lbMap = forStmt->getLowerBoundMap(); - auto ubMap = forStmt->getUpperBoundMap(); + if (auto *forInst = dyn_cast<ForInst>(this)) { + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); - auto *newFor = ForStmt::create( + auto *newFor = ForInst::create( getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()), lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), - ubMap, forStmt->getStep()); + ubMap, forInst->getStep()); // Remember the induction variable mapping. - operandMap[forStmt] = newFor; + operandMap[forInst] = newFor; // Recursively clone the body of the for loop. - for (auto &subStmt : *forStmt->getBody()) - newFor->getBody()->push_back(subStmt.clone(operandMap, context)); + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(operandMap, context)); return newFor; } - // Otherwise, we must have an If statement. - auto *ifStmt = cast<IfStmt>(this); - auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet()); + // Otherwise, we must have an If instruction. + auto *ifInst = cast<IfInst>(this); + auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); auto *resultThen = newIf->getThen(); - for (auto &childStmt : *ifStmt->getThen()) - resultThen->push_back(childStmt.clone(operandMap, context)); + for (auto &childInst : *ifInst->getThen()) + resultThen->push_back(childInst.clone(operandMap, context)); - if (ifStmt->hasElse()) { + if (ifInst->hasElse()) { auto *resultElse = newIf->createElse(); - for (auto &childStmt : *ifStmt->getElse()) - resultElse->push_back(childStmt.clone(operandMap, context)); + for (auto &childInst : *ifInst->getElse()) + resultElse->push_back(childInst.clone(operandMap, context)); } return newIf; } -Statement *Statement::clone(MLIRContext *context) const { +Instruction *Instruction::clone(MLIRContext *context) const { DenseMap<const Value *, Value *> operandMap; return clone(operandMap, context); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index ccd7d65f7c8..9cd4355e4aa 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -17,10 +17,10 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Statements.h" using namespace mlir; /// Form the OperationName for an op with the specified string. This either is @@ -279,7 +279,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (op->getFunction()->isML()) { Block *block = op->getBlock(); if (!block || block->getContainingInst() || &block->back() != op) - return op->emitOpError("must be the last statement in the ML function"); + return op->emitOpError("must be the last instruction in the ML function"); } else { const Block *block = op->getBlock(); if (!block || &block->back() != op) diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 8c41d488a8b..90d768c844e 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -16,7 +16,7 @@ // ============================================================================= #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/Value.h" using namespace mlir; diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index c7a5e42dd99..a213f05a932 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -17,7 +17,7 @@ #include "mlir/IR/Value.h" #include "mlir/IR/Function.h" -#include "mlir/IR/Statements.h" +#include "mlir/IR/Instructions.h" using namespace mlir; /// If this value is the result of an Instruction, return the instruction @@ -35,8 +35,8 @@ Function *Value::getFunction() { return cast<BlockArgument>(this)->getFunction(); case Value::Kind::InstResult: return getDefiningInst()->getFunction(); - case Value::Kind::ForStmt: - return cast<ForStmt>(this)->getFunction(); + case Value::Kind::ForInst: + return cast<ForInst>(this)->getFunction(); } } @@ -59,10 +59,10 @@ MLIRContext *IROperandOwner::getContext() const { switch (getKind()) { case Kind::OperationInst: return cast<OperationInst>(this)->getContext(); - case Kind::ForStmt: - return cast<ForStmt>(this)->getContext(); - case Kind::IfStmt: - return cast<IfStmt>(this)->getContext(); + case Kind::ForInst: + return cast<ForInst>(this)->getContext(); + case Kind::IfInst: + return cast<IfInst>(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 6cc1aba72b3..3f05a4a145a 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -26,12 +26,12 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/Utils.h" @@ -2071,7 +2071,7 @@ FunctionParser::~FunctionParser() { } } -/// Parse a SSA operand for an instruction or statement. +/// Parse a SSA operand for an instruction or instruction. /// /// ssa-use ::= ssa-id /// @@ -2716,7 +2716,7 @@ ParseResult CFGFunctionParser::parseFunctionBody() { /// Basic block declaration. /// -/// basic-block ::= bb-label instruction* terminator-stmt +/// basic-block ::= bb-label instruction* terminator-inst /// bb-label ::= bb-id bb-arg-list? `:` /// bb-id ::= bare-id /// bb-arg-list ::= `(` ssa-id-and-type-list? `)` @@ -2786,16 +2786,16 @@ private: /// more specific builder type. FuncBuilder builder; - ParseResult parseForStmt(); + ParseResult parseForInst(); ParseResult parseIntConstant(int64_t &val); ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands, unsigned numDims, unsigned numOperands, const char *affineStructName); ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map, bool isLower); - ParseResult parseIfStmt(); + ParseResult parseIfInst(); ParseResult parseElseClause(Block *elseClause); - ParseResult parseStatements(Block *block); + ParseResult parseInstructions(Block *block); ParseResult parseBlock(Block *block); bool parseSuccessorAndUseList(Block *&dest, @@ -2809,19 +2809,19 @@ private: ParseResult MLFunctionParser::parseFunctionBody() { auto braceLoc = getToken().getLoc(); - // Parse statements in this function. + // Parse instructions in this function. if (parseBlock(function->getBody())) return ParseFailure; return finalizeFunction(function, braceLoc); } -/// For statement. +/// For instruction. /// -/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound -/// (`step` integer-literal)? `{` ml-stmt* `}` +/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound +/// (`step` integer-literal)? `{` ml-inst* `}` /// -ParseResult MLFunctionParser::parseForStmt() { +ParseResult MLFunctionParser::parseForInst() { consumeToken(Token::kw_for); // Parse induction variable. @@ -2862,23 +2862,23 @@ ParseResult MLFunctionParser::parseForStmt() { return emitError("step has to be a positive integer"); } - // Create for statement. - ForStmt *forStmt = + // Create for instruction. + ForInst *forInst = builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap, ubOperands, ubMap, step); // Create SSA value definition for the induction variable. - if (addDefinition({inductionVariableName, 0, loc}, forStmt)) + if (addDefinition({inductionVariableName, 0, loc}, forInst)) return ParseFailure; - // If parsing of the for statement body fails, - // MLIR contains for statement with those nested statements that have been + // If parsing of the for instruction body fails, + // MLIR contains for instruction with those nested instructions that have been // successfully parsed. - if (parseBlock(forStmt->getBody())) + if (parseBlock(forInst->getBody())) return ParseFailure; // Reset insertion point to the current block. - builder.setInsertionPointToEnd(forStmt->getBlock()); + builder.setInsertionPointToEnd(forInst->getBlock()); return ParseSuccess; } @@ -3007,7 +3007,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands, // Create an identity map using dim id for an induction variable and // symbol otherwise. This representation is optimized for storage. // Analysis passes may expand it into a multi-dimensional map if desired. - if (isa<ForStmt>(operands[0])) + if (isa<ForInst>(operands[0])) map = builder.getDimIdentityMap(); else map = builder.getSymbolIdentityMap(); @@ -3095,14 +3095,14 @@ IntegerSet Parser::parseIntegerSetInline() { return set; } -/// If statement. +/// If instruction. /// -/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}` -/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}` -/// ml-if-stmt ::= ml-if-head -/// | ml-if-head `else` `{` ml-stmt* `}` +/// ml-if-head ::= `if` ml-if-cond `{` ml-inst* `}` +/// | ml-if-head `else` `if` ml-if-cond `{` ml-inst* `}` +/// ml-if-inst ::= ml-if-head +/// | ml-if-head `else` `{` ml-inst* `}` /// -ParseResult MLFunctionParser::parseIfStmt() { +ParseResult MLFunctionParser::parseIfInst() { auto loc = getToken().getLoc(); consumeToken(Token::kw_if); @@ -3115,25 +3115,25 @@ ParseResult MLFunctionParser::parseIfStmt() { "integer set")) return ParseFailure; - IfStmt *ifStmt = + IfInst *ifInst = builder.createIf(getEncodedSourceLocation(loc), operands, set); - Block *thenClause = ifStmt->getThen(); + Block *thenClause = ifInst->getThen(); - // When parsing of an if statement body fails, the IR contains - // the if statement with the portion of the body that has been + // When parsing of an if instruction body fails, the IR contains + // the if instruction with the portion of the body that has been // successfully parsed. if (parseBlock(thenClause)) return ParseFailure; if (consumeIf(Token::kw_else)) { - auto *elseClause = ifStmt->createElse(); + auto *elseClause = ifInst->createElse(); if (parseElseClause(elseClause)) return ParseFailure; } // Reset insertion point to the current block. - builder.setInsertionPointToEnd(ifStmt->getBlock()); + builder.setInsertionPointToEnd(ifInst->getBlock()); return ParseSuccess; } @@ -3141,25 +3141,25 @@ ParseResult MLFunctionParser::parseIfStmt() { ParseResult MLFunctionParser::parseElseClause(Block *elseClause) { if (getToken().is(Token::kw_if)) { builder.setInsertionPointToEnd(elseClause); - return parseIfStmt(); + return parseIfInst(); } return parseBlock(elseClause); } /// -/// Parse a list of statements ending with `return` or `}` +/// Parse a list of instructions ending with `return` or `}` /// -ParseResult MLFunctionParser::parseStatements(Block *block) { +ParseResult MLFunctionParser::parseInstructions(Block *block) { auto createOpFunc = [&](const OperationState &state) -> OperationInst * { return builder.createOperation(state); }; builder.setInsertionPointToEnd(block); - // Parse statements till we see '}' or 'return'. - // Return statement is parsed separately to emit a more intuitive error - // when '}' is missing after the return statement. + // Parse instructions till we see '}' or 'return'. + // Return instruction is parsed separately to emit a more intuitive error + // when '}' is missing after the return instruction. while (getToken().isNot(Token::r_brace, Token::kw_return)) { switch (getToken().getKind()) { default: @@ -3167,17 +3167,17 @@ ParseResult MLFunctionParser::parseStatements(Block *block) { return ParseFailure; break; case Token::kw_for: - if (parseForStmt()) + if (parseForInst()) return ParseFailure; break; case Token::kw_if: - if (parseIfStmt()) + if (parseIfInst()) return ParseFailure; break; } // end switch } - // Parse the return statement. + // Parse the return instruction. if (getToken().is(Token::kw_return)) if (parseOperation(createOpFunc)) return ParseFailure; @@ -3186,12 +3186,12 @@ ParseResult MLFunctionParser::parseStatements(Block *block) { } /// -/// Parse `{` ml-stmt* `}` +/// Parse `{` ml-inst* `}` /// ParseResult MLFunctionParser::parseBlock(Block *block) { - if (parseToken(Token::l_brace, "expected '{' before statement list") || - parseStatements(block) || - parseToken(Token::r_brace, "expected '}' after statement list")) + if (parseToken(Token::l_brace, "expected '{' before instruction list") || + parseInstructions(block) || + parseToken(Token::r_brace, "expected '}' after instruction list")) return ParseFailure; return ParseSuccess; @@ -3429,7 +3429,7 @@ ParseResult ModuleParser::parseCFGFunc() { /// ML function declarations. /// /// ml-func ::= `mlfunc` ml-func-signature -/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt +/// (`attributes` attribute-dict)? `{` ml-inst* ml-return-inst /// `}` /// ParseResult ModuleParser::parseMLFunc() { diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 0f130e19e26..20e8e0af214 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -21,9 +21,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Instructions.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/Statements.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/FileUtilities.h" diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index a5b45ba4098..80e3dd955c3 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -24,7 +24,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" @@ -207,24 +207,24 @@ struct CFGCSE : public CSEImpl { }; /// Common sub-expression elimination for ML functions. -struct MLCSE : public CSEImpl, StmtWalker<MLCSE> { - using StmtWalker<MLCSE>::walk; +struct MLCSE : public CSEImpl, InstWalker<MLCSE> { + using InstWalker<MLCSE>::walk; void run(Function *f) { - // Walk the function statements. + // Walk the function instructions. walk(f); // Finally, erase any redundant operations. eraseDeadOperations(); } - // Insert a scope for each statement range. + // Insert a scope for each instruction range. template <class Iterator> void walk(Iterator Start, Iterator End) { ScopedMapTy::ScopeTy scope(knownValues); - StmtWalker<MLCSE>::walk(Start, End); + InstWalker<MLCSE>::walk(Start, End); } - void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); } + void visitOperationInst(OperationInst *inst) { simplifyOperation(inst); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index c97b83f8485..f5edf2d8b81 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -36,20 +36,20 @@ using namespace mlir; namespace { -// ComposeAffineMaps walks stmt blocks in a Function, and for each +// ComposeAffineMaps walks inst blocks in a Function, and for each // AffineApplyOp, forward substitutes its results into any users which are // also AffineApplyOps. After forward subtituting its results, AffineApplyOps // with no remaining uses are collected and erased after the walk. // TODO(andydavis) Remove this when Chris adds instruction combiner pass. -struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> { +struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> { std::vector<OperationInst *> affineApplyOpsToErase; explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} - using InstListType = llvm::iplist<Statement>; + using InstListType = llvm::iplist<Instruction>; void walk(InstListType::iterator Start, InstListType::iterator End); - void visitOperationInst(OperationInst *stmt); + void visitOperationInst(OperationInst *inst); PassResult runOnMLFunction(Function *f) override; - using StmtWalker<ComposeAffineMaps>::walk; + using InstWalker<ComposeAffineMaps>::walk; static char passID; }; @@ -66,14 +66,14 @@ void ComposeAffineMaps::walk(InstListType::iterator Start, InstListType::iterator End) { while (Start != End) { walk(&(*Start)); - // Increment iterator after walk as visit function can mutate stmt list + // Increment iterator after walk as visit function can mutate inst list // ahead of 'Start'. ++Start; } } -void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { - if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) { +void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) { + if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) { forwardSubstitute(affineApplyOp); bool allUsesEmpty = true; for (auto *result : affineApplyOp->getInstruction()->getResults()) { @@ -83,7 +83,7 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { } } if (allUsesEmpty) { - affineApplyOpsToErase.push_back(opStmt); + affineApplyOpsToErase.push_back(opInst); } } } @@ -91,8 +91,8 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { PassResult ComposeAffineMaps::runOnMLFunction(Function *f) { affineApplyOpsToErase.clear(); walk(f); - for (auto *opStmt : affineApplyOpsToErase) { - opStmt->erase(); + for (auto *opInst : affineApplyOpsToErase) { + opInst->erase(); } return success(); } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 08087777e72..f482e90d7ac 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -17,7 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -26,20 +26,20 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> { +struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> { ConstantFold() : FunctionPass(&ConstantFold::passID) {} // All constants in the function post folding. SmallVector<Value *, 8> existingConstants; // Operations that were folded and that need to be erased. - std::vector<OperationInst *> opStmtsToErase; + std::vector<OperationInst *> opInstsToErase; using ConstantFactoryType = std::function<Value *(Attribute, Type)>; bool foldOperation(OperationInst *op, SmallVectorImpl<Value *> &existingConstants, ConstantFactoryType constantFactory); - void visitOperationInst(OperationInst *stmt); - void visitForStmt(ForStmt *stmt); + void visitOperationInst(OperationInst *inst); + void visitForInst(ForInst *inst); PassResult runOnCFGFunction(Function *f) override; PassResult runOnMLFunction(Function *f) override; @@ -140,24 +140,24 @@ PassResult ConstantFold::runOnCFGFunction(Function *f) { } // Override the walker's operation visiter for constant folding. -void ConstantFold::visitOperationInst(OperationInst *stmt) { +void ConstantFold::visitOperationInst(OperationInst *inst) { auto constantFactory = [&](Attribute value, Type type) -> Value * { - FuncBuilder builder(stmt); - return builder.create<ConstantOp>(stmt->getLoc(), value, type); + FuncBuilder builder(inst); + return builder.create<ConstantOp>(inst->getLoc(), value, type); }; - if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) { - opStmtsToErase.push_back(stmt); + if (!ConstantFold::foldOperation(inst, existingConstants, constantFactory)) { + opInstsToErase.push_back(inst); } } -// Override the walker's 'for' statement visit for constant folding. -void ConstantFold::visitForStmt(ForStmt *forStmt) { - constantFoldBounds(forStmt); +// Override the walker's 'for' instruction visit for constant folding. +void ConstantFold::visitForInst(ForInst *forInst) { + constantFoldBounds(forInst); } PassResult ConstantFold::runOnMLFunction(Function *f) { existingConstants.clear(); - opStmtsToErase.clear(); + opInstsToErase.clear(); walk(f); // At this point, these operations are dead, remove them. @@ -165,8 +165,8 @@ PassResult ConstantFold::runOnMLFunction(Function *f) { // side effects. When we have side effect modeling, we should verify that // the operation is effect-free before we remove it. Until then this is // close enough. - for (auto *stmt : opStmtsToErase) { - stmt->erase(); + for (auto *inst : opInstsToErase) { + inst->erase(); } // By the time we are done, we may have simplified a bunch of code, leaving diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 821f35ca539..abce624b06f 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -21,9 +21,9 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" @@ -39,14 +39,14 @@ using namespace mlir; namespace { // Generates CFG function equivalent to the given ML function. -class FunctionConverter : public StmtVisitor<FunctionConverter> { +class FunctionConverter : public InstVisitor<FunctionConverter> { public: FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {} Function *convert(Function *mlFunc); - void visitForStmt(ForStmt *forStmt); - void visitIfStmt(IfStmt *ifStmt); - void visitOperationInst(OperationInst *opStmt); + void visitForInst(ForInst *forInst); + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); private: Value *getConstantIndexValue(int64_t value); @@ -64,49 +64,49 @@ private: } // end anonymous namespace // Return a vector of OperationInst's arguments as Values. For each -// statement operands, represented as Value, lookup its Value conterpart in +// instruction operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. static llvm::SmallVector<mlir::Value *, 4> -operandsAs(Statement *opStmt, +operandsAs(Instruction *opInst, const llvm::DenseMap<const Value *, Value *> &valueRemapping) { llvm::SmallVector<Value *, 4> operands; - for (const Value *operand : opStmt->getOperands()) { + for (const Value *operand : opInst->getOperands()) { assert(valueRemapping.count(operand) != 0 && "operand is not defined"); operands.push_back(valueRemapping.lookup(operand)); } return operands; } -// Convert an operation statement into an operation instruction. +// Convert an operation instruction into an operation instruction. // // The operation description (name, number and types of operands or results) // remains the same but the values must be updated to be Values. Update the // mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). -void FunctionConverter::visitOperationInst(OperationInst *opStmt) { +void FunctionConverter::visitOperationInst(OperationInst *opInst) { // Set up basic operation state (context, name, operands). - OperationState state(cfgFunc->getContext(), opStmt->getLoc(), - opStmt->getName()); - state.addOperands(operandsAs(opStmt, valueRemapping)); + OperationState state(cfgFunc->getContext(), opInst->getLoc(), + opInst->getName()); + state.addOperands(operandsAs(opInst, valueRemapping)); // Set up operation return types. The corresponding Values will become // available after the operation is created. state.addTypes(functional::map( - [](Value *result) { return result->getType(); }, opStmt->getResults())); + [](Value *result) { return result->getType(); }, opInst->getResults())); // Copy attributes. - for (auto attr : opStmt->getAttrs()) { + for (auto attr : opInst->getAttrs()) { state.addAttribute(attr.first.strref(), attr.second); } - auto opInst = builder.createOperation(state); + auto op = builder.createOperation(state); // Make results of the operation accessible to the following operations // through remapping. - assert(opInst->getNumResults() == opStmt->getNumResults()); + assert(opInst->getNumResults() == op->getNumResults()); for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) { valueRemapping.insert( - std::make_pair(opStmt->getResult(i), opInst->getResult(i))); + std::make_pair(opInst->getResult(i), op->getResult(i))); } } @@ -116,10 +116,10 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) { return op->getResult(); } -// Visit all statements in the given statement block. +// Visit all instructions in the given instruction block. void FunctionConverter::visitBlock(Block *Block) { - for (auto &stmt : *Block) - this->visit(&stmt); + for (auto &inst : *Block) + this->visit(&inst); } // Given a range of values, emit the code that reduces them with "min" or "max" @@ -211,7 +211,7 @@ Value *FunctionConverter::buildMinMaxReductionSeq( // | <new insertion point> | // +--------------------------------+ // -void FunctionConverter::visitForStmt(ForStmt *forStmt) { +void FunctionConverter::visitForInst(ForInst *forInst) { // First, store the loop insertion location so that we can go back to it after // creating the new blocks (block creation updates the insertion point). Block *loopInsertionPoint = builder.getInsertionBlock(); @@ -228,27 +228,27 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // The loop condition block has an argument for loop induction variable. // Create it upfront and make the loop induction variable -> basic block - // argument remapping available to the following instructions. ForStatement + // argument remapping available to the following instructions. ForInstruction // is-a Value corresponding to the loop induction variable. builder.setInsertionPointToEnd(loopConditionBlock); Value *iv = loopConditionBlock->addArgument(builder.getIndexType()); - valueRemapping.insert(std::make_pair(forStmt, iv)); + valueRemapping.insert(std::make_pair(forInst, iv)); // Recursively construct loop body region. // Walking manually because we need custom logic before and after traversing // the list of children. builder.setInsertionPointToEnd(loopBodyFirstBlock); - visitBlock(forStmt->getBody()); + visitBlock(forInst->getBody()); // Builder point is currently at the last block of the loop body. Append the // induction variable stepping to this block and branch back to the exit // condition block. Construct an affine map f : (x -> x+step) and apply this // map to the induction variable. - auto affStep = builder.getAffineConstantExpr(forStmt->getStep()); + auto affStep = builder.getAffineConstantExpr(forInst->getStep()); auto affDim = builder.getAffineDimExpr(0); auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {}); auto stepOp = - builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv); + builder.create<AffineApplyOp>(forInst->getLoc(), affStepMap, iv); Value *nextIvValue = stepOp->getResult(0); builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, nextIvValue); @@ -262,22 +262,22 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { return valueRemapping.lookup(value); }; auto operands = - functional::map(remapOperands, forStmt->getLowerBoundOperands()); + functional::map(remapOperands, forInst->getLowerBoundOperands()); auto lbAffineApply = builder.create<AffineApplyOp>( - forStmt->getLoc(), forStmt->getLowerBoundMap(), operands); + forInst->getLoc(), forInst->getLowerBoundMap(), operands); Value *lowerBound = buildMinMaxReductionSeq( - forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); - operands = functional::map(remapOperands, forStmt->getUpperBoundOperands()); + forInst->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); + operands = functional::map(remapOperands, forInst->getUpperBoundOperands()); auto ubAffineApply = builder.create<AffineApplyOp>( - forStmt->getLoc(), forStmt->getUpperBoundMap(), operands); + forInst->getLoc(), forInst->getUpperBoundMap(), operands); Value *upperBound = buildMinMaxReductionSeq( - forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); + forInst->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, lowerBound); builder.setInsertionPointToEnd(loopConditionBlock); auto comparisonOp = builder.create<CmpIOp>( - forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); + forInst->getLoc(), CmpIPredicate::SLT, iv, upperBound); auto comparisonResult = comparisonOp->getResult(); builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult, loopBodyFirstBlock, ArrayRef<Value *>(), @@ -288,16 +288,16 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { builder.setInsertionPointToEnd(postLoopBlock); } -// Convert an "if" statement into a flow of basic blocks. +// Convert an "if" instruction into a flow of basic blocks. // -// Create an SESE region for the if statement (including its "then" and optional -// "else" statement blocks) and append it to the end of the current region. The -// conditional region consists of a sequence of condition-checking blocks that -// implement the short-circuit scheme, followed by a "then" SESE region and an -// "else" SESE region, and the continuation block that post-dominates all blocks -// of the "if" statement. The flow of blocks that correspond to the "then" and -// "else" clauses are constructed recursively, enabling easy nesting of "if" -// statements and if-then-else-if chains. +// Create an SESE region for the if instruction (including its "then" and +// optional "else" instruction blocks) and append it to the end of the current +// region. The conditional region consists of a sequence of condition-checking +// blocks that implement the short-circuit scheme, followed by a "then" SESE +// region and an "else" SESE region, and the continuation block that +// post-dominates all blocks of the "if" instruction. The flow of blocks that +// correspond to the "then" and "else" clauses are constructed recursively, +// enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ // | <end of current SESE region> | @@ -365,17 +365,17 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // | <new insertion point> | // +--------------------------------+ // -void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { - assert(ifStmt != nullptr); +void FunctionConverter::visitIfInst(IfInst *ifInst) { + assert(ifInst != nullptr); - auto integerSet = ifStmt->getCondition().getIntegerSet(); + auto integerSet = ifInst->getCondition().getIntegerSet(); // Create basic blocks for the 'then' block and for the 'else' block. // Although 'else' block may be empty in absence of an 'else' clause, create // it anyway for the sake of consistency and output IR readability. Also // create extra blocks for condition checking to prepare for short-circuit - // logic: conditions in the 'if' statement are conjunctive, so we can jump to - // the false branch as soon as one condition fails. `cond_br` requires + // logic: conditions in the 'if' instruction are conjunctive, so we can jump + // to the false branch as soon as one condition fails. `cond_br` requires // another block as a target when the condition is true, and that block will // contain the next condition. Block *ifInsertionBlock = builder.getInsertionBlock(); @@ -412,14 +412,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { builder.getAffineMap(integerSet.getNumDims(), integerSet.getNumSymbols(), constraintExpr, {}); auto affineApplyOp = builder.create<AffineApplyOp>( - ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping)); + ifInst->getLoc(), affineMap, operandsAs(ifInst, valueRemapping)); Value *affResult = affineApplyOp->getResult(0); // Compare the result of the apply and branch. auto comparisonOp = builder.create<CmpIOp>( - ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, + ifInst->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult, zeroConstant); - builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(), + builder.create<CondBranchOp>(ifInst->getLoc(), comparisonOp->getResult(), nextBlock, /*trueArgs*/ ArrayRef<Value *>(), elseBlock, /*falseArgs*/ ArrayRef<Value *>()); @@ -429,13 +429,13 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Recursively traverse the 'then' block. builder.setInsertionPointToEnd(thenBlock); - visitBlock(ifStmt->getThen()); + visitBlock(ifInst->getThen()); Block *lastThenBlock = builder.getInsertionBlock(); // Recursively traverse the 'else' block if present. builder.setInsertionPointToEnd(elseBlock); - if (ifStmt->hasElse()) - visitBlock(ifStmt->getElse()); + if (ifInst->hasElse()) + visitBlock(ifInst->getElse()); Block *lastElseBlock = builder.getInsertionBlock(); // Create the continuation block here so that it appears lexically after the @@ -443,9 +443,9 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // to the continuation block. Block *continuationBlock = builder.createBlock(); builder.setInsertionPointToEnd(lastThenBlock); - builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock); + builder.create<BranchOp>(ifInst->getLoc(), continuationBlock); builder.setInsertionPointToEnd(lastElseBlock); - builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock); + builder.create<BranchOp>(ifInst->getLoc(), continuationBlock); // Make sure building can continue by setting up the continuation block as the // insertion point. @@ -454,12 +454,12 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { // Entry point of the function convertor. // -// Conversion is performed by recursively visiting statements of a Function. +// Conversion is performed by recursively visiting instructions of a Function. // It reasons in terms of single-entry single-exit (SESE) regions that are not // materialized in the code. Instead, the pointer to the last block of the // region is maintained throughout the conversion as the insertion point of the // IR builder since we never change the first block after its creation. "Block" -// statements such as loops and branches create new SESE regions for their +// instructions such as loops and branches create new SESE regions for their // bodies, and surround them with additional basic blocks for the control flow. // Individual operations are simply appended to the end of the last basic block // of the current region. The SESE invariant allows us to easily handle nested @@ -484,9 +484,9 @@ Function *FunctionConverter::convert(Function *mlFunc) { valueRemapping.insert(std::make_pair(mlArgument, cfgArgument)); } - // Convert statements in order. - for (auto &stmt : *mlFunc->getBody()) { - visit(&stmt); + // Convert instructions in order. + for (auto &inst : *mlFunc->getBody()) { + visit(&inst); } return cfgFunc; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 69344819ed8..bc7f31f0434 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -49,7 +49,7 @@ namespace { /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> { +struct DmaGeneration : public FunctionPass, InstWalker<DmaGeneration> { explicit DmaGeneration(unsigned slowMemorySpace = 0, unsigned fastMemorySpaceArg = 1, int minDmaTransferSize = 1024) @@ -65,10 +65,10 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> { // Not applicable to CFG functions. PassResult runOnCFGFunction(Function *f) override { return success(); } PassResult runOnMLFunction(Function *f) override; - void runOnForStmt(ForStmt *forStmt); + void runOnForInst(ForInst *forInst); - void visitOperationInst(OperationInst *opStmt); - bool generateDma(const MemRefRegion ®ion, ForStmt *forStmt, + void visitOperationInst(OperationInst *opInst); + bool generateDma(const MemRefRegion ®ion, ForInst *forInst, uint64_t *sizeInBytes); // List of memory regions to DMA for. @@ -108,11 +108,11 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. -void DmaGeneration::visitOperationInst(OperationInst *opStmt) { - if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { +void DmaGeneration::visitOperationInst(OperationInst *opInst) { + if (auto loadOp = opInst->dyn_cast<LoadOp>()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) { + } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { @@ -125,7 +125,7 @@ void DmaGeneration::visitOperationInst(OperationInst *opStmt) { // This way we would be allocating O(num of memref's) sets instead of // O(num of load/store op's). auto region = std::make_unique<MemRefRegion>(); - if (!getMemRefRegion(opStmt, dmaDepth, region.get())) { + if (!getMemRefRegion(opInst, dmaDepth, region.get())) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n"); return; } @@ -170,19 +170,19 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, // Creates a buffer in the faster memory space for the specified region; // generates a DMA from the lower memory space to this one, and replaces all // loads to load from that buffer. Returns true if DMAs are generated. -bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, +bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, uint64_t *sizeInBytes) { // DMAs for read regions are going to be inserted just before the for loop. - FuncBuilder prologue(forStmt); + FuncBuilder prologue(forInst); // DMAs for write regions are going to be inserted just after the for loop. - FuncBuilder epilogue(forStmt->getBlock(), - std::next(Block::iterator(forStmt))); + FuncBuilder epilogue(forInst->getBlock(), + std::next(Block::iterator(forInst))); FuncBuilder *b = region.isWrite() ? &epilogue : &prologue; // Builder to create constants at the top level. - FuncBuilder top(forStmt->getFunction()); + FuncBuilder top(forInst->getFunction()); - auto loc = forStmt->getLoc(); + auto loc = forInst->getLoc(); auto *memref = region.memref; auto memRefType = memref->getType().cast<MemRefType>(); @@ -285,7 +285,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: "); LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n"); - // Create the fast memory space buffer just before the 'for' statement. + // Create the fast memory space buffer just before the 'for' instruction. fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult(); // Record it. fastBufferMap[memref] = fastMemRef; @@ -361,58 +361,58 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, remapExprs.push_back(dimExpr - offsets[i]); } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); - // *Only* those uses within the body of 'forStmt' are replaced. + // *Only* those uses within the body of 'forInst' are replaced. replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domStmtFilter=*/&*forStmt->getBody()->begin()); + /*domInstFilter=*/&*forInst->getBody()->begin()); return true; } -/// Returns the nesting depth of this statement, i.e., the number of loops -/// surrounding this statement. +/// Returns the nesting depth of this instruction, i.e., the number of loops +/// surrounding this instruction. // TODO(bondhugula): move this to utilities later. -static unsigned getNestingDepth(const Statement &stmt) { - const Statement *currStmt = &stmt; +static unsigned getNestingDepth(const Instruction &inst) { + const Instruction *currInst = &inst; unsigned depth = 0; - while ((currStmt = currStmt->getParentStmt())) { - if (isa<ForStmt>(currStmt)) + while ((currInst = currInst->getParentInst())) { + if (isa<ForInst>(currInst)) depth++; } return depth; } -// TODO(bondhugula): make this run on a Block instead of a 'for' stmt. -void DmaGeneration::runOnForStmt(ForStmt *forStmt) { +// TODO(bondhugula): make this run on a Block instead of a 'for' inst. +void DmaGeneration::runOnForInst(ForInst *forInst) { // For now (for testing purposes), we'll run this on the outermost among 'for' - // stmt's with unit stride, i.e., right at the top of the tile if tiling has + // inst's with unit stride, i.e., right at the top of the tile if tiling has // been done. In the future, the DMA generation has to be done at a level // where the generated data fits in a higher level of the memory hierarchy; so // the pass has to be instantiated with additional information that we aren't // provided with at the moment. - if (forStmt->getStep() != 1) { - if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) { - runOnForStmt(innerFor); + if (forInst->getStep() != 1) { + if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) { + runOnForInst(innerFor); } return; } // DMAs will be generated for this depth, i.e., for all data accessed by this // loop. - dmaDepth = getNestingDepth(*forStmt); + dmaDepth = getNestingDepth(*forInst); regions.clear(); fastBufferMap.clear(); - // Walk this 'for' statement to gather all memory regions. - walk(forStmt); + // Walk this 'for' instruction to gather all memory regions. + walk(forInst); uint64_t totalSizeInBytes = 0; bool ret = false; for (const auto ®ion : regions) { uint64_t sizeInBytes; - bool iRet = generateDma(*region, forStmt, &sizeInBytes); + bool iRet = generateDma(*region, forInst, &sizeInBytes); if (iRet) totalSizeInBytes += sizeInBytes; ret = ret | iRet; @@ -426,9 +426,9 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { } PassResult DmaGeneration::runOnMLFunction(Function *f) { - for (auto &stmt : *f->getBody()) { - if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { - runOnForStmt(forStmt); + for (auto &inst : *f->getBody()) { + if (auto *forInst = dyn_cast<ForInst>(&inst)) { + runOnForInst(forInst); } } // This function never leaves the IR in an invalid state. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index d31337437ad..97dea753f88 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -80,20 +80,20 @@ char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt, +static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst, MemRefAccess *access) { - if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { + if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) { access->memref = loadOp->getMemRef(); - access->opStmt = loadOrStoreOpStmt; + access->opInst = loadOrStoreOpInst; auto loadMemrefType = loadOp->getMemRefType(); access->indices.reserve(loadMemrefType.getRank()); for (auto *index : loadOp->getIndices()) { access->indices.push_back(index); } } else { - assert(loadOrStoreOpStmt->isa<StoreOp>()); - auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); - access->opStmt = loadOrStoreOpStmt; + assert(loadOrStoreOpInst->isa<StoreOp>()); + auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>(); + access->opInst = loadOrStoreOpInst; access->memref = storeOp->getMemRef(); auto storeMemrefType = storeOp->getMemRefType(); access->indices.reserve(storeMemrefType.getRank()); @@ -112,24 +112,24 @@ struct FusionCandidate { MemRefAccess dstAccess; }; -static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt, - OperationInst *dstLoadOpStmt) { +static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst, + OperationInst *dstLoadOpInst) { FusionCandidate candidate; // Get store access for src loop nest. - getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess); + getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess); // Get load access for dst loop nest. - getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess); + getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess); return candidate; } -// Returns the loop depth of the loop nest surrounding 'opStmt'. -static unsigned getLoopDepth(OperationInst *opStmt) { +// Returns the loop depth of the loop nest surrounding 'opInst'. +static unsigned getLoopDepth(OperationInst *opInst) { unsigned loopDepth = 0; - auto *currStmt = opStmt->getParentStmt(); - ForStmt *currForStmt; - while (currStmt && (currForStmt = dyn_cast<ForStmt>(currStmt))) { + auto *currInst = opInst->getParentInst(); + ForInst *currForInst; + while (currInst && (currForInst = dyn_cast<ForInst>(currInst))) { ++loopDepth; - currStmt = currStmt->getParentStmt(); + currInst = currInst->getParentInst(); } return loopDepth; } @@ -137,28 +137,28 @@ static unsigned getLoopDepth(OperationInst *opStmt) { namespace { // LoopNestStateCollector walks loop nests and collects load and store -// operations, and whether or not an IfStmt was encountered in the loop nest. -class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> { +// operations, and whether or not an IfInst was encountered in the loop nest. +class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> { public: - SmallVector<ForStmt *, 4> forStmts; - SmallVector<OperationInst *, 4> loadOpStmts; - SmallVector<OperationInst *, 4> storeOpStmts; - bool hasIfStmt = false; + SmallVector<ForInst *, 4> forInsts; + SmallVector<OperationInst *, 4> loadOpInsts; + SmallVector<OperationInst *, 4> storeOpInsts; + bool hasIfInst = false; - void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } + void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } + void visitIfInst(IfInst *ifInst) { hasIfInst = true; } - void visitOperationInst(OperationInst *opStmt) { - if (opStmt->isa<LoadOp>()) - loadOpStmts.push_back(opStmt); - if (opStmt->isa<StoreOp>()) - storeOpStmts.push_back(opStmt); + void visitOperationInst(OperationInst *opInst) { + if (opInst->isa<LoadOp>()) + loadOpInsts.push_back(opInst); + if (opInst->isa<StoreOp>()) + storeOpInsts.push_back(opInst); } }; // MemRefDependenceGraph is a graph data structure where graph nodes are -// top-level statements in a Function which contain load/store ops, and edges +// top-level instructions in a Function which contain load/store ops, and edges // are memref dependences between the nodes. // TODO(andydavis) Add a depth parameter to dependence graph construction. struct MemRefDependenceGraph { @@ -170,18 +170,18 @@ public: // The unique identifier of this node in the graph. unsigned id; // The top-level statment which is (or contains) loads/stores. - Statement *stmt; + Instruction *inst; // List of load operations. SmallVector<OperationInst *, 4> loads; - // List of store op stmts. + // List of store op insts. SmallVector<OperationInst *, 4> stores; - Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} + Node(unsigned id, Instruction *inst) : id(id), inst(inst) {} // Returns the load op count for 'memref'. unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; - for (auto *loadOpStmt : loads) { - if (memref == loadOpStmt->cast<LoadOp>()->getMemRef()) + for (auto *loadOpInst : loads) { + if (memref == loadOpInst->cast<LoadOp>()->getMemRef()) ++loadOpCount; } return loadOpCount; @@ -190,8 +190,8 @@ public: // Returns the store op count for 'memref'. unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; - for (auto *storeOpStmt : stores) { - if (memref == storeOpStmt->cast<StoreOp>()->getMemRef()) + for (auto *storeOpInst : stores) { + if (memref == storeOpInst->cast<StoreOp>()->getMemRef()) ++storeOpCount; } return storeOpCount; @@ -315,10 +315,10 @@ public: void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads, const SmallVectorImpl<OperationInst *> &stores) { Node *node = getNode(id); - for (auto *loadOpStmt : loads) - node->loads.push_back(loadOpStmt); - for (auto *storeOpStmt : stores) - node->stores.push_back(storeOpStmt); + for (auto *loadOpInst : loads) + node->loads.push_back(loadOpInst); + for (auto *storeOpInst : stores) + node->stores.push_back(storeOpInst); } void print(raw_ostream &os) const { @@ -341,55 +341,55 @@ public: void dump() const { print(llvm::errs()); } }; -// Intializes the data dependence graph by walking statements in 'f'. +// Intializes the data dependence graph by walking instructions in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(Function *f) { unsigned id = 0; DenseMap<Value *, SetVector<unsigned>> memrefAccesses; - for (auto &stmt : *f->getBody()) { - if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { - // Create graph node 'id' to represent top-level 'forStmt' and record + for (auto &inst : *f->getBody()) { + if (auto *forInst = dyn_cast<ForInst>(&inst)) { + // Create graph node 'id' to represent top-level 'forInst' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walkForStmt(forStmt); - // Return false if IfStmts are found (not currently supported). - if (collector.hasIfStmt) + collector.walkForInst(forInst); + // Return false if IfInsts are found (not currently supported). + if (collector.hasIfInst) return false; - Node node(id++, &stmt); - for (auto *opStmt : collector.loadOpStmts) { - node.loads.push_back(opStmt); - auto *memref = opStmt->cast<LoadOp>()->getMemRef(); + Node node(id++, &inst); + for (auto *opInst : collector.loadOpInsts) { + node.loads.push_back(opInst); + auto *memref = opInst->cast<LoadOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); } - for (auto *opStmt : collector.storeOpStmts) { - node.stores.push_back(opStmt); - auto *memref = opStmt->cast<StoreOp>()->getMemRef(); + for (auto *opInst : collector.storeOpInsts) { + node.stores.push_back(opInst); + auto *memref = opInst->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); } nodes.insert({node.id, node}); } - if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { - if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { + if (auto *opInst = dyn_cast<OperationInst>(&inst)) { + if (auto loadOp = opInst->dyn_cast<LoadOp>()) { // Create graph node for top-level load op. - Node node(id++, &stmt); - node.loads.push_back(opStmt); - auto *memref = opStmt->cast<LoadOp>()->getMemRef(); + Node node(id++, &inst); + node.loads.push_back(opInst); + auto *memref = opInst->cast<LoadOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } - if (auto storeOp = opStmt->dyn_cast<StoreOp>()) { + if (auto storeOp = opInst->dyn_cast<StoreOp>()) { // Create graph node for top-level store op. - Node node(id++, &stmt); - node.stores.push_back(opStmt); - auto *memref = opStmt->cast<StoreOp>()->getMemRef(); + Node node(id++, &inst); + node.stores.push_back(opInst); + auto *memref = opInst->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } } - // Return false if IfStmts are found (not currently supported). - if (isa<IfStmt>(&stmt)) + // Return false if IfInsts are found (not currently supported). + if (isa<IfInst>(&inst)) return false; } @@ -421,9 +421,9 @@ bool MemRefDependenceGraph::init(Function *f) { // // *) A worklist is initialized with node ids from the dependence graph. // *) For each node id in the worklist: -// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate -// destination ForStmt into which fusion will be attempted. -// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'. +// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate +// destination ForInst into which fusion will be attempted. +// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: // *) Lookup dependent loop nests at earlier positions in the Function // which have a single store op to the same memref. @@ -434,12 +434,12 @@ bool MemRefDependenceGraph::init(Function *f) { // bounds to be functions of 'dstLoopNest' IVs and symbols. // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', // just before the dst load op user. -// *) Add the newly fused load/store operation statements to the state, +// *) Add the newly fused load/store operation instructions to the state, // and also add newly fuse load ops to 'dstLoopOps' to be considered // as fusion dst load ops in another iteration. // *) Remove old src loop nest and its associated state. // -// Given a graph where top-level statements are vertices in the set 'V' and +// Given a graph where top-level instructions are vertices in the set 'V' and // edges in the set 'E' are dependences between vertices, this algorithm // takes O(V) time for initialization, and has runtime O(V + E). // @@ -471,14 +471,14 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!isa<ForStmt>(dstNode->stmt)) + if (!isa<ForInst>(dstNode->inst)) continue; SmallVector<OperationInst *, 4> loads = dstNode->loads; while (!loads.empty()) { - auto *dstLoadOpStmt = loads.pop_back_val(); - auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef(); - // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'. + auto *dstLoadOpInst = loads.pop_back_val(); + auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef(); + // Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'. if (dstNode->getLoadOpCount(memref) != 1) continue; // Skip if no input edges along which to fuse. @@ -491,7 +491,7 @@ public: continue; auto *srcNode = mdg->getNode(srcEdge.id); // Skip if 'srcNode' is not a loop nest. - if (!isa<ForStmt>(srcNode->stmt)) + if (!isa<ForInst>(srcNode->inst)) continue; // Skip if 'srcNode' has more than one store to 'memref'. if (srcNode->getStoreOpCount(memref) != 1) @@ -508,17 +508,17 @@ public: if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId) continue; // Get unique 'srcNode' store op. - auto *srcStoreOpStmt = srcNode->stores.front(); - // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'. + auto *srcStoreOpInst = srcNode->stores.front(); + // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'. FusionCandidate candidate = - buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt); + buildFusionCandidate(srcStoreOpInst, dstLoadOpInst); // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0 ? clSrcLoopDepth - : getLoopDepth(srcStoreOpStmt); + : getLoopDepth(srcStoreOpInst); unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0 ? clDstLoopDepth - : getLoopDepth(dstLoadOpStmt); + : getLoopDepth(dstLoadOpInst); auto *sliceLoopNest = mlir::insertBackwardComputationSlice( &candidate.srcAccess, &candidate.dstAccess, srcLoopDepth, dstLoopDepth); @@ -527,19 +527,19 @@ public: mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id); // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'. LoopNestStateCollector collector; - collector.walkForStmt(sliceLoopNest); - mdg->addToNode(dstId, collector.loadOpStmts, - collector.storeOpStmts); + collector.walkForInst(sliceLoopNest); + mdg->addToNode(dstId, collector.loadOpInsts, + collector.storeOpInsts); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. - for (auto *loadOpStmt : collector.loadOpStmts) - loads.push_back(loadOpStmt); + for (auto *loadOpInst : collector.loadOpInsts) + loads.push_back(loadOpInst); // Promote single iteration loops to single IV value. - for (auto *forStmt : collector.forStmts) { - promoteIfSingleIteration(forStmt); + for (auto *forInst : collector.forInsts) { + promoteIfSingleIteration(forInst); } // Remove old src loop nest. - cast<ForStmt>(srcNode->stmt)->erase(); + cast<ForInst>(srcNode->inst)->erase(); } } } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 109953f2296..8f3be8a3d45 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -55,16 +55,16 @@ char LoopTiling::passID = 0; /// Function. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } -// Move the loop body of ForStmt 'src' from 'src' into the specified location in +// Move the loop body of ForInst 'src' from 'src' into the specified location in // destination's body. -static inline void moveLoopBody(ForStmt *src, ForStmt *dest, +static inline void moveLoopBody(ForInst *src, ForInst *dest, Block::iterator loc) { dest->getBody()->getInstructions().splice(loc, src->getBody()->getInstructions()); } -// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body. -static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { +// Move the loop body of ForInst 'src' from 'src' to the start of dest's body. +static inline void moveLoopBody(ForInst *src, ForInst *dest) { moveLoopBody(src, dest, dest->getBody()->begin()); } @@ -73,8 +73,8 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { /// depend on other dimensions. Bounds of each dimension can thus be treated /// independently, and deriving the new bounds is much simpler and faster /// than for the case of tiling arbitrary polyhedral shapes. -static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops, - ArrayRef<ForStmt *> newLoops, +static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, + ArrayRef<ForInst *> newLoops, ArrayRef<unsigned> tileSizes) { assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); @@ -138,27 +138,27 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops, /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, +UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, ArrayRef<unsigned> tileSizes) { assert(!band.empty()); assert(band.size() == tileSizes.size()); - // Check if the supplied for stmt's are all successively nested. + // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i]->getParentStmt() == band[i - 1]); + assert(band[i]->getParentInst() == band[i - 1]); } auto origLoops = band; - ForStmt *rootForStmt = origLoops[0]; - auto loc = rootForStmt->getLoc(); + ForInst *rootForInst = origLoops[0]; + auto loc = rootForInst->getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector<ForStmt *, 12> newLoops(2 * width); - ForStmt *innermostPointLoop; + SmallVector<ForInst *, 12> newLoops(2 * width); + ForInst *innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootForStmt; + auto *topLoop = rootForInst; // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { @@ -195,7 +195,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootForStmt->emitError("tiled code generation unimplemented for the" + rootForInst->emitError("tiled code generation unimplemented for the" "non-hyperrectangular case"); return UtilResult::Failure; } @@ -207,7 +207,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, } // Erase the old loop nest. - rootForStmt->erase(); + rootForInst->erase(); return UtilResult::Success; } @@ -216,28 +216,28 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. static void getTileableBands(Function *f, - std::vector<SmallVector<ForStmt *, 6>> *bands) { - // Get maximal perfect nest of 'for' stmts starting from root (inclusive). - auto getMaximalPerfectLoopNest = [&](ForStmt *root) { - SmallVector<ForStmt *, 6> band; - ForStmt *currStmt = root; + std::vector<SmallVector<ForInst *, 6>> *bands) { + // Get maximal perfect nest of 'for' insts starting from root (inclusive). + auto getMaximalPerfectLoopNest = [&](ForInst *root) { + SmallVector<ForInst *, 6> band; + ForInst *currInst = root; do { - band.push_back(currStmt); - } while (currStmt->getBody()->getInstructions().size() == 1 && - (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin()))); + band.push_back(currInst); + } while (currInst->getBody()->getInstructions().size() == 1 && + (currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin()))); bands->push_back(band); }; - for (auto &stmt : *f->getBody()) { - auto *forStmt = dyn_cast<ForStmt>(&stmt); - if (!forStmt) + for (auto &inst : *f->getBody()) { + auto *forInst = dyn_cast<ForInst>(&inst); + if (!forInst) continue; - getMaximalPerfectLoopNest(forStmt); + getMaximalPerfectLoopNest(forInst); } } PassResult LoopTiling::runOnMLFunction(Function *f) { - std::vector<SmallVector<ForStmt *, 6>> bands; + std::vector<SmallVector<ForInst *, 6>> bands; getTileableBands(f, &bands); // Temporary tile sizes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 15ea0f841cc..69431bf6349 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -62,18 +62,18 @@ struct LoopUnroll : public FunctionPass { const Optional<bool> unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function<unsigned(const ForStmt &)> getUnrollFactor; + const std::function<unsigned(const ForInst &)> getUnrollFactor; explicit LoopUnroll( Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, - const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr) + const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnMLFunction(Function *f) override; - /// Unroll this for stmt. Returns false if nothing was done. - bool runOnForStmt(ForStmt *forStmt); + /// Unroll this for inst. Returns false if nothing was done. + bool runOnForInst(ForInst *forInst); static const unsigned kDefaultUnrollFactor = 4; @@ -85,13 +85,13 @@ char LoopUnroll::passID = 0; PassResult LoopUnroll::runOnMLFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> { + class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> { public: // Store innermost loops as we walk. - std::vector<ForStmt *> loops; + std::vector<ForInst *> loops; // This method specialized to encode custom return logic. - using InstListType = llvm::iplist<Statement>; + using InstListType = llvm::iplist<Instruction>; bool walkPostOrder(InstListType::iterator Start, InstListType::iterator End) { bool hasInnerLoops = false; @@ -103,43 +103,43 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { return hasInnerLoops; } - bool walkForStmtPostOrder(ForStmt *forStmt) { + bool walkForInstPostOrder(ForInst *forInst) { bool hasInnerLoops = - walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end()); + walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end()); if (!hasInnerLoops) - loops.push_back(forStmt); + loops.push_back(forInst); return true; } - bool walkIfStmtPostOrder(IfStmt *ifStmt) { + bool walkIfInstPostOrder(IfInst *ifInst) { bool hasInnerLoops = - walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end()); - if (ifStmt->hasElse()) + walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); + if (ifInst->hasElse()) hasInnerLoops |= - walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end()); + walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); return hasInnerLoops; } - bool visitOperationInst(OperationInst *opStmt) { return false; } + bool visitOperationInst(OperationInst *opInst) { return false; } // FIXME: can't use base class method for this because that in turn would // need to use the derived class method above. CRTP doesn't allow it, and // the compiler error resulting from it is also misleading. - using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder; + using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder; }; // Gathers all loops with trip count <= minTripCount. - class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> { + class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> { public: // Store short loops as we walk. - std::vector<ForStmt *> loops; + std::vector<ForInst *> loops; const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitForStmt(ForStmt *forStmt) { - Optional<uint64_t> tripCount = getConstantTripCount(*forStmt); + void visitForInst(ForInst *forInst) { + Optional<uint64_t> tripCount = getConstantTripCount(*forInst); if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) - loops.push_back(forStmt); + loops.push_back(forInst); } }; @@ -151,8 +151,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { // ones). slg.walkPostOrder(f); auto &loops = slg.loops; - for (auto *forStmt : loops) - loopUnrollFull(forStmt); + for (auto *forInst : loops) + loopUnrollFull(forInst); return success(); } @@ -167,8 +167,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { if (loops.empty()) break; bool unrolled = false; - for (auto *forStmt : loops) - unrolled |= runOnForStmt(forStmt); + for (auto *forInst : loops) + unrolled |= runOnForInst(forInst); if (!unrolled) // Break out if nothing was unrolled. break; @@ -176,31 +176,31 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { return success(); } -/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false +/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. -bool LoopUnroll::runOnForStmt(ForStmt *forStmt) { +bool LoopUnroll::runOnForInst(ForInst *forInst) { // Use the function callback if one was provided. if (getUnrollFactor) { - return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt)); + return loopUnrollByFactor(forInst, getUnrollFactor(*forInst)); } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) - return loopUnrollByFactor(forStmt, unrollFactor.getValue()); + return loopUnrollByFactor(forInst, unrollFactor.getValue()); // Unroll by the command line factor if one was specified. if (clUnrollFactor.getNumOccurrences() > 0) - return loopUnrollByFactor(forStmt, clUnrollFactor); + return loopUnrollByFactor(forInst, clUnrollFactor); // Unroll completely if full loop unroll was specified. if (clUnrollFull.getNumOccurrences() > 0 || (unrollFull.hasValue() && unrollFull.getValue())) - return loopUnrollFull(forStmt); + return loopUnrollFull(forInst); // Unroll by four otherwise. - return loopUnrollByFactor(forStmt, kDefaultUnrollFactor); + return loopUnrollByFactor(forInst, kDefaultUnrollFactor); } FunctionPass *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function<unsigned(const ForStmt &)> &getUnrollFactor) { + const std::function<unsigned(const ForInst &)> &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 60e8d154f98..f59659cf234 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -40,7 +40,7 @@ // S6(i+1); // // Note: 'if/else' blocks are not jammed. So, if there are loops inside if -// stmt's, bodies of those loops will not be jammed. +// inst's, bodies of those loops will not be jammed. //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" @@ -49,7 +49,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -75,7 +75,7 @@ struct LoopUnrollAndJam : public FunctionPass { unrollJamFactor(unrollJamFactor) {} PassResult runOnMLFunction(Function *f) override; - bool runOnForStmt(ForStmt *forStmt); + bool runOnForInst(ForInst *forInst); static char passID; }; @@ -90,79 +90,79 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is - // unroll-and-jammed by this pass. However, runOnForStmt can be called on any - // for Stmt. - auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin()); - if (!forStmt) + // unroll-and-jammed by this pass. However, runOnForInst can be called on any + // for Inst. + auto *forInst = dyn_cast<ForInst>(f->getBody()->begin()); + if (!forInst) return success(); - runOnForStmt(forStmt); + runOnForInst(forInst); return success(); } -/// Unroll and jam a 'for' stmt. Default unroll jam factor is +/// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. -bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) { +bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) - return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue()); + return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue()); // Otherwise, unroll jam by the command-line factor if one was specified. if (clUnrollJamFactor.getNumOccurrences() > 0) - return loopUnrollJamByFactor(forStmt, clUnrollJamFactor); + return loopUnrollJamByFactor(forInst, clUnrollJamFactor); // Unroll and jam by four otherwise. - return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor); + return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor); } -bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); +bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forStmt, unrollJamFactor); + return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forInst, unrollJamFactor); } /// Unrolls and jams this loop by the specified factor. -bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { - // Gathers all maximal sub-blocks of statements that do not themselves include - // a for stmt (a statement could have a descendant for stmt though in its - // tree). - class JamBlockGatherer : public StmtWalker<JamBlockGatherer> { +bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { + // Gathers all maximal sub-blocks of instructions that do not themselves + // include a for inst (a instruction could have a descendant for inst though + // in its tree). + class JamBlockGatherer : public InstWalker<JamBlockGatherer> { public: - using InstListType = llvm::iplist<Statement>; + using InstListType = llvm::iplist<Instruction>; - // Store iterators to the first and last stmt of each sub-block found. + // Store iterators to the first and last inst of each sub-block found. std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; // This is a linear time walk. void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !isa<ForStmt>(it)) + while (it != End && !isa<ForInst>(it)) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); - // Process all for stmts that appear next. - while (it != End && isa<ForStmt>(it)) - walkForStmt(cast<ForStmt>(it++)); + // Process all for insts that appear next. + while (it != End && isa<ForInst>(it)) + walkForInst(cast<ForInst>(it++)); } } }; assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forStmt->getBody()->empty()) + if (unrollJamFactor == 1 || forInst->getBody()->empty()) return false; - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); if (!mayBeConstantTripCount.hasValue() && - getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0) + getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0) return false; - auto lbMap = forStmt->getLowerBoundMap(); - auto ubMap = forStmt->getUpperBoundMap(); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -173,7 +173,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different sets of operands. - if (!forStmt->matchingBoundOperandList()) + if (!forInst->matchingBoundOperandList()) return false; // If the trip count is lower than the unroll jam factor, no unroll jam. @@ -184,7 +184,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkForStmt(forStmt); + jbg.walkForInst(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of @@ -192,24 +192,24 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { DenseMap<const Value *, Value *> operandMap; - // Insert the cleanup loop right after 'forStmt'. - FuncBuilder builder(forStmt->getBlock(), - std::next(Block::iterator(forStmt))); - auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); - cleanupForStmt->setLowerBoundMap( - getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder)); + // Insert the cleanup loop right after 'forInst'. + FuncBuilder builder(forInst->getBlock(), + std::next(Block::iterator(forInst))); + auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap)); + cleanupForInst->setLowerBoundMap( + getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder)); // The upper bound needs to be adjusted. - forStmt->setUpperBoundMap( - getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder)); + forInst->setUpperBoundMap( + getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder)); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForStmt); + promoteIfSingleIteration(cleanupForInst); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forStmt->getStep(); - forStmt->setStep(step * unrollJamFactor); + int64_t step = forInst->getStep(); + forInst->setStep(step * unrollJamFactor); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of @@ -222,14 +222,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forStmt->use_empty()) { + if (!forInst->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto *ivUnroll = - builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) + builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst) ->getResult(0); - operandMapping[forStmt] = ivUnroll; + operandMapping[forInst] = ivUnroll; } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { @@ -239,7 +239,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forStmt); + promoteIfSingleIteration(forInst); return true; } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 51577009abb..bcb2abf11dd 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -110,7 +110,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // Get the ML function builder. // We need access to the Function builder stored internally in the // MLFunctionLoweringRewriter general rewriting API does not provide - // ML-specific functions (ForStmt and Block manipulation). While we could + // ML-specific functions (ForInst and Block manipulation). While we could // forward them or define a whole rewriting chain based on MLFunctionBuilder // instead of Builer, the code for it would be duplicate boilerplate. As we // go towards unifying ML and CFG functions, this separation will disappear. @@ -137,13 +137,13 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // memory. // TODO(ntv): Handle broadcast / slice properly. auto permutationMap = transfer->getPermutationMap(); - SetVector<ForStmt *> loops; + SetVector<ForInst *> loops; SmallVector<Value *, 8> accessIndices(transfer->getIndices()); for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) { auto composed = composeWithUnboundedMap( getAffineDimExpr(it.index(), b.getContext()), permutationMap); - auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value()); - loops.insert(forStmt); + auto *forInst = b.createFor(transfer->getLoc(), 0, it.value()); + loops.insert(forInst); // Setting the insertion point to the innermost loop achieves nesting. b.setInsertionPointToStart(loops.back()->getBody()); if (composed == getAffineConstantExpr(0, b.getContext())) { @@ -196,7 +196,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, b.setInsertionPoint(transfer->getInstruction()); b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc); - // 7. It is now safe to erase the statement. + // 7. It is now safe to erase the instruction. rewriter->replaceOp(transfer->getInstruction(), newResults); } @@ -213,7 +213,7 @@ public: return matchFailure(); } - void rewriteOpStmt(OperationInst *op, + void rewriteOpInst(OperationInst *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr<PatternState> opState, MLFuncLoweringRewriter *rewriter) const override { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index a30e8164760..37f0f571a0f 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -73,7 +73,7 @@ /// Implementation details /// ====================== /// The current decisions made by the super-vectorization pass guarantee that -/// use-def chains do not escape an enclosing vectorized ForStmt. In other +/// use-def chains do not escape an enclosing vectorized ForInst. In other /// words, this pass operates on a scoped program slice. Furthermore, since we /// do not vectorize in the presence of conditionals for now, sliced chains are /// guaranteed not to escape the innermost scope, which has to be either the top @@ -247,7 +247,7 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex, } static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, +instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately @@ -263,10 +263,10 @@ static Value *substitute(Value *v, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap) { auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { - auto *opStmt = v->getDefiningInst(); - if (opStmt->isa<ConstantOp>()) { - FuncBuilder b(opStmt); - auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap); + auto *opInst = v->getDefiningInst(); + if (opInst->isa<ConstantOp>()) { + FuncBuilder b(opInst); + auto *inst = instantiate(&b, opInst, hwVectorType, substitutionsMap); auto res = substitutionsMap->insert(std::make_pair(v, inst->getResult(0))); assert(res.second && "Insertion failed"); @@ -285,7 +285,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// /// The general problem this pass solves is as follows: /// Assume a vector_transfer operation at the super-vector granularity that has -/// `l` enclosing loops (ForStmt). Assume the vector transfer operation operates +/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates /// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of /// rank `h`. /// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the @@ -347,7 +347,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, SmallVector<AffineExpr, 8> affineExprs; // TODO(ntv): support a concrete map and composition. unsigned i = 0; - // The first numMemRefIndices correspond to ForStmt that have not been + // The first numMemRefIndices correspond to ForInst that have not been // vectorized, the transformation is the identity on those. for (i = 0; i < numMemRefIndices; ++i) { auto d_i = b->getAffineDimExpr(i); @@ -384,9 +384,9 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector<NamedAttribute, 1> -materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { +materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { SmallVector<NamedAttribute, 1> res; - for (auto a : opStmt->getAttrs()) { + for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) { auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue()); res.push_back(NamedAttribute(a.first, attr)); @@ -397,7 +397,7 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { return res; } -/// Creates an instantiated version of `opStmt`. +/// Creates an instantiated version of `opInst`. /// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no /// affine reindexing. Just substitute their Value operands and be done. For /// this case the actual instance is irrelevant. Just use the values in @@ -405,11 +405,11 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { /// /// If the underlying substitution fails, this fails too and returns nullptr. static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, +instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap) { - assert(!opStmt->isa<VectorTransferReadOp>() && + assert(!opInst->isa<VectorTransferReadOp>() && "Should call the function specialized for VectorTransferReadOp"); - assert(!opStmt->isa<VectorTransferWriteOp>() && + assert(!opInst->isa<VectorTransferWriteOp>() && "Should call the function specialized for VectorTransferWriteOp"); bool fail = false; auto operands = map( @@ -419,14 +419,14 @@ instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, fail |= !res; return res; }, - opStmt->getOperands()); + opInst->getOperands()); if (fail) return nullptr; - auto attrs = materializeAttributes(opStmt, hwVectorType); + auto attrs = materializeAttributes(opInst, hwVectorType); - OperationState state(b->getContext(), opStmt->getLoc(), - opStmt->getName().getStringRef(), operands, + OperationState state(b->getContext(), opInst->getLoc(), + opInst->getName().getStringRef(), operands, {hwVectorType}, attrs); return b->createOperation(state); } @@ -511,11 +511,11 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write, return cloned->getInstruction(); } -/// Returns `true` if stmt instance is properly cloned and inserted, false +/// Returns `true` if inst instance is properly cloned and inserted, false /// otherwise. /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of /// super-vector type to hw vector type. -/// A cloned instance of `stmt` is formed as follows: +/// A cloned instance of `inst` is formed as follows: /// 1. vector_transfer_read: the return `superVectorType` is replaced by /// `hwVectorType`. Additionally, affine indices are reindexed with /// `reindexAffineIndices` using `hwVectorInstance` and vector type @@ -532,24 +532,24 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write, /// possible. /// /// Returns true on failure. -static bool instantiateMaterialization(Statement *stmt, +static bool instantiateMaterialization(Instruction *inst, MaterializationState *state) { - LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt); + LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst); - if (isa<ForStmt>(stmt)) - return stmt->emitError("NYI path ForStmt"); + if (isa<ForInst>(inst)) + return inst->emitError("NYI path ForInst"); - if (isa<IfStmt>(stmt)) - return stmt->emitError("NYI path IfStmt"); + if (isa<IfInst>(inst)) + return inst->emitError("NYI path IfInst"); // Create a builder here for unroll-and-jam effects. - FuncBuilder b(stmt); - auto *opStmt = cast<OperationInst>(stmt); - if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) { + FuncBuilder b(inst); + auto *opInst = cast<OperationInst>(inst); + if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) { instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return false; - } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) { + } else if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); state->substitutionsMap->insert( @@ -559,17 +559,17 @@ static bool instantiateMaterialization(Statement *stmt, // The only op with 0 results reaching this point must, by construction, be // VectorTransferWriteOps and have been caught above. Ops with >= 2 results // are not yet supported. So just support 1 result. - if (opStmt->getNumResults() != 1) - return stmt->emitError("NYI: ops with != 1 results"); - if (opStmt->getResult(0)->getType() != state->superVectorType) - return stmt->emitError("Op does not return a supervector."); + if (opInst->getNumResults() != 1) + return inst->emitError("NYI: ops with != 1 results"); + if (opInst->getResult(0)->getType() != state->superVectorType) + return inst->emitError("Op does not return a supervector."); auto *clone = - instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap); + instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap); if (!clone) { return true; } state->substitutionsMap->insert( - std::make_pair(opStmt->getResult(0), clone->getResult(0))); + std::make_pair(opInst->getResult(0), clone->getResult(0))); return false; } @@ -595,7 +595,7 @@ static bool instantiateMaterialization(Statement *stmt, /// TODO(ntv): full loops + materialized allocs. /// TODO(ntv): partial unrolling + materialized allocs. static bool emitSlice(MaterializationState *state, - SetVector<Statement *> *slice) { + SetVector<Instruction *> *slice) { auto ratio = shapeRatio(state->superVectorType, state->hwVectorType); assert(ratio.hasValue() && "ratio of super-vector to HW-vector shape is not integral"); @@ -610,10 +610,10 @@ static bool emitSlice(MaterializationState *state, DenseMap<const Value *, Value *> substitutionMap; scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. - for (auto *stmt : *slice) { - auto fail = instantiateMaterialization(stmt, &scopedState); + for (auto *inst : *slice) { + auto fail = instantiateMaterialization(inst, &scopedState); if (fail) { - stmt->emitError("Unhandled super-vector materialization failure"); + inst->emitError("Unhandled super-vector materialization failure"); return true; } } @@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state, /// Materializes super-vector types into concrete hw vector types as follows: /// 1. start from super-vector terminators (current vector_transfer_write /// ops); -/// 2. collect all the statements that can be reached by transitive use-defs +/// 2. collect all the instructions that can be reached by transitive use-defs /// chains; /// 3. get the superVectorType for this particular terminator and the /// corresponding hardware vector type (for now limited to F32) @@ -647,13 +647,13 @@ static bool emitSlice(MaterializationState *state, /// Notes /// ===== /// The `slice` is sorted in topological order by construction. -/// Additionally, this set is limited to statements in the same lexical scope +/// Additionally, this set is limited to instructions in the same lexical scope /// because we currently disallow vectorization of defs that come from another /// scope. static bool materialize(Function *f, const SetVector<OperationInst *> &terminators, MaterializationState *state) { - DenseSet<Statement *> seen; + DenseSet<Instruction *> seen; for (auto *term : terminators) { // Short-circuit test, a given terminator may have been reached by some // other previous transitive use-def chains. @@ -668,16 +668,16 @@ static bool materialize(Function *f, // current enclosing scope of the terminator. See the top of the function // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. - auto *enclosingScope = term->getParentStmt(); - auto keepIfInSameScope = [enclosingScope](Statement *stmt) { - assert(stmt && "NULL stmt"); + auto *enclosingScope = term->getParentInst(); + auto keepIfInSameScope = [enclosingScope](Instruction *inst) { + assert(inst && "NULL inst"); if (!enclosingScope) { // by construction, everyone is always under the top scope (null scope). return true; } - return properlyDominates(*enclosingScope, *stmt); + return properlyDominates(*enclosingScope, *inst); }; - SetVector<Statement *> slice = + SetVector<Instruction *> slice = getSlice(term, keepIfInSameScope, keepIfInSameScope); assert(!slice.empty()); @@ -722,12 +722,12 @@ PassResult MaterializeVectorsPass::runOnMLFunction(Function *f) { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. - auto filter = [subVectorType](const Statement &stmt) { - const auto &opStmt = cast<OperationInst>(stmt); - if (!opStmt.isa<VectorTransferWriteOp>()) { + auto filter = [subVectorType](const Instruction &inst) { + const auto &opInst = cast<OperationInst>(inst); + if (!opInst.isa<VectorTransferWriteOp>()) { return false; } - return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType); + return matcher::operatesOnStrictSuperVectors(opInst, subVectorType); }; auto pat = Op(filter); auto matches = pat.match(f); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index c8a6ced4ed1..debaac3a33c 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -39,14 +39,14 @@ using namespace mlir; namespace { struct PipelineDataTransfer : public FunctionPass, - StmtWalker<PipelineDataTransfer> { + InstWalker<PipelineDataTransfer> { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnMLFunction(Function *f) override; - PassResult runOnForStmt(ForStmt *forStmt); + PassResult runOnForInst(ForInst *forInst); - // Collect all 'for' statements. - void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } - std::vector<ForStmt *> forStmts; + // Collect all 'for' instructions. + void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } + std::vector<ForInst *> forInsts; static char passID; }; @@ -61,26 +61,26 @@ FunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// Returns the position of the tag memref operand given a DMA statement. +// Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationInst &dmaStmt) { - assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>()); - if (dmaStmt.isa<DmaStartOp>()) { +static unsigned getTagMemRefPos(const OperationInst &dmaInst) { + assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>()); + if (dmaInst.isa<DmaStartOp>()) { // Second to last operand. - return dmaStmt.getNumOperands() - 2; + return dmaInst.getNumOperands() - 2; } - // First operand for a dma finish statement. + // First operand for a dma finish instruction. return 0; } -/// Doubles the buffer of the supplied memref on the specified 'for' statement +/// Doubles the buffer of the supplied memref on the specified 'for' instruction /// by adding a leading dimension of size two to the memref. Replaces all uses /// of the old memref by the new one while indexing the newly added dimension by -/// the loop IV of the specified 'for' statement modulo 2. Returns false if such -/// a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { - auto *forBody = forStmt->getBody(); +/// the loop IV of the specified 'for' instruction modulo 2. Returns false if +/// such a replacement cannot be performed. +static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { + auto *forBody = forInst->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -101,33 +101,33 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { auto newMemRefType = doubleShape(oldMemRefType); // Put together alloc operands for the dynamic dimensions of the memref. - FuncBuilder bOuter(forStmt); + FuncBuilder bOuter(forInst); SmallVector<Value *, 4> allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) - allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef, + allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++)); } - // Create and place the alloc right before the 'for' statement. + // Create and place the alloc right before the 'for' instruction. // TODO(mlir-team): we are assuming scoped allocation here, and aren't // inserting a dealloc -- this isn't the right thing. Value *newMemRef = - bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands); + bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {}); auto ivModTwoOp = - bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt); + bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst); - // replaceAllMemRefUsesWith will always succeed unless the forStmt body has + // replaceAllMemRefUsesWith will always succeed unless the forInst body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0), AffineMap::Null(), {}, - &*forStmt->getBody()->begin())) { + &*forInst->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -139,15 +139,15 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { /// Returns success if the IR is in a valid state. PassResult PipelineDataTransfer::runOnMLFunction(Function *f) { // Do a post order walk so that inner loop DMAs are processed first. This is - // necessary since 'for' statements nested within would otherwise become + // necessary since 'for' instructions nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). - forStmts.clear(); + forInsts.clear(); walkPostOrder(f); bool ret = false; - for (auto *forStmt : forStmts) { - ret = ret | runOnForStmt(forStmt); + for (auto *forInst : forInsts) { + ret = ret | runOnForInst(forInst); } return ret ? failure() : success(); } @@ -176,36 +176,36 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp, return true; } -// Identify matching DMA start/finish statements to overlap computation with. -static void findMatchingStartFinishStmts( - ForStmt *forStmt, +// Identify matching DMA start/finish instructions to overlap computation with. +static void findMatchingStartFinishInsts( + ForInst *forInst, SmallVectorImpl<std::pair<OperationInst *, OperationInst *>> &startWaitPairs) { - // Collect outgoing DMA statements - needed to check for dependences below. + // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps; - for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast<OperationInst>(&stmt); - if (!opStmt) + for (auto &inst : *forInst->getBody()) { + auto *opInst = dyn_cast<OperationInst>(&inst); + if (!opInst) continue; OpPointer<DmaStartOp> dmaStartOp; - if ((dmaStartOp = opStmt->dyn_cast<DmaStartOp>()) && + if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } - SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts; - for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast<OperationInst>(&stmt); - if (!opStmt) + SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts; + for (auto &inst : *forInst->getBody()) { + auto *opInst = dyn_cast<OperationInst>(&inst); + if (!opInst) continue; - // Collect DMA finish statements. - if (opStmt->isa<DmaWaitOp>()) { - dmaFinishStmts.push_back(opStmt); + // Collect DMA finish instructions. + if (opInst->isa<DmaWaitOp>()) { + dmaFinishInsts.push_back(opInst); continue; } OpPointer<DmaStartOp> dmaStartOp; - if (!(dmaStartOp = opStmt->dyn_cast<DmaStartOp>())) + if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>())) continue; // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. @@ -227,7 +227,7 @@ static void findMatchingStartFinishStmts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { + if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -235,15 +235,15 @@ static void findMatchingStartFinishStmts( } } if (!escapingUses) - dmaStartStmts.push_back(opStmt); + dmaStartInsts.push_back(opInst); } - // For each start statement, we look for a matching finish statement. - for (auto *dmaStartStmt : dmaStartStmts) { - for (auto *dmaFinishStmt : dmaFinishStmts) { - if (checkTagMatch(dmaStartStmt->cast<DmaStartOp>(), - dmaFinishStmt->cast<DmaWaitOp>())) { - startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt}); + // For each start instruction, we look for a matching finish instruction. + for (auto *dmaStartInst : dmaStartInsts) { + for (auto *dmaFinishInst : dmaFinishInsts) { + if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(), + dmaFinishInst->cast<DmaWaitOp>())) { + startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); break; } } @@ -251,17 +251,17 @@ static void findMatchingStartFinishStmts( } /// Overlap DMA transfers with computation in this loop. If successful, -/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { - auto mayBeConstTripCount = getConstantTripCount(*forStmt); +PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { + auto mayBeConstTripCount = getConstantTripCount(*forInst); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return success(); } SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs; - findMatchingStartFinishStmts(forStmt, startWaitPairs); + findMatchingStartFinishInsts(forInst, startWaitPairs); if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); @@ -269,22 +269,22 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } // Double the buffers for the higher memory space memref's. - // Identify memref's to replace by scanning through all DMA start statements. - // A DMA start statement has two memref's - the one from the higher level of - // memory hierarchy is the one to double buffer. + // Identify memref's to replace by scanning through all DMA start + // instructions. A DMA start instruction has two memref's - the one from the + // higher level of memory hierarchy is the one to double buffer. // TODO(bondhugula): check whether double-buffering is even necessary. // TODO(bondhugula): make this work with different layouts: assuming here that // the dimension we are adding here for the double buffering is the outermost // dimension. for (auto &pair : startWaitPairs) { - auto *dmaStartStmt = pair.first; - Value *oldMemRef = dmaStartStmt->getOperand( - dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()); - if (!doubleBuffer(oldMemRef, forStmt)) { + auto *dmaStartInst = pair.first; + Value *oldMemRef = dmaStartInst->getOperand( + dmaStartInst->cast<DmaStartOp>()->getFasterMemPos()); + if (!doubleBuffer(oldMemRef, forInst)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); - LLVM_DEBUG(dmaStartStmt->dump()); + LLVM_DEBUG(dmaStartInst->dump()); // IR still in a valid state. return success(); } @@ -293,80 +293,80 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // operation could have been used on it if it was dynamically shaped in // order to create the double buffer above) if (oldMemRef->use_empty()) - if (auto *allocStmt = oldMemRef->getDefiningInst()) - allocStmt->erase(); + if (auto *allocInst = oldMemRef->getDefiningInst()) + allocInst->erase(); } // Double the buffers for tag memrefs. for (auto &pair : startWaitPairs) { - auto *dmaFinishStmt = pair.second; + auto *dmaFinishInst = pair.second; Value *oldTagMemRef = - dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)); - if (!doubleBuffer(oldTagMemRef, forStmt)) { + dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); + if (!doubleBuffer(oldTagMemRef, forInst)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); } // If the old tag has no more uses, remove its 'dead' alloc if it was // alloc'ed. if (oldTagMemRef->use_empty()) - if (auto *allocStmt = oldTagMemRef->getDefiningInst()) - allocStmt->erase(); + if (auto *allocInst = oldTagMemRef->getDefiningInst()) + allocInst->erase(); } - // Double buffering would have invalidated all the old DMA start/wait stmts. + // Double buffering would have invalidated all the old DMA start/wait insts. startWaitPairs.clear(); - findMatchingStartFinishStmts(forStmt, startWaitPairs); + findMatchingStartFinishInsts(forInst, startWaitPairs); - // Store shift for statement for later lookup for AffineApplyOp's. - DenseMap<const Statement *, unsigned> stmtShiftMap; + // Store shift for instruction for later lookup for AffineApplyOp's. + DenseMap<const Instruction *, unsigned> instShiftMap; for (auto &pair : startWaitPairs) { - auto *dmaStartStmt = pair.first; - assert(dmaStartStmt->isa<DmaStartOp>()); - stmtShiftMap[dmaStartStmt] = 0; - // Set shifts for DMA start stmt's affine operand computation slices to 0. - if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) { - stmtShiftMap[slice] = 0; + auto *dmaStartInst = pair.first; + assert(dmaStartInst->isa<DmaStartOp>()); + instShiftMap[dmaStartInst] = 0; + // Set shifts for DMA start inst's affine operand computation slices to 0. + if (auto *slice = mlir::createAffineComputationSlice(dmaStartInst)) { + instShiftMap[slice] = 0; } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector<OperationInst *, 4> affineApplyStmts; - SmallVector<Value *, 4> operands(dmaStartStmt->getOperands()); - getReachableAffineApplyOps(operands, affineApplyStmts); - for (const auto *stmt : affineApplyStmts) { - stmtShiftMap[stmt] = 0; + SmallVector<OperationInst *, 4> affineApplyInsts; + SmallVector<Value *, 4> operands(dmaStartInst->getOperands()); + getReachableAffineApplyOps(operands, affineApplyInsts); + for (const auto *inst : affineApplyInsts) { + instShiftMap[inst] = 0; } } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &stmt : *forStmt->getBody()) { - if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) { - stmtShiftMap[&stmt] = 1; + for (const auto &inst : *forInst->getBody()) { + if (instShiftMap.find(&inst) == instShiftMap.end()) { + instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector<uint64_t> shifts(forStmt->getBody()->getInstructions().size()); + std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size()); unsigned s = 0; - for (auto &stmt : *forStmt->getBody()) { - assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end()); - shifts[s++] = stmtShiftMap[&stmt]; + for (auto &inst : *forInst->getBody()) { + assert(instShiftMap.find(&inst) != instShiftMap.end()); + shifts[s++] = instShiftMap[&inst]; LLVM_DEBUG( - // Tagging statements with shifts for debugging purposes. - if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { - FuncBuilder b(opStmt); - opStmt->setAttr(b.getIdentifier("shift"), + // Tagging instructions with shifts for debugging purposes. + if (auto *opInst = dyn_cast<OperationInst>(&inst)) { + FuncBuilder b(opInst); + opInst->setAttr(b.getIdentifier("shift"), b.getI64IntegerAttr(shifts[s - 1])); }); } - if (!isStmtwiseShiftValid(*forStmt, shifts)) { + if (!isInstwiseShiftValid(*forInst, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); return success(); } - if (stmtBodySkew(forStmt, shifts)) { - LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";); + if (instBodySkew(forInst, shifts)) { + LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); return success(); } diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index 853a814e516..2a643eb690a 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -21,7 +21,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Function.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" @@ -32,12 +32,12 @@ using llvm::report_fatal_error; namespace { -/// Simplifies all affine expressions appearing in the operation statements of +/// Simplifies all affine expressions appearing in the operation instructions of /// the Function. This is mainly to test the simplifyAffineExpr method. // TODO(someone): Gradually, extend this to all affine map references found in // ML functions and CFG functions. struct SimplifyAffineStructures : public FunctionPass, - StmtWalker<SimplifyAffineStructures> { + InstWalker<SimplifyAffineStructures> { explicit SimplifyAffineStructures() : FunctionPass(&SimplifyAffineStructures::passID) {} @@ -46,8 +46,8 @@ struct SimplifyAffineStructures : public FunctionPass, // for this yet? TODO(someone). PassResult runOnCFGFunction(Function *f) override { return success(); } - void visitIfStmt(IfStmt *ifStmt); - void visitOperationInst(OperationInst *opStmt); + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); static char passID; }; @@ -70,18 +70,18 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) { - auto set = ifStmt->getCondition().getIntegerSet(); - ifStmt->setIntegerSet(simplifyIntegerSet(set)); +void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { + auto set = ifInst->getCondition().getIntegerSet(); + ifInst->setIntegerSet(simplifyIntegerSet(set)); } -void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) { - for (auto attr : opStmt->getAttrs()) { +void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { MutableAffineMap mMap(mapAttr.getValue()); mMap.simplify(); auto map = mMap.getAffineMap(); - opStmt->setAttr(attr.first, AffineMapAttr::get(map)); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); } } } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index a4116667794..6064d1feff3 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -271,7 +271,7 @@ static void processMLFunction(Function *fn, } void setInsertionPoint(OperationInst *op) override { - // Any new operations should be added before this statement. + // Any new operations should be added before this instruction. builder.setInsertionPoint(cast<OperationInst>(op)); } @@ -280,7 +280,7 @@ static void processMLFunction(Function *fn, }; GreedyPatternRewriteDriver driver(std::move(patterns)); - fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); }); + fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); }); FuncBuilder mlBuilder(fn); MLFuncRewriter rewriter(driver, mlBuilder); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 03b4bb29e19..93039372121 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -26,8 +26,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Statements.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Instructions.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" @@ -38,22 +38,22 @@ using namespace mlir; /// Returns the upper bound of an unrolled loop with lower bound 'lb' and with /// the specified trip count, stride, and unroll factor. Returns nullptr when /// the trip count can't be expressed as an affine expression. -AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt, +AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forStmt.getLowerBoundMap(); + auto lbMap = forInst.getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap::Null(); // Sometimes, the trip count cannot be expressed as an affine expression. - auto tripCount = getTripCountExpr(forStmt); + auto tripCount = getTripCountExpr(forInst); if (!tripCount) return AffineMap::Null(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forStmt.getStep(); + unsigned step = forInst.getStep(); auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), @@ -64,122 +64,122 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt, /// bound 'lb' and with the specified trip count, stride, and unroll factor. /// Returns an AffinMap with nullptr storage (that evaluates to false) /// when the trip count can't be expressed as an affine expression. -AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt, +AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forStmt.getLowerBoundMap(); + auto lbMap = forInst.getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap::Null(); // Sometimes the trip count cannot be expressed as an affine expression. - AffineExpr tripCount(getTripCountExpr(forStmt)); + AffineExpr tripCount(getTripCountExpr(forInst)); if (!tripCount) return AffineMap::Null(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forStmt.getStep(); + unsigned step = forInst.getStep(); auto newLb = lb + (tripCount - tripCount % unrollFactor) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), {newLb}, {}); } -/// Promotes the loop body of a forStmt to its containing block if the forStmt +/// Promotes the loop body of a forInst to its containing block if the forInst /// was known to have a single iteration. Returns false otherwise. // TODO(bondhugula): extend this for arbitrary affine bounds. -bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { - Optional<uint64_t> tripCount = getConstantTripCount(*forStmt); +bool mlir::promoteIfSingleIteration(ForInst *forInst) { + Optional<uint64_t> tripCount = getConstantTripCount(*forInst); if (!tripCount.hasValue() || tripCount.getValue() != 1) return false; // TODO(mlir-team): there is no builder for a max. - if (forStmt->getLowerBoundMap().getNumResults() != 1) + if (forInst->getLowerBoundMap().getNumResults() != 1) return false; // Replaces all IV uses to its single iteration value. - if (!forStmt->use_empty()) { - if (forStmt->hasConstantLowerBound()) { - auto *mlFunc = forStmt->getFunction(); + if (!forInst->use_empty()) { + if (forInst->hasConstantLowerBound()) { + auto *mlFunc = forInst->getFunction(); FuncBuilder topBuilder(&mlFunc->getBody()->front()); auto constOp = topBuilder.create<ConstantIndexOp>( - forStmt->getLoc(), forStmt->getConstantLowerBound()); - forStmt->replaceAllUsesWith(constOp); + forInst->getLoc(), forInst->getConstantLowerBound()); + forInst->replaceAllUsesWith(constOp); } else { - const AffineBound lb = forStmt->getLowerBound(); + const AffineBound lb = forInst->getLowerBound(); SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end()); - FuncBuilder builder(forStmt->getBlock(), Block::iterator(forStmt)); + FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); auto affineApplyOp = builder.create<AffineApplyOp>( - forStmt->getLoc(), lb.getMap(), lbOperands); - forStmt->replaceAllUsesWith(affineApplyOp->getResult(0)); + forInst->getLoc(), lb.getMap(), lbOperands); + forInst->replaceAllUsesWith(affineApplyOp->getResult(0)); } } - // Move the loop body statements to the loop's containing block. - auto *block = forStmt->getBlock(); - block->getInstructions().splice(Block::iterator(forStmt), - forStmt->getBody()->getInstructions()); - forStmt->erase(); + // Move the loop body instructions to the loop's containing block. + auto *block = forInst->getBlock(); + block->getInstructions().splice(Block::iterator(forInst), + forInst->getBody()->getInstructions()); + forInst->erase(); return true; } -/// Promotes all single iteration for stmt's in the Function, i.e., moves +/// Promotes all single iteration for inst's in the Function, i.e., moves /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> { + class LoopBodyPromoter : public InstWalker<LoopBodyPromoter> { public: - void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); } + void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); } }; LoopBodyPromoter fsw; fsw.walkPostOrder(f); } -/// Generates a 'for' stmt with the specified lower and upper bounds while -/// generating the right IV remappings for the shifted statements. The -/// statement blocks that go into the loop are specified in stmtGroupQueue +/// Generates a 'for' inst with the specified lower and upper bounds while +/// generating the right IV remappings for the shifted instructions. The +/// instruction blocks that go into the loop are specified in instGroupQueue /// starting from the specified offset, and in that order; the first element of -/// the pair specifies the shift applied to that group of statements; note that -/// the shift is multiplied by the loop step before being applied. Returns +/// the pair specifies the shift applied to that group of instructions; note +/// that the shift is multiplied by the loop step before being applied. Returns /// nullptr if the generated loop simplifies to a single iteration one. -static ForStmt * +static ForInst * generateLoop(AffineMap lbMap, AffineMap ubMap, - const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> - &stmtGroupQueue, - unsigned offset, ForStmt *srcForStmt, FuncBuilder *b) { - SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands()); - SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands()); + const std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> + &instGroupQueue, + unsigned offset, ForInst *srcForInst, FuncBuilder *b) { + SmallVector<Value *, 4> lbOperands(srcForInst->getLowerBoundOperands()); + SmallVector<Value *, 4> ubOperands(srcForInst->getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); - auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap, - ubOperands, ubMap, srcForStmt->getStep()); + auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap, + ubOperands, ubMap, srcForInst->getStep()); OperationInst::OperandMapTy operandMap; - for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end(); + for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end(); it != e; ++it) { uint64_t shift = it->first; - auto stmts = it->second; - // All 'same shift' statements get added with their operands being remapped - // to results of cloned statements, and their IV used remapped. + auto insts = it->second; + // All 'same shift' instructions get added with their operands being + // remapped to results of cloned instructions, and their IV used remapped. // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. - if (!srcForStmt->use_empty() && shift != 0) { - auto b = FuncBuilder::getForStmtBodyBuilder(loopChunk); + if (!srcForInst->use_empty() && shift != 0) { + auto b = FuncBuilder::getForInstBodyBuilder(loopChunk); auto *ivRemap = b.create<AffineApplyOp>( - srcForStmt->getLoc(), + srcForInst->getLoc(), b.getSingleDimShiftAffineMap(-static_cast<int64_t>( - srcForStmt->getStep() * shift)), + srcForInst->getStep() * shift)), loopChunk) ->getResult(0); - operandMap[srcForStmt] = ivRemap; + operandMap[srcForInst] = ivRemap; } else { - operandMap[srcForStmt] = loopChunk; + operandMap[srcForInst] = loopChunk; } - for (auto *stmt : stmts) { - loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext())); + for (auto *inst : insts) { + loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext())); } } if (promoteIfSingleIteration(loopChunk)) @@ -187,63 +187,63 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, return loopChunk; } -/// Skew the statements in the body of a 'for' statement with the specified -/// statement-wise shifts. The shifts are with respect to the original execution -/// order, and are multiplied by the loop 'step' before being applied. A shift -/// of zero for each statement will lead to no change. -// The skewing of statements with respect to one another can be used for example -// to allow overlap of asynchronous operations (such as DMA communication) with -// computation, or just relative shifting of statements for better register -// reuse, locality or parallelism. As such, the shifts are typically expected to -// be at most of the order of the number of statements. This method should not -// be used as a substitute for loop distribution/fission. -// This method uses an algorithm// in time linear in the number of statements in -// the body of the for loop - (using the 'sweep line' paradigm). This method +/// Skew the instructions in the body of a 'for' instruction with the specified +/// instruction-wise shifts. The shifts are with respect to the original +/// execution order, and are multiplied by the loop 'step' before being applied. +/// A shift of zero for each instruction will lead to no change. +// The skewing of instructions with respect to one another can be used for +// example to allow overlap of asynchronous operations (such as DMA +// communication) with computation, or just relative shifting of instructions +// for better register reuse, locality or parallelism. As such, the shifts are +// typically expected to be at most of the order of the number of instructions. +// This method should not be used as a substitute for loop distribution/fission. +// This method uses an algorithm// in time linear in the number of instructions +// in the body of the for loop - (using the 'sweep line' paradigm). This method // asserts preservation of SSA dominance. A check for that as well as that for // memory-based depedence preservation check rests with the users of this // method. -UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, +UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts, bool unrollPrologueEpilogue) { - if (forStmt->getBody()->empty()) + if (forInst->getBody()->empty()) return UtilResult::Success; // If the trip counts aren't constant, we would need versioning and // conditional guards (or context information to prevent such versioning). The // better way to pipeline for such loops is to first tile them and extract // constant trip count "full tiles" before applying this. - auto mayBeConstTripCount = getConstantTripCount(*forStmt); + auto mayBeConstTripCount = getConstantTripCount(*forInst); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";); return UtilResult::Success; } uint64_t tripCount = mayBeConstTripCount.getValue(); - assert(isStmtwiseShiftValid(*forStmt, shifts) && + assert(isInstwiseShiftValid(*forInst, shifts) && "shifts will lead to an invalid transformation\n"); - int64_t step = forStmt->getStep(); + int64_t step = forInst->getStep(); - unsigned numChildStmts = forStmt->getBody()->getInstructions().size(); + unsigned numChildInsts = forInst->getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; - for (unsigned i = 0; i < numChildStmts; i++) { + for (unsigned i = 0; i < numChildInsts; i++) { maxShift = std::max(maxShift, shifts[i]); } // Such large shifts are not the typical use case. - if (maxShift >= numChildStmts) { - LLVM_DEBUG(llvm::dbgs() << "stmt shifts too large - unexpected\n";); + if (maxShift >= numChildInsts) { + LLVM_DEBUG(llvm::dbgs() << "inst shifts too large - unexpected\n";); return UtilResult::Success; } - // An array of statement groups sorted by shift amount; each group has all - // statements with the same shift in the order in which they appear in the - // body of the 'for' stmt. - std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1); + // An array of instruction groups sorted by shift amount; each group has all + // instructions with the same shift in the order in which they appear in the + // body of the 'for' inst. + std::vector<std::vector<Instruction *>> sortedInstGroups(maxShift + 1); unsigned pos = 0; - for (auto &stmt : *forStmt->getBody()) { + for (auto &inst : *forInst->getBody()) { auto shift = shifts[pos++]; - sortedStmtGroups[shift].push_back(&stmt); + sortedInstGroups[shift].push_back(&inst); } // Unless the shifts have a specific pattern (which actually would be the @@ -251,40 +251,40 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first // loop generated as the prologue and the last as epilogue and unroll these // fully. - ForStmt *prologue = nullptr; - ForStmt *epilogue = nullptr; + ForInst *prologue = nullptr; + ForInst *epilogue = nullptr; // Do a sweep over the sorted shifts while storing open groups in a // vector, and generating loop portions as necessary during the sweep. A block - // of statements is paired with its shift. - std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue; + // of instructions is paired with its shift. + std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> instGroupQueue; - auto origLbMap = forStmt->getLowerBoundMap(); + auto origLbMap = forInst->getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forStmt); - for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) { + FuncBuilder b(forInst); + for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. - if (sortedStmtGroups[d].empty()) + if (sortedInstGroups[d].empty()) continue; - if (!stmtGroupQueue.empty()) { + if (!instGroupQueue.empty()) { assert(d >= 1 && "Queue expected to be empty when the first block is found"); // The interval for which the loop needs to be generated here is: // [lbShift, min(lbShift + tripCount, d)) and the body of the - // loop needs to have all statements in stmtQueue in that order. - ForStmt *res; + // loop needs to have all instructions in instQueue in that order. + ForInst *res; if (lbShift + tripCount * step < d * step) { res = generateLoop( b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step), - stmtGroupQueue, 0, forStmt, &b); - // Entire loop for the queued stmt groups generated, empty it. - stmtGroupQueue.clear(); + instGroupQueue, 0, forInst, &b); + // Entire loop for the queued inst groups generated, empty it. + instGroupQueue.clear(); lbShift += tripCount * step; } else { res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), - b.getShiftedAffineMap(origLbMap, d), stmtGroupQueue, - 0, forStmt, &b); + b.getShiftedAffineMap(origLbMap, d), instGroupQueue, + 0, forInst, &b); lbShift = d * step; } if (!prologue && res) @@ -294,24 +294,24 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, // Start of first interval. lbShift = d * step; } - // Augment the list of statements that get into the current open interval. - stmtGroupQueue.push_back({d, sortedStmtGroups[d]}); + // Augment the list of instructions that get into the current open interval. + instGroupQueue.push_back({d, sortedInstGroups[d]}); } - // Those statements groups left in the queue now need to be processed (FIFO) + // Those instructions groups left in the queue now need to be processed (FIFO) // and their loops completed. - for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) { - uint64_t ubShift = (stmtGroupQueue[i].first + tripCount) * step; + for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) { + uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step; epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, ubShift), - stmtGroupQueue, i, forStmt, &b); + instGroupQueue, i, forInst, &b); lbShift = ubShift; if (!prologue) prologue = epilogue; } - // Erase the original for stmt. - forStmt->erase(); + // Erase the original for inst. + forInst->erase(); if (unrollPrologueEpilogue && prologue) loopUnrollFull(prologue); @@ -322,39 +322,39 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, } /// Unrolls this loop completely. -bool mlir::loopUnrollFull(ForStmt *forStmt) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); +bool mlir::loopUnrollFull(ForInst *forInst) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); if (mayBeConstantTripCount.hasValue()) { uint64_t tripCount = mayBeConstantTripCount.getValue(); if (tripCount == 1) { - return promoteIfSingleIteration(forStmt); + return promoteIfSingleIteration(forInst); } - return loopUnrollByFactor(forStmt, tripCount); + return loopUnrollByFactor(forInst, tripCount); } return false; } /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant) whichever is lower. -bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); +bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) - return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue()); - return loopUnrollByFactor(forStmt, unrollFactor); + return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue()); + return loopUnrollByFactor(forInst, unrollFactor); } /// Unrolls this loop by the specified factor. Returns true if the loop /// is successfully unrolled. -bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { +bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); - if (unrollFactor == 1 || forStmt->getBody()->empty()) + if (unrollFactor == 1 || forInst->getBody()->empty()) return false; - auto lbMap = forStmt->getLowerBoundMap(); - auto ubMap = forStmt->getUpperBoundMap(); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -365,10 +365,10 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different operand lists. - if (!forStmt->matchingBoundOperandList()) + if (!forInst->matchingBoundOperandList()) return false; - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); // If the trip count is lower than the unroll factor, no unrolled body. // TODO(bondhugula): option to specify cleanup loop unrolling. @@ -377,64 +377,64 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { + if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) { DenseMap<const Value *, Value *> operandMap; - FuncBuilder builder(forStmt->getBlock(), ++Block::iterator(forStmt)); - auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); - auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); + FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); + auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap)); + auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " "always be determined"); - cleanupForStmt->setLowerBoundMap(clLbMap); + cleanupForInst->setLowerBoundMap(clLbMap); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForStmt); + promoteIfSingleIteration(cleanupForInst); // Adjust upper bound. auto unrolledUbMap = - getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder); + getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder); assert(unrolledUbMap && "upper bound map can alwayys be determined for an unrolled loop " "with single result bounds"); - forStmt->setUpperBoundMap(unrolledUbMap); + forInst->setUpperBoundMap(unrolledUbMap); } // Scale the step of loop being unrolled by unroll factor. - int64_t step = forStmt->getStep(); - forStmt->setStep(step * unrollFactor); + int64_t step = forInst->getStep(); + forInst->setStep(step * unrollFactor); - // Builder to insert unrolled bodies right after the last statement in the - // body of 'forStmt'. - FuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end()); + // Builder to insert unrolled bodies right after the last instruction in the + // body of 'forInst'. + FuncBuilder builder(forInst->getBody(), forInst->getBody()->end()); - // Keep a pointer to the last statement in the original block so that we know - // what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); + // Keep a pointer to the last instruction in the original block so that we + // know what to clone (since we are doing this in-place). + Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end()); - // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). + // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { DenseMap<const Value *, Value *> operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forStmt->use_empty()) { + if (!forInst->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto *ivUnroll = - builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) + builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst) ->getResult(0); - operandMap[forStmt] = ivUnroll; + operandMap[forInst] = ivUnroll; } - // Clone the original body of 'forStmt'. - for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd); + // Clone the original body of 'forInst'. + for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd); it++) { builder.clone(*it, operandMap); } } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forStmt); + promoteIfSingleIteration(forInst); return true; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 3661c1bdbbc..8cfe2619e2a 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -26,8 +26,8 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Module.h" -#include "mlir/IR/StmtVisitor.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" @@ -66,7 +66,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, ArrayRef<Value *> extraIndices, AffineMap indexRemap, ArrayRef<Value *> extraOperands, - const Statement *domStmtFilter) { + const Instruction *domInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); @@ -85,41 +85,41 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // Walk all uses of old memref. Operation using the memref gets replaced. for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { InstOperand &use = *(it++); - auto *opStmt = cast<OperationInst>(use.getOwner()); + auto *opInst = cast<OperationInst>(use.getOwner()); - // Skip this use if it's not dominated by domStmtFilter. - if (domStmtFilter && !dominates(*domStmtFilter, *opStmt)) + // Skip this use if it's not dominated by domInstFilter. + if (domInstFilter && !dominates(*domInstFilter, *opInst)) continue; // Check if the memref was used in a non-deferencing context. It is fine for // the memref to be used in a non-deferencing way outside of the region // where this replacement is happening. - if (!isMemRefDereferencingOp(*opStmt)) + if (!isMemRefDereferencingOp(*opInst)) // Failure: memref used in a non-deferencing op (potentially escapes); no // replacement in these cases. return false; auto getMemRefOperandPos = [&]() -> unsigned { unsigned i, e; - for (i = 0, e = opStmt->getNumOperands(); i < e; i++) { - if (opStmt->getOperand(i) == oldMemRef) + for (i = 0, e = opInst->getNumOperands(); i < e; i++) { + if (opInst->getOperand(i) == oldMemRef) break; } - assert(i < opStmt->getNumOperands() && "operand guaranteed to be found"); + assert(i < opInst->getNumOperands() && "operand guaranteed to be found"); return i; }; unsigned memRefOperandPos = getMemRefOperandPos(); - // Construct the new operation statement using this memref. - OperationState state(opStmt->getContext(), opStmt->getLoc(), - opStmt->getName()); - state.operands.reserve(opStmt->getNumOperands() + extraIndices.size()); + // Construct the new operation instruction using this memref. + OperationState state(opInst->getContext(), opInst->getLoc(), + opInst->getName()); + state.operands.reserve(opInst->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. - state.operands.insert(state.operands.end(), opStmt->operand_begin(), - opStmt->operand_begin() + memRefOperandPos); + state.operands.insert(state.operands.end(), opInst->operand_begin(), + opInst->operand_begin() + memRefOperandPos); state.operands.push_back(newMemRef); - FuncBuilder builder(opStmt); + FuncBuilder builder(opInst); for (auto *extraIndex : extraIndices) { // TODO(mlir-team): An operation/SSA value should provide a method to // return the position of an SSA result in its defining @@ -139,10 +139,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, remapOperands.insert(remapOperands.end(), extraOperands.begin(), extraOperands.end()); remapOperands.insert( - remapOperands.end(), opStmt->operand_begin() + memRefOperandPos + 1, - opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank); + remapOperands.end(), opInst->operand_begin() + memRefOperandPos + 1, + opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank); if (indexRemap) { - auto remapOp = builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, + auto remapOp = builder.create<AffineApplyOp>(opInst->getLoc(), indexRemap, remapOperands); // Remapped indices. for (auto *index : remapOp->getInstruction()->getResults()) @@ -155,27 +155,27 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // Insert the remaining operands unmodified. state.operands.insert(state.operands.end(), - opStmt->operand_begin() + memRefOperandPos + 1 + + opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank, - opStmt->operand_end()); + opInst->operand_end()); // Result types don't change. Both memref's are of the same elemental type. - state.types.reserve(opStmt->getNumResults()); - for (const auto *result : opStmt->getResults()) + state.types.reserve(opInst->getNumResults()); + for (const auto *result : opInst->getResults()) state.types.push_back(result->getType()); // Attributes also do not change. - state.attributes.insert(state.attributes.end(), opStmt->getAttrs().begin(), - opStmt->getAttrs().end()); + state.attributes.insert(state.attributes.end(), opInst->getAttrs().begin(), + opInst->getAttrs().end()); // Create the new operation. auto *repOp = builder.createOperation(state); // Replace old memref's deferencing op's uses. unsigned r = 0; - for (auto *res : opStmt->getResults()) { + for (auto *res : opInst->getResults()) { res->replaceAllUsesWith(repOp->getResult(r++)); } - opStmt->erase(); + opInst->erase(); } return true; } @@ -196,9 +196,9 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, // Initialize AffineValueMap with identity map. AffineValueMap valueMap(map, operands); - for (auto *opStmt : affineApplyOps) { - assert(opStmt->isa<AffineApplyOp>()); - auto affineApplyOp = opStmt->cast<AffineApplyOp>(); + for (auto *opInst : affineApplyOps) { + assert(opInst->isa<AffineApplyOp>()); + auto affineApplyOp = opInst->cast<AffineApplyOp>(); // Forward substitute 'affineApplyOp' into 'valueMap'. valueMap.forwardSubstitute(*affineApplyOp); } @@ -219,10 +219,10 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, return affineApplyOp->getInstruction(); } -/// Given an operation statement, inserts a new single affine apply operation, -/// that is exclusively used by this operation statement, and that provides all -/// operands that are results of an affine_apply as a function of loop iterators -/// and program parameters and whose results are. +/// Given an operation instruction, inserts a new single affine apply operation, +/// that is exclusively used by this operation instruction, and that provides +/// all operands that are results of an affine_apply as a function of loop +/// iterators and program parameters and whose results are. /// /// Before /// @@ -242,18 +242,18 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// This allows applying different transformations on send and compute (for eg. /// different shifts/delays). /// -/// Returns nullptr either if none of opStmt's operands were the result of an +/// Returns nullptr either if none of opInst's operands were the result of an /// affine_apply and thus there was no affine computation slice to create, or if -/// all the affine_apply op's supplying operands to this opStmt do not have any -/// uses besides this opStmt. Returns the new affine_apply operation statement +/// all the affine_apply op's supplying operands to this opInst do not have any +/// uses besides this opInst. Returns the new affine_apply operation instruction /// otherwise. -OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { +OperationInst *mlir::createAffineComputationSlice(OperationInst *opInst) { // Collect all operands that are results of affine apply ops. SmallVector<Value *, 4> subOperands; - subOperands.reserve(opStmt->getNumOperands()); - for (auto *operand : opStmt->getOperands()) { - auto *defStmt = operand->getDefiningInst(); - if (defStmt && defStmt->isa<AffineApplyOp>()) { + subOperands.reserve(opInst->getNumOperands()); + for (auto *operand : opInst->getOperands()) { + auto *defInst = operand->getDefiningInst(); + if (defInst && defInst->isa<AffineApplyOp>()) { subOperands.push_back(operand); } } @@ -265,13 +265,13 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { if (affineApplyOps.empty()) return nullptr; - // Check if all uses of the affine apply op's lie only in this op stmt, in + // Check if all uses of the affine apply op's lie only in this op inst, in // which case there would be nothing to do. bool localized = true; for (auto *op : affineApplyOps) { for (auto *result : op->getResults()) { for (auto &use : result->getUses()) { - if (use.getOwner() != opStmt) { + if (use.getOwner() != opInst) { localized = false; break; } @@ -281,18 +281,18 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { if (localized) return nullptr; - FuncBuilder builder(opStmt); + FuncBuilder builder(opInst); SmallVector<Value *, 4> results; - auto *affineApplyStmt = createComposedAffineApplyOp( - &builder, opStmt->getLoc(), subOperands, affineApplyOps, &results); + auto *affineApplyInst = createComposedAffineApplyOp( + &builder, opInst->getLoc(), subOperands, affineApplyOps, &results); assert(results.size() == subOperands.size() && "number of results should be the same as the number of subOperands"); // Construct the new operands that include the results from the composed // affine apply op above instead of existing ones (subOperands). So, they - // differ from opStmt's operands only for those operands in 'subOperands', for + // differ from opInst's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'results'. - SmallVector<Value *, 4> newOperands(opStmt->getOperands()); + SmallVector<Value *, 4> newOperands(opInst->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; @@ -306,10 +306,10 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { } for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { - opStmt->setOperand(idx, newOperands[idx]); + opInst->setOperand(idx, newOperands[idx]); } - return affineApplyStmt; + return affineApplyInst; } void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { @@ -317,26 +317,26 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { // TODO: Support forward substitution for CFG style functions. return; } - auto *opStmt = affineApplyOp->getInstruction(); - // Iterate through all uses of all results of 'opStmt', forward substituting + auto *opInst = affineApplyOp->getInstruction(); + // Iterate through all uses of all results of 'opInst', forward substituting // into any uses which are AffineApplyOps. - for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; + for (unsigned resultIndex = 0, e = opInst->getNumResults(); resultIndex < e; ++resultIndex) { - const Value *result = opStmt->getResult(resultIndex); + const Value *result = opInst->getResult(resultIndex); for (auto it = result->use_begin(); it != result->use_end();) { InstOperand &use = *(it++); - auto *useStmt = use.getOwner(); - auto *useOpStmt = dyn_cast<OperationInst>(useStmt); + auto *useInst = use.getOwner(); + auto *useOpInst = dyn_cast<OperationInst>(useInst); // Skip if use is not AffineApplyOp. - if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>()) + if (useOpInst == nullptr || !useOpInst->isa<AffineApplyOp>()) continue; - // Advance iterator past 'opStmt' operands which also use 'result'. - while (it != result->use_end() && it->getOwner() == useStmt) + // Advance iterator past 'opInst' operands which also use 'result'. + while (it != result->use_end() && it->getOwner() == useInst) ++it; - FuncBuilder builder(useOpStmt); + FuncBuilder builder(useOpInst); // Initialize AffineValueMap with 'affineApplyOp' which uses 'result'. - auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>(); + auto oldAffineApplyOp = useOpInst->cast<AffineApplyOp>(); AffineValueMap valueMap(*oldAffineApplyOp); // Forward substitute 'result' at index 'i' into 'valueMap'. valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex); @@ -348,10 +348,10 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { operands[i] = valueMap.getOperand(i); } auto newAffineApplyOp = builder.create<AffineApplyOp>( - useOpStmt->getLoc(), valueMap.getAffineMap(), operands); + useOpInst->getLoc(), valueMap.getAffineMap(), operands); // Update all uses to use results from 'newAffineApplyOp'. - for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) { + for (unsigned i = 0, e = useOpInst->getNumResults(); i < e; ++i) { oldAffineApplyOp->getResult(i)->replaceAllUsesWith( newAffineApplyOp->getResult(i)); } @@ -364,19 +364,19 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { /// Folds the specified (lower or upper) bound to a constant if possible /// considering its operands. Returns false if the folding happens for any of /// the bounds, true otherwise. -bool mlir::constantFoldBounds(ForStmt *forStmt) { - auto foldLowerOrUpperBound = [forStmt](bool lower) { +bool mlir::constantFoldBounds(ForInst *forInst) { + auto foldLowerOrUpperBound = [forInst](bool lower) { // Check if the bound is already a constant. - if (lower && forStmt->hasConstantLowerBound()) + if (lower && forInst->hasConstantLowerBound()) return true; - if (!lower && forStmt->hasConstantUpperBound()) + if (!lower && forInst->hasConstantUpperBound()) return true; // Check to see if each of the operands is the result of a constant. If so, // get the value. If not, ignore it. SmallVector<Attribute, 8> operandConstants; - auto boundOperands = lower ? forStmt->getLowerBoundOperands() - : forStmt->getUpperBoundOperands(); + auto boundOperands = lower ? forInst->getLowerBoundOperands() + : forInst->getUpperBoundOperands(); for (const auto *operand : boundOperands) { Attribute operandCst; if (auto *operandOp = operand->getDefiningInst()) { @@ -387,7 +387,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { } AffineMap boundMap = - lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap(); + lower ? forInst->getLowerBoundMap() : forInst->getUpperBoundMap(); assert(boundMap.getNumResults() >= 1 && "bound maps should have at least one result"); SmallVector<Attribute, 4> foldedResults; @@ -402,8 +402,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) : llvm::APIntOps::smin(maxOrMin, foldedResult); } - lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue()) - : forStmt->setConstantUpperBound(maxOrMin.getSExtValue()); + lower ? forInst->setConstantLowerBound(maxOrMin.getSExtValue()) + : forInst->setConstantUpperBound(maxOrMin.getSExtValue()); // Return false on success. return false; @@ -449,11 +449,11 @@ void mlir::remapFunctionAttrs( if (!fn.isML()) return; - struct MLFnWalker : public StmtWalker<MLFnWalker> { + struct MLFnWalker : public InstWalker<MLFnWalker> { MLFnWalker(const DenseMap<Attribute, FunctionAttr> &remappingTable) : remappingTable(remappingTable) {} - void visitOperationInst(OperationInst *opStmt) { - remapFunctionAttrs(*opStmt, remappingTable); + void visitOperationInst(OperationInst *opInst) { + remapFunctionAttrs(*opInst, remappingTable); } const DenseMap<Attribute, FunctionAttr> &remappingTable; diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 78d048b4778..9aa11682ebb 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -95,20 +95,20 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { SmallVector<int, 8> shape(clTestVectorShapeRatio.begin(), clTestVectorShapeRatio.end()); auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext())); - // Only filter statements that operate on a strict super-vector and have one + // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. - auto filter = [subVectorType](const Statement &stmt) { - auto *opStmt = dyn_cast<OperationInst>(&stmt); - if (!opStmt) { + auto filter = [subVectorType](const Instruction &inst) { + auto *opInst = dyn_cast<OperationInst>(&inst); + if (!opInst) { return false; } assert(subVectorType.getElementType() == Type::getF32(subVectorType.getContext()) && "Only f32 supported for now"); - if (!matcher::operatesOnStrictSuperVectors(*opStmt, subVectorType)) { + if (!matcher::operatesOnStrictSuperVectors(*opInst, subVectorType)) { return false; } - if (opStmt->getNumResults() != 1) { + if (opInst->getNumResults() != 1) { return false; } return true; @@ -116,26 +116,26 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { auto pat = Op(filter); auto matches = pat.match(f); for (auto m : matches) { - auto *opStmt = cast<OperationInst>(m.first); + auto *opInst = cast<OperationInst>(m.first); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the // future we can always extend. - auto superVectorType = opStmt->getResult(0)->getType().cast<VectorType>(); + auto superVectorType = opInst->getResult(0)->getType().cast<VectorType>(); auto ratio = shapeRatio(superVectorType, subVectorType); if (!ratio.hasValue()) { - opStmt->emitNote("NOT MATCHED"); + opInst->emitNote("NOT MATCHED"); } else { - outs() << "\nmatched: " << *opStmt << " with shape ratio: "; + outs() << "\nmatched: " << *opInst << " with shape ratio: "; interleaveComma(MutableArrayRef<unsigned>(*ratio), outs()); } } } -static std::string toString(Statement *stmt) { +static std::string toString(Instruction *inst) { std::string res; auto os = llvm::raw_string_ostream(res); - stmt->print(os); + inst->print(os); return res; } @@ -144,10 +144,10 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) { constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; using matcher::Op; - // Match all OpStatements with the kTestSlicingOpName name. - auto filter = [](const Statement &stmt) { - const auto &opStmt = cast<OperationInst>(stmt); - return opStmt.getName().getStringRef() == kTestSlicingOpName; + // Match all OpInstructions with the kTestSlicingOpName name. + auto filter = [](const Instruction &inst) { + const auto &opInst = cast<OperationInst>(inst); + return opInst.getName().getStringRef() == kTestSlicingOpName; }; auto pat = Op(filter); return pat.match(f); @@ -156,7 +156,7 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) { void VectorizerTestPass::testBackwardSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { - SetVector<Statement *> backwardSlice; + SetVector<Instruction *> backwardSlice; getBackwardSlice(m.first, &backwardSlice); auto strs = map(toString, backwardSlice); outs() << "\nmatched: " << *m.first << " backward static slice: "; @@ -169,7 +169,7 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) { void VectorizerTestPass::testForwardSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { - SetVector<Statement *> forwardSlice; + SetVector<Instruction *> forwardSlice; getForwardSlice(m.first, &forwardSlice); auto strs = map(toString, forwardSlice); outs() << "\nmatched: " << *m.first << " forward static slice: "; @@ -182,7 +182,7 @@ void VectorizerTestPass::testForwardSlicing(Function *f) { void VectorizerTestPass::testSlicing(Function *f) { auto matches = matchTestSlicingOps(f); for (auto m : matches) { - SetVector<Statement *> staticSlice = getSlice(m.first); + SetVector<Instruction *> staticSlice = getSlice(m.first); auto strs = map(toString, staticSlice); outs() << "\nmatched: " << *m.first << " static slice: "; for (const auto &s : strs) { @@ -191,9 +191,9 @@ void VectorizerTestPass::testSlicing(Function *f) { } } -bool customOpWithAffineMapAttribute(const Statement &stmt) { - const auto &opStmt = cast<OperationInst>(stmt); - return opStmt.getName().getStringRef() == +bool customOpWithAffineMapAttribute(const Instruction &inst) { + const auto &opInst = cast<OperationInst>(inst); + return opInst.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -205,8 +205,8 @@ void VectorizerTestPass::testComposeMaps(Function *f) { maps.reserve(matches.size()); std::reverse(matches.begin(), matches.end()); for (auto m : matches) { - auto *opStmt = cast<OperationInst>(m.first); - auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName) + auto *opInst = cast<OperationInst>(m.first); + auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast<AffineMapAttr>() .getValue(); maps.push_back(map); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ddbd6256782..bbb703cd627 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -252,7 +252,7 @@ using namespace mlir; /// ========== /// The algorithm proceeds in a few steps: /// 1. defining super-vectorization patterns and matching them on the tree of -/// ForStmt. A super-vectorization pattern is defined as a recursive data +/// ForInst. A super-vectorization pattern is defined as a recursive data /// structures that matches and captures nested, imperfectly-nested loops /// that have a. comformable loop annotations attached (e.g. parallel, /// reduction, vectoriable, ...) as well as b. all contiguous load/store @@ -279,7 +279,7 @@ using namespace mlir; /// it by its vector form. Otherwise, if the scalar value is a constant, /// it is vectorized into a splat. In all other cases, vectorization for /// the pattern currently fails. -/// e. if everything under the root ForStmt in the current pattern vectorizes +/// e. if everything under the root ForInst in the current pattern vectorizes /// properly, we commit that loop to the IR. Otherwise we discard it and /// restore a previously cloned version of the loop. Thanks to the /// recursive scoping nature of matchers and captured patterns, this is @@ -668,12 +668,12 @@ namespace { struct VectorizationStrategy { ArrayRef<int> vectorSizes; - DenseMap<ForStmt *, unsigned> loopToVectorDim; + DenseMap<ForInst *, unsigned> loopToVectorDim; }; } // end anonymous namespace -static void vectorizeLoopIfProfitable(ForStmt *loop, unsigned depthInPattern, +static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { assert(patternDepth > depthInPattern && @@ -705,7 +705,7 @@ static bool analyzeProfitability(MLFunctionMatches matches, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForStmt>(m.first); + auto *loop = cast<ForInst>(m.first); bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth, strategy); if (fail) { @@ -721,7 +721,7 @@ static bool analyzeProfitability(MLFunctionMatches matches, namespace { struct VectorizationState { - /// Adds an entry of pre/post vectorization statements in the state. + /// Adds an entry of pre/post vectorization instructions in the state. void registerReplacement(OperationInst *key, OperationInst *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets @@ -733,7 +733,7 @@ struct VectorizationState { SmallVector<OperationInst *, 16> toErase; // Set of OperationInst that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in - // particular to filter the statements that have already been vectorized by + // particular to filter the instructions that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. DenseSet<OperationInst *> vectorizedSet; // Map of old scalar OperationInst to new vectorized OperationInst. @@ -747,16 +747,16 @@ struct VectorizationState { // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. DenseSet<OperationInst *> roots; - // Terminator statements for the worklist in the vectorizeOperations function. - // They consist of the subset of store operations that have been vectorized. - // They can be retrieved from `vectorizationMap` but it is convenient to keep - // track of them in a separate data structure. Since they do not necessarily - // belong to use-def chains starting from loads (e.g storing a constant), we - // need to handle them in a post-pass. + // Terminator instructions for the worklist in the vectorizeOperations + // function. They consist of the subset of store operations that have been + // vectorized. They can be retrieved from `vectorizationMap` but it is + // convenient to keep track of them in a separate data structure. Since they + // do not necessarily belong to use-def chains starting from loads (e.g + // storing a constant), we need to handle them in a post-pass. DenseSet<OperationInst *> terminators; - // Checks that the type of `stmt` is StoreOp and adds it to the terminators + // Checks that the type of `inst` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationInst *stmt); + void registerTerminator(OperationInst *inst); private: void registerReplacement(const Value *key, Value *value); @@ -784,19 +784,19 @@ void VectorizationState::registerReplacement(OperationInst *key, } } -void VectorizationState::registerTerminator(OperationInst *stmt) { - assert(stmt->isa<StoreOp>() && "terminator must be a StoreOp"); - assert(terminators.count(stmt) == 0 && +void VectorizationState::registerTerminator(OperationInst *inst) { + assert(inst->isa<StoreOp>() && "terminator must be a StoreOp"); + assert(terminators.count(inst) == 0 && "terminator was already inserted previously"); - terminators.insert(stmt); + terminators.insert(inst); } void VectorizationState::finishVectorizationPattern() { while (!toErase.empty()) { - auto *stmt = toErase.pop_back_val(); + auto *inst = toErase.pop_back_val(); LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: "); - LLVM_DEBUG(stmt->print(dbgs())); - stmt->erase(); + LLVM_DEBUG(inst->print(dbgs())); + inst->erase(); } } @@ -832,23 +832,23 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); // Materialize a MemRef with 1 vector. - auto *opStmt = memoryOp->getInstruction(); + auto *opInst = memoryOp->getInstruction(); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations // as needed by various targets. - if (opStmt->template isa<LoadOp>()) { + if (opInst->template isa<LoadOp>()) { auto permutationMap = - makePermutationMap(opStmt, state->strategy->loopToVectorDim); + makePermutationMap(opInst, state->strategy->loopToVectorDim); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto transfer = b.create<VectorTransferReadOp>( - opStmt->getLoc(), vectorType, memoryOp->getMemRef(), + opInst->getLoc(), vectorType, memoryOp->getMemRef(), map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap); - state->registerReplacement(opStmt, transfer->getInstruction()); + state->registerReplacement(opInst, transfer->getInstruction()); } else { - state->registerTerminator(opStmt); + state->registerTerminator(opInst); } return false; } @@ -856,28 +856,29 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, /// Coarsens the loops bounds and transforms all remaining load and store /// operations into the appropriate vector_transfer. -static bool vectorizeForStmt(ForStmt *loop, int64_t step, +static bool vectorizeForInst(ForInst *loop, int64_t step, VectorizationState *state) { using namespace functional; loop->setStep(step); - FilterFunctionType notVectorizedThisPattern = [state](const Statement &stmt) { - if (!matcher::isLoadOrStore(stmt)) { - return false; - } - auto *opStmt = cast<OperationInst>(&stmt); - return state->vectorizationMap.count(opStmt) == 0 && - state->vectorizedSet.count(opStmt) == 0 && - state->roots.count(opStmt) == 0 && - state->terminators.count(opStmt) == 0; - }; + FilterFunctionType notVectorizedThisPattern = + [state](const Instruction &inst) { + if (!matcher::isLoadOrStore(inst)) { + return false; + } + auto *opInst = cast<OperationInst>(&inst); + return state->vectorizationMap.count(opInst) == 0 && + state->vectorizedSet.count(opInst) == 0 && + state->roots.count(opInst) == 0 && + state->terminators.count(opInst) == 0; + }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); auto matches = loadAndStores.match(loop); for (auto ls : matches) { - auto *opStmt = cast<OperationInst>(ls.first); - auto load = opStmt->dyn_cast<LoadOp>(); - auto store = opStmt->dyn_cast<StoreOp>(); - LLVM_DEBUG(opStmt->print(dbgs())); + auto *opInst = cast<OperationInst>(ls.first); + auto load = opInst->dyn_cast<LoadOp>(); + auto store = opInst->dyn_cast<StoreOp>(); + LLVM_DEBUG(opInst->print(dbgs())); auto fail = load ? vectorizeRootOrTerminal(loop, load, state) : vectorizeRootOrTerminal(loop, store, state); if (fail) { @@ -895,8 +896,8 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step, /// we can build a cost model and a search procedure. static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { - return [fastestVaryingMemRefDimension](const Statement &forStmt) { - const auto &loop = cast<ForStmt>(forStmt); + return [fastestVaryingMemRefDimension](const Instruction &forInst) { + const auto &loop = cast<ForInst>(forInst); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -911,7 +912,7 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// recursively in DFS post-order. static bool doVectorize(MLFunctionMatches::EntryType oneMatch, VectorizationState *state) { - ForStmt *loop = cast<ForStmt>(oneMatch.first); + ForInst *loop = cast<ForInst>(oneMatch.first); MLFunctionMatches childrenMatches = oneMatch.second; // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -938,10 +939,10 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch, // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.: // | ub -> ub // | step -> step * vectorSize - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForStmt by " << vectorSize + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize << " : "); LLVM_DEBUG(loop->print(dbgs())); - return vectorizeForStmt(loop, loop->getStep() * vectorSize, state); + return vectorizeForInst(loop, loop->getStep() * vectorSize, state); } /// Non-root pattern iterates over the matches at this level, calls doVectorize @@ -963,20 +964,20 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, +static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, Type type) { if (!type || !type.isa<VectorType>() || !VectorType::isValidElementType(constant.getType())) { return nullptr; } - FuncBuilder b(stmt); - Location loc = stmt->getLoc(); + FuncBuilder b(inst); + Location loc = inst->getLoc(); auto vectorType = type.cast<VectorType>(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpStmt = cast<OperationInst>(constant.getInstruction()); + auto *constantOpInst = cast<OperationInst>(constant.getInstruction()); OperationState state( - b.getContext(), loc, constantOpStmt->getName().getStringRef(), {}, + b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); @@ -985,7 +986,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, } /// Returns a uniqu'ed VectorType. -/// In the case `v`'s defining statement is already part of the `state`'s +/// In the case `v`'s defining instruction is already part of the `state`'s /// vectorizedSet, just returns the type of `v`. /// Otherwise, constructs a new VectorType of shape defined by `state.strategy` /// and of elemental type the type of `v`. @@ -993,17 +994,17 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpStmt = cast<OperationInst>(v->getDefiningInst()); - if (state.vectorizedSet.count(definingOpStmt) > 0) { + auto *definingOpInst = cast<OperationInst>(v->getDefiningInst()); + if (state.vectorizedSet.count(definingOpInst) > 0) { return v->getType().cast<VectorType>(); } return VectorType::get(state.strategy->vectorSizes, v->getType()); }; -/// Tries to vectorize a given operand `op` of Statement `stmt` during def-chain -/// propagation or during terminator vectorization, by applying the following -/// logic: -/// 1. if the defining statement is part of the vectorizedSet (i.e. vectorized +/// Tries to vectorize a given operand `op` of Instruction `inst` during +/// def-chain propagation or during terminator vectorization, by applying the +/// following logic: +/// 1. if the defining instruction is part of the vectorizedSet (i.e. vectorized /// useby -def propagation), `op` is already in the proper vector form; /// 2. otherwise, the `op` may be in some other vector form that fails to /// vectorize atm (i.e. broadcasting required), returns nullptr to indicate @@ -1021,13 +1022,13 @@ static Type getVectorType(Value *v, const VectorizationState &state) { /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static Value *vectorizeOperand(Value *operand, Statement *stmt, +static Value *vectorizeOperand(Value *operand, Instruction *inst, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingStatement = cast<OperationInst>(operand->getDefiningInst()); + auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(definingStatement) > 0) { + if (state->vectorizedSet.count(definingInstruction) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); return operand; } @@ -1049,7 +1050,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, } // 3. vectorize constant. if (auto constant = operand->getDefiningInst()->dyn_cast<ConstantOp>()) { - return vectorizeConstant(stmt, *constant, + return vectorizeConstant(inst, *constant, getVectorType(operand, *state).cast<VectorType>()); } // 4. currently non-vectorizable. @@ -1068,41 +1069,41 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, - OperationInst *opStmt, + OperationInst *opInst, VectorizationState *state) { // Sanity checks. - assert(!opStmt->isa<LoadOp>() && + assert(!opInst->isa<LoadOp>() && "all loads must have already been fully vectorized independently"); - assert(!opStmt->isa<VectorTransferReadOp>() && + assert(!opInst->isa<VectorTransferReadOp>() && "vector_transfer_read cannot be further vectorized"); - assert(!opStmt->isa<VectorTransferWriteOp>() && + assert(!opInst->isa<VectorTransferWriteOp>() && "vector_transfer_write cannot be further vectorized"); - if (auto store = opStmt->dyn_cast<StoreOp>()) { + if (auto store = opInst->dyn_cast<StoreOp>()) { auto *memRef = store->getMemRef(); auto *value = store->getValueToStore(); - auto *vectorValue = vectorizeOperand(value, opStmt, state); + auto *vectorValue = vectorizeOperand(value, opInst, state); auto indices = map(makePtrDynCaster<Value>(), store->getIndices()); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto permutationMap = - makePermutationMap(opStmt, state->strategy->loopToVectorDim); + makePermutationMap(opInst, state->strategy->loopToVectorDim); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<VectorTransferWriteOp>( - opStmt->getLoc(), vectorValue, memRef, indices, permutationMap); + opInst->getLoc(), vectorValue, memRef, indices, permutationMap); auto *res = cast<OperationInst>(transfer->getInstruction()); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. - opStmt->erase(); + opInst->erase(); return res; } auto types = map([state](Value *v) { return getVectorType(v, *state); }, - opStmt->getResults()); - auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * { - return vectorizeOperand(op, opStmt, state); + opInst->getResults()); + auto vectorizeOneOperand = [opInst, state](Value *op) -> Value * { + return vectorizeOperand(op, opInst, state); }; - auto operands = map(vectorizeOneOperand, opStmt->getOperands()); + auto operands = map(vectorizeOneOperand, opInst->getOperands()); // Check whether a single operand is null. If so, vectorization failed. bool success = llvm::all_of(operands, [](Value *op) { return op; }); if (!success) { @@ -1116,9 +1117,9 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, // TODO(ntv): Is it worth considering an OperationInst.clone operation // which changes the type so we can promote an OperationInst with less // boilerplate? - OperationState newOp(b->getContext(), opStmt->getLoc(), - opStmt->getName().getStringRef(), operands, types, - opStmt->getAttrs()); + OperationState newOp(b->getContext(), opInst->getLoc(), + opInst->getName().getStringRef(), operands, types, + opInst->getAttrs()); return b->createOperation(newOp); } @@ -1137,13 +1138,13 @@ static bool vectorizeOperations(VectorizationState *state) { auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *stmt = cast<OperationInst>(u.getOwner()); + auto *inst = cast<OperationInst>(u.getOwner()); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. - if (state->terminators.count(stmt) > 0) { + if (state->terminators.count(inst) > 0) { continue; } - worklist.insert(stmt); + worklist.insert(inst); } }; apply(insertUsesOf, state->roots); @@ -1152,15 +1153,15 @@ static bool vectorizeOperations(VectorizationState *state) { // size again. By construction, the order of elements in the worklist is // consistent across iterations. for (unsigned i = 0; i < worklist.size(); ++i) { - auto *stmt = worklist[i]; + auto *inst = worklist[i]; LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: "); - LLVM_DEBUG(stmt->print(dbgs())); + LLVM_DEBUG(inst->print(dbgs())); - // 2. Create vectorized form of the statement. - // Insert it just before stmt, on success register stmt as replaced. - FuncBuilder b(stmt); - auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state); - if (!vectorizedStmt) { + // 2. Create vectorized form of the instruction. + // Insert it just before inst, on success register inst as replaced. + FuncBuilder b(inst); + auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state); + if (!vectorizedInst) { return true; } @@ -1168,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) { // Note that we cannot just call replaceAllUsesWith because it may // result in ops with mixed types, for ops whose operands have not all // yet been vectorized. This would be invalid IR. - state->registerReplacement(stmt, vectorizedStmt); + state->registerReplacement(inst, vectorizedInst); - // 4. Augment the worklist with uses of the statement we just vectorized. + // 4. Augment the worklist with uses of the instruction we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef<OperationInst *>{stmt}); + apply(insertUsesOf, ArrayRef<OperationInst *>{inst}); } return false; } @@ -1184,7 +1185,7 @@ static bool vectorizeOperations(VectorizationState *state) { static bool vectorizeRootMatches(MLFunctionMatches matches, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForStmt>(m.first); + auto *loop = cast<ForInst>(m.first); VectorizationState state; state.strategy = strategy; @@ -1201,7 +1202,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } FuncBuilder builder(loop); // builder to insert in place of loop DenseMap<const Value *, Value *> nomap; - ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap)); + ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop, nomap)); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via @@ -1230,8 +1231,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); // Vectorize the root operations and everything reached by use-def chains - // except the terminators (store statements) that need to be post-processed - // separately. + // except the terminators (store instructions) that need to be + // post-processed separately. fail = vectorizeOperations(&state); if (fail) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); @@ -1239,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) { + auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { if (fail) { return; } - FuncBuilder b(stmt); - auto *res = vectorizeOneOperationInst(&b, stmt, &state); + FuncBuilder b(inst); + auto *res = vectorizeOneOperationInst(&b, inst, &state); if (res == nullptr) { fail = true; } @@ -1284,7 +1285,7 @@ PassResult Vectorize::runOnMLFunction(Function *f) { if (fail) { continue; } - auto *loop = cast<ForStmt>(m.first); + auto *loop = cast<ForInst>(m.first); vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. |

