diff options
Diffstat (limited to 'mlir/lib')
29 files changed, 1221 insertions, 1250 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 5b29467fc44..f1693c8e449 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -17,7 +17,10 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" using namespace mlir; @@ -27,7 +30,445 @@ using namespace mlir; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { - addOperations<AffineIfOp>(); + addOperations<AffineForOp, AffineIfOp>(); +} + +//===----------------------------------------------------------------------===// +// AffineForOp +//===----------------------------------------------------------------------===// + +void AffineForOp::build(Builder *builder, OperationState *result, + ArrayRef<Value *> lbOperands, AffineMap lbMap, + ArrayRef<Value *> ubOperands, AffineMap ubMap, + int64_t step) { + assert((!lbMap && lbOperands.empty()) || + lbOperands.size() == lbMap.getNumInputs() && + "lower bound operand count does not match the affine map"); + assert((!ubMap && ubOperands.empty()) || + ubOperands.size() == ubMap.getNumInputs() && + "upper bound operand count does not match the affine map"); + assert(step > 0 && "step has to be a positive integer constant"); + + // Add an attribute for the step. + result->addAttribute(getStepAttrName(), + builder->getIntegerAttr(builder->getIndexType(), step)); + + // Add the lower bound. + result->addAttribute(getLowerBoundAttrName(), + builder->getAffineMapAttr(lbMap)); + result->addOperands(lbOperands); + + // Add the upper bound. + result->addAttribute(getUpperBoundAttrName(), + builder->getAffineMapAttr(ubMap)); + result->addOperands(ubOperands); + + // Reserve a block list for the body. + result->reserveBlockLists(/*numReserved=*/1); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); +} + +void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step) { + auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); + auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); + return build(builder, result, {}, lbMap, {}, ubMap, step); +} + +bool AffineForOp::verify() const { + const auto &bodyBlockList = getInstruction()->getBlockList(0); + + // The body block list must contain a single basic block. + if (bodyBlockList.empty() || + std::next(bodyBlockList.begin()) != bodyBlockList.end()) + return emitOpError("expected body block list to have a single block"); + + // Check that the body defines as single block argument for the induction + // variable. + const auto *body = getBody(); + if (body->getNumArguments() != 1 || + !body->getArgument(0)->getType().isIndex()) + return emitOpError("expected body to have a single index argument for the " + "induction variable"); + + // TODO: check that loop bounds are properly formed. + return false; +} + +/// Parse a for operation loop bounds. +static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { + // 'min' / 'max' prefixes are generally syntactic sugar, but are required if + // the map has multiple results. + bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min"); + + auto &builder = p->getBuilder(); + auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() + : AffineForOp::getUpperBoundAttrName(); + + // Parse ssa-id as identity map. + SmallVector<OpAsmParser::OperandType, 1> boundOpInfos; + if (p->parseOperandList(boundOpInfos)) + return true; + + if (!boundOpInfos.empty()) { + // Check that only one operand was parsed. + if (boundOpInfos.size() > 1) + return p->emitError(p->getNameLoc(), + "expected only one loop bound operand"); + + // TODO: improve error message when SSA value is not an affine integer. + // Currently it is 'use of value ... expects different type than prior uses' + if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), + result->operands)) + return true; + + // Create an identity map using symbol id. This representation is optimized + // for storage. Analysis passes may expand it into a multi-dimensional map + // if desired. + AffineMap map = builder.getSymbolIdentityMap(); + result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); + return false; + } + + Attribute boundAttr; + if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName.data(), + result->attributes)) + return true; + + // Parse full form - affine map followed by dim and symbol list. + if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { + unsigned currentNumOperands = result->operands.size(); + unsigned numDims; + if (parseDimAndSymbolList(p, result->operands, numDims)) + return true; + + auto map = affineMapAttr.getValue(); + if (map.getNumDims() != numDims) + return p->emitError( + p->getNameLoc(), + "dim operand count and integer set dim count must match"); + + unsigned numDimAndSymbolOperands = + result->operands.size() - currentNumOperands; + if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) + return p->emitError( + p->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // If the map has multiple results, make sure that we parsed the min/max + // prefix. + if (map.getNumResults() > 1 && failedToParsedMinMax) { + if (isLower) { + return p->emitError(p->getNameLoc(), + "lower loop bound affine map with multiple results " + "requires 'max' prefix"); + } + return p->emitError(p->getNameLoc(), + "upper loop bound affine map with multiple results " + "requires 'min' prefix"); + } + return false; + } + + // Parse custom assembly form. + if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { + result->attributes.pop_back(); + result->addAttribute( + boundAttrName, builder.getAffineMapAttr( + builder.getConstantAffineMap(integerAttr.getInt()))); + return false; + } + + return p->emitError( + p->getNameLoc(), + "expected valid affine map representation for loop bounds"); +} + +bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + // Parse the induction variable followed by '='. + if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) || + parser->parseEqual()) + return true; + + // Parse loop bounds. + if (parseBound(/*isLower=*/true, result, parser) || + parser->parseKeyword("to", " between bounds") || + parseBound(/*isLower=*/false, result, parser)) + return true; + + // Parse the optional loop step, we default to 1 if one is not present. + if (parser->parseOptionalKeyword("step")) { + result->addAttribute( + getStepAttrName(), + builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); + } else { + llvm::SMLoc stepLoc; + IntegerAttr stepAttr; + if (parser->getCurrentLocation(&stepLoc) || + parser->parseAttribute(stepAttr, builder.getIndexType(), + getStepAttrName().data(), result->attributes)) + return true; + + if (stepAttr.getValue().getSExtValue() < 0) + return parser->emitError( + stepLoc, + "expected step to be representable as a positive signed integer"); + } + + // Parse the body block list. + result->reserveBlockLists(/*numReserved=*/1); + if (parser->parseBlockList()) + return true; + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); + return false; +} + +static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) { + AffineMap map = bound.getMap(); + + // Check if this bound should be printed using custom assembly form. + // The decision to restrict printing custom assembly form to trivial cases + // comes from the will to roundtrip MLIR binary -> text -> binary in a + // lossless way. + // Therefore, custom assembly form parsing and printing is only supported for + // zero-operand constant maps and single symbol operand identity maps. + if (map.getNumResults() == 1) { + AffineExpr expr = map.getResult(0); + + // Print constant bound. + if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { + if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { + *p << constExpr.getValue(); + return; + } + } + + // Print bound that consists of a single SSA symbol if the map is over a + // single symbol. + if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { + if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { + p->printOperand(bound.getOperand(0)); + return; + } + } + } else { + // Map has multiple results. Print 'min' or 'max' prefix. + *p << prefix << ' '; + } + + // Print the map and its operands. + p->printAffineMap(map); + printDimAndSymbolList(bound.operand_begin(), bound.operand_end(), + map.getNumDims(), p); +} + +void AffineForOp::print(OpAsmPrinter *p) const { + *p << "for "; + p->printOperand(getBody()->getArgument(0)); + *p << " = "; + printBound(getLowerBound(), "max", p); + *p << " to "; + printBound(getUpperBound(), "min", p); + + if (getStep() != 1) + *p << " step " << getStep(); + p->printBlockList(getInstruction()->getBlockList(0), + /*printEntryBlockArgs=*/false); +} + +Block *AffineForOp::createBody() { + auto &bodyBlockList = getBlockList(); + assert(bodyBlockList.empty() && "expected no existing body blocks"); + + // Create a new block for the body, and add an argument for the induction + // variable. + Block *body = new Block(); + body->addArgument(IndexType::get(getInstruction()->getContext())); + bodyBlockList.push_back(body); + return body; +} + +const AffineBound AffineForOp::getLowerBound() const { + auto lbMap = getLowerBoundMap(); + return AffineBound(ConstOpPointer<AffineForOp>(*this), 0, + lbMap.getNumInputs(), lbMap); +} + +const AffineBound AffineForOp::getUpperBound() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + return AffineBound(ConstOpPointer<AffineForOp>(*this), lbMap.getNumInputs(), + getNumOperands(), ubMap); +} + +void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end()); + + auto ubOperands = getUpperBoundOperands(); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector<Value *, 4> newOperands(getLowerBoundOperands()); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setLowerBoundMap(AffineMap map) { + auto lbMap = getLowerBoundMap(); + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)lbMap; + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBoundMap(AffineMap map) { + auto ubMap = getUpperBoundMap(); + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)ubMap; + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +bool AffineForOp::hasConstantLowerBound() const { + return getLowerBoundMap().isSingleConstant(); +} + +bool AffineForOp::hasConstantUpperBound() const { + return getUpperBoundMap().isSingleConstant(); +} + +int64_t AffineForOp::getConstantLowerBound() const { + return getLowerBoundMap().getSingleConstantResult(); +} + +int64_t AffineForOp::getConstantUpperBound() const { + return getUpperBoundMap().getSingleConstantResult(); +} + +void AffineForOp::setConstantLowerBound(int64_t value) { + setLowerBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +void AffineForOp::setConstantUpperBound(int64_t value) { + setUpperBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +bool AffineForOp::matchingBoundOperandList() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) + return false; + + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { + // Compare Value *'s. + if (getOperand(i) != getOperand(numOperands + i)) + return false; + } + return true; +} + +void AffineForOp::walkOps(std::function<void(OperationInst *)> callback) { + struct Walker : public InstWalker<Walker> { + std::function<void(OperationInst *)> const &callback; + Walker(std::function<void(OperationInst *)> const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker w(callback); + w.walk(getInstruction()); +} + +void AffineForOp::walkOpsPostOrder( + std::function<void(OperationInst *)> callback) { + struct Walker : public InstWalker<Walker> { + std::function<void(OperationInst *)> const &callback; + Walker(std::function<void(OperationInst *)> const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker v(callback); + v.walkPostOrder(getInstruction()); +} + +/// Returns the induction variable for this loop. +Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool mlir::isForInductionVar(const Value *val) { + return getForInductionVarOwner(val) != nullptr; +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +OpPointer<AffineForOp> mlir::getForInductionVarOwner(Value *val) { + const BlockArgument *ivArg = dyn_cast<BlockArgument>(val); + if (!ivArg || !ivArg->getOwner()) + return OpPointer<AffineForOp>(); + auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); + if (!containingInst) + return OpPointer<AffineForOp>(); + return cast<OperationInst>(containingInst)->dyn_cast<AffineForOp>(); +} +ConstOpPointer<AffineForOp> mlir::getForInductionVarOwner(const Value *val) { + auto nonConstOwner = getForInductionVarOwner(const_cast<Value *>(val)); + return ConstOpPointer<AffineForOp>(nonConstOwner); +} + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +SmallVector<Value *, 8> mlir::extractForInductionVars( + MutableArrayRef<OpPointer<AffineForOp>> forInsts) { + SmallVector<Value *, 8> results; + for (auto forInst : forInsts) + results.push_back(forInst->getInductionVar()); + return results; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 0153546a4c6..d2366f1ce81 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -21,12 +21,14 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" @@ -519,7 +521,7 @@ void mlir::getReachableAffineApplyOps( State &state = worklist.back(); auto *opInst = state.value->getDefiningInst(); // Note: getDefiningInst will return nullptr if the operand is not an - // OperationInst (i.e. ForInst), which is a terminator for the search. + // OperationInst (i.e. AffineForOp), which is a terminator for the search. if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) { worklist.pop_back(); continue; @@ -546,21 +548,21 @@ void mlir::getReachableAffineApplyOps( } // Builds a system of constraints with dimensional identifiers corresponding to -// the loop IVs of the forInsts appearing in that order. Any symbols founds in +// the loop IVs of the forOps appearing in that order. Any symbols founds in // the bound operands are added as symbols in the system. Returns false for the // yet unimplemented cases. // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or // stride information in FlatAffineConstraints. (For eg., by using iv - lb % // step = 0 and/or by introducing a method in FlatAffineConstraints // setExprStride(ArrayRef<int64_t> expr, int64_t stride) -bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts, +bool mlir::getIndexSet(MutableArrayRef<OpPointer<AffineForOp>> forOps, FlatAffineConstraints *domain) { - auto indices = extractForInductionVars(forInsts); + auto indices = extractForInductionVars(forOps); // Reset while associated Values in 'indices' to the domain. - domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); - for (auto *forInst : forInsts) { - // Add constraints from forInst's bounds. - if (!domain->addForInstDomain(*forInst)) + domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); + for (auto forOp : forOps) { + // Add constraints from forOp's bounds. + if (!domain->addAffineForOpDomain(forOp)) return false; } return true; @@ -576,7 +578,7 @@ static bool getInstIndexSet(const Instruction *inst, FlatAffineConstraints *indexSet) { // TODO(andydavis) Extend this to gather enclosing IfInsts and consider // factoring it out into a utility function. - SmallVector<ForInst *, 4> loops; + SmallVector<OpPointer<AffineForOp>, 4> loops; getLoopIVs(*inst, &loops); return getIndexSet(loops, indexSet); } @@ -998,9 +1000,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess, return block; } auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); - auto *forInst = getForInductionVarOwner(commonForValue); - assert(forInst && "commonForValue was not an induction variable"); - return forInst->getBody(); + auto forOp = getForInductionVarOwner(commonForValue); + assert(forOp && "commonForValue was not an induction variable"); + return forOp->getBody(); } // Returns true if the ancestor operation instruction of 'srcAccess' appears @@ -1195,7 +1197,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { // until operands of the AffineValueMap are loop IVs or symbols. // *) Build iteration domain constraints for each access. Iteration domain // constraints are pairs of inequality contraints representing the -// upper/lower loop bounds for each ForInst in the loop nest associated +// upper/lower loop bounds for each AffineForOp in the loop nest associated // with each access. // *) Build dimension and symbol position maps for each access, which map // Values from access functions and iteration domains to their position diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 5e7f8e3243c..c794899d3e1 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineStructures.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -1247,22 +1248,23 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { numSymbols = newSymbolCount; } -bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { +bool FlatAffineConstraints::addAffineForOpDomain( + ConstOpPointer<AffineForOp> forOp) { unsigned pos; // Pre-condition for this method. - if (!findId(*forInst.getInductionVar(), &pos)) { + if (!findId(*forOp->getInductionVar(), &pos)) { assert(0 && "Value not found"); return false; } - if (forInst.getStep() != 1) + if (forOp->getStep() != 1) LLVM_DEBUG(llvm::dbgs() << "Domain conservative: non-unit stride not handled\n"); // Adds a lower or upper bound when the bounds aren't constant. auto addLowerOrUpperBound = [&](bool lower) -> bool { - auto operands = lower ? forInst.getLowerBoundOperands() - : forInst.getUpperBoundOperands(); + auto operands = + lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands(); for (const auto &operand : operands) { unsigned loc; if (!findId(*operand, &loc)) { @@ -1291,7 +1293,7 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { } auto boundMap = - lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap(); + lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap(); FlatAffineConstraints localVarCst; std::vector<SmallVector<int64_t, 8>> flatExprs; @@ -1321,16 +1323,16 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { return true; }; - if (forInst.hasConstantLowerBound()) { - addConstantLowerBound(pos, forInst.getConstantLowerBound()); + if (forOp->hasConstantLowerBound()) { + addConstantLowerBound(pos, forOp->getConstantLowerBound()); } else { // Non-constant lower bound case. if (!addLowerOrUpperBound(/*lower=*/true)) return false; } - if (forInst.hasConstantUpperBound()) { - addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1); + if (forOp->hasConstantUpperBound()) { + addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1); return true; } // Non-constant upper bound case. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 7d88a3d9b9f..249776d42c9 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -43,27 +43,27 @@ using namespace mlir; /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count /// expression is simplified before returning. -AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { +AffineExpr mlir::getTripCountExpr(ConstOpPointer<AffineForOp> forOp) { // upper_bound - lower_bound int64_t loopSpan; - int64_t step = forInst.getStep(); - auto *context = forInst.getContext(); + int64_t step = forOp->getStep(); + auto *context = forOp->getInstruction()->getContext(); - if (forInst.hasConstantBounds()) { - int64_t lb = forInst.getConstantLowerBound(); - int64_t ub = forInst.getConstantUpperBound(); + if (forOp->hasConstantBounds()) { + int64_t lb = forOp->getConstantLowerBound(); + int64_t ub = forOp->getConstantUpperBound(); loopSpan = ub - lb; } else { - auto lbMap = forInst.getLowerBoundMap(); - auto ubMap = forInst.getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // TODO(bondhugula): handle max/min of multiple expressions. if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) return nullptr; // TODO(bondhugula): handle bounds with different operands. // Bounds have different operands, unhandled for now. - if (!forInst.matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return nullptr; // ub_expr - lb_expr @@ -89,8 +89,9 @@ AffineExpr mlir::getTripCountExpr(const ForInst &forInst) { /// Returns the trip count of the loop if it's a constant, None otherwise. This /// method uses affine expression analysis (in turn using getTripCount) and is /// able to determine constant trip count in non-trivial cases. -llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) { - auto tripCountExpr = getTripCountExpr(forInst); +llvm::Optional<uint64_t> +mlir::getConstantTripCount(ConstOpPointer<AffineForOp> forOp) { + auto tripCountExpr = getTripCountExpr(forOp); if (!tripCountExpr) return None; @@ -104,8 +105,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) { /// Returns the greatest known integral divisor of the trip count. Affine /// expression analysis is used (indirectly through getTripCount), and /// this method is thus able to determine non-trivial divisors. -uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { - auto tripCountExpr = getTripCountExpr(forInst); +uint64_t mlir::getLargestDivisorOfTripCount(ConstOpPointer<AffineForOp> forOp) { + auto tripCountExpr = getTripCountExpr(forOp); if (!tripCountExpr) return 1; @@ -126,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { } bool mlir::isAccessInvariant(const Value &iv, const Value &index) { - assert(isForInductionVar(&iv) && "iv must be a ForInst"); + assert(isForInductionVar(&iv) && "iv must be a AffineForOp"); assert(index.getType().isa<IndexType>() && "index must be of IndexType"); SmallVector<OperationInst *, 4> affineApplyOps; getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps); @@ -163,7 +164,7 @@ mlir::getInvariantAccesses(const Value &iv, } /// Given: -/// 1. an induction variable `iv` of type ForInst; +/// 1. an induction variable `iv` of type AffineForOp; /// 2. a `memoryOp` of type const LoadOp& or const StoreOp&; /// 3. the index of the `fastestVaryingDim` along which to check; /// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access @@ -231,17 +232,18 @@ static bool isVectorTransferReadOrWrite(const Instruction &inst) { } using VectorizableInstFun = - std::function<bool(const ForInst &, const OperationInst &)>; + std::function<bool(ConstOpPointer<AffineForOp>, const OperationInst &)>; -static bool isVectorizableLoopWithCond(const ForInst &loop, +static bool isVectorizableLoopWithCond(ConstOpPointer<AffineForOp> loop, VectorizableInstFun isVectorizableInst) { - if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) { + auto *forInst = const_cast<OperationInst *>(loop->getInstruction()); + if (!matcher::isParallelLoop(*forInst) && + !matcher::isReductionLoop(*forInst)) { return false; } // No vectorization across conditionals for now. auto conditionals = matcher::If(); - auto *forInst = const_cast<ForInst *>(&loop); SmallVector<NestedMatch, 8> conditionalsMatched; conditionals.match(forInst, &conditionalsMatched); if (!conditionalsMatched.empty()) { @@ -251,7 +253,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, // No vectorization across unknown regions. auto regions = matcher::Op([](const Instruction &inst) -> bool { auto &opInst = cast<OperationInst>(inst); - return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>(); + return opInst.getNumBlockLists() != 0 && + !(opInst.isa<AffineIfOp>() || opInst.isa<AffineForOp>()); }); SmallVector<NestedMatch, 8> regionsMatched; regions.match(forInst, ®ionsMatched); @@ -288,23 +291,25 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, } bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( - const ForInst &loop, unsigned fastestVaryingDim) { - VectorizableInstFun fun( - [fastestVaryingDim](const ForInst &loop, const OperationInst &op) { - auto load = op.dyn_cast<LoadOp>(); - auto store = op.dyn_cast<StoreOp>(); - return load ? isContiguousAccess(*loop.getInductionVar(), *load, - fastestVaryingDim) - : isContiguousAccess(*loop.getInductionVar(), *store, - fastestVaryingDim); - }); + ConstOpPointer<AffineForOp> loop, unsigned fastestVaryingDim) { + VectorizableInstFun fun([fastestVaryingDim](ConstOpPointer<AffineForOp> loop, + const OperationInst &op) { + auto load = op.dyn_cast<LoadOp>(); + auto store = op.dyn_cast<StoreOp>(); + return load ? isContiguousAccess(*loop->getInductionVar(), *load, + fastestVaryingDim) + : isContiguousAccess(*loop->getInductionVar(), *store, + fastestVaryingDim); + }); return isVectorizableLoopWithCond(loop, fun); } -bool mlir::isVectorizableLoop(const ForInst &loop) { +bool mlir::isVectorizableLoop(ConstOpPointer<AffineForOp> loop) { VectorizableInstFun fun( // TODO: implement me - [](const ForInst &loop, const OperationInst &op) { return true; }); + [](ConstOpPointer<AffineForOp> loop, const OperationInst &op) { + return true; + }); return isVectorizableLoopWithCond(loop, fun); } @@ -313,9 +318,9 @@ bool mlir::isVectorizableLoop(const ForInst &loop) { /// 'def' and all its uses have the same shift factor. // TODO(mlir-team): extend this to check for memory-based dependence // violation when we have the support. -bool mlir::isInstwiseShiftValid(const ForInst &forInst, +bool mlir::isInstwiseShiftValid(ConstOpPointer<AffineForOp> forOp, ArrayRef<uint64_t> shifts) { - auto *forBody = forInst.getBody(); + auto *forBody = forOp->getBody(); assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; for (const auto &inst : *forBody) { @@ -325,7 +330,7 @@ bool mlir::isInstwiseShiftValid(const ForInst &forInst, for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) { const Value *result = opInst->getResult(i); for (const InstOperand &use : result->getUses()) { - // If an ancestor instruction doesn't lie in the block of forInst, + // If an ancestor instruction doesn't lie in the block of forOp, // there is no shift to check. This is a naive way. If performance // becomes an issue, a map can be used to store 'shifts' - to look up // the shift for a instruction in constant time. diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 46bf5ad0b97..214b4ce403c 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -115,6 +115,10 @@ void NestedPattern::matchOne(Instruction *inst, } } +static bool isAffineForOp(const Instruction &inst) { + return cast<OperationInst>(inst).isa<AffineForOp>(); +} + static bool isAffineIfOp(const Instruction &inst) { return isa<OperationInst>(inst) && cast<OperationInst>(inst).isa<AffineIfOp>(); @@ -147,28 +151,34 @@ NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { } NestedPattern For(NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } NestedPattern For(ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::For, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { - const auto *loop = cast<ForInst>(&inst); - return (void *)loop || true; // loop->isParallel(); + auto loop = cast<OperationInst>(inst).cast<AffineForOp>(); + return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { - const auto *loop = cast<ForInst>(&inst); - return (void *)loop || true; // loop->isReduction(); + auto loop = cast<OperationInst>(inst).cast<AffineForOp>(); + return loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index d16a7fcb1b3..4025af936f3 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -20,6 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" @@ -52,7 +53,16 @@ void mlir::getForwardSlice(Instruction *inst, return; } - if (auto *opInst = dyn_cast<OperationInst>(inst)) { + auto *opInst = cast<OperationInst>(inst); + if (auto forOp = opInst->dyn_cast<AffineForOp>()) { + for (auto &u : forOp->getInductionVar()->getUses()) { + auto *ownerInst = u.getOwner(); + if (forwardSlice->count(ownerInst) == 0) { + getForwardSlice(ownerInst, forwardSlice, filter, + /*topLevel=*/false); + } + } + } else { assert(opInst->getNumResults() <= 1 && "NYI: multiple results"); if (opInst->getNumResults() > 0) { for (auto &u : opInst->getResult(0)->getUses()) { @@ -63,16 +73,6 @@ void mlir::getForwardSlice(Instruction *inst, } } } - } else if (auto *forInst = dyn_cast<ForInst>(inst)) { - for (auto &u : forInst->getInductionVar()->getUses()) { - auto *ownerInst = u.getOwner(); - if (forwardSlice->count(ownerInst) == 0) { - getForwardSlice(ownerInst, forwardSlice, filter, - /*topLevel=*/false); - } - } - } else { - assert(false && "NYI slicing case"); } // At the top level we reverse to get back the actual topological order. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0e77d4d9084..4b8afd9a620 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -38,15 +38,17 @@ using namespace mlir; /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. void mlir::getLoopIVs(const Instruction &inst, - SmallVectorImpl<ForInst *> *loops) { + SmallVectorImpl<OpPointer<AffineForOp>> *loops) { auto *currInst = inst.getParentInst(); - ForInst *currForInst; + OpPointer<AffineForOp> currAffineForOp; // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. - while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) || - cast<OperationInst>(currInst)->isa<AffineIfOp>())) { - if (currForInst) - loops->push_back(currForInst); + while (currInst && + ((currAffineForOp = + cast<OperationInst>(currInst)->dyn_cast<AffineForOp>()) || + cast<OperationInst>(currInst)->isa<AffineIfOp>())) { + if (currAffineForOp) + loops->push_back(currAffineForOp); currInst = currInst->getParentInst(); } std::reverse(loops->begin(), loops->end()); @@ -148,7 +150,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, if (rank == 0) { // A rank 0 memref has a 0-d region. - SmallVector<ForInst *, 4> ivs; + SmallVector<OpPointer<AffineForOp>, 4> ivs; getLoopIVs(*opInst, &ivs); SmallVector<Value *, 8> regionSymbols = extractForInductionVars(ivs); @@ -174,12 +176,12 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { + if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { // Note that regionCst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. - if (!regionCst->addForInstDomain(*loop)) + if (!regionCst->addAffineForOpDomain(loop)) return false; } else { // Has to be a valid symbol. @@ -203,14 +205,14 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. - SmallVector<ForInst *, 4> outerIVs; + SmallVector<OpPointer<AffineForOp>, 4> outerIVs; getLoopIVs(*opInst, &outerIVs); assert(loopDepth <= outerIVs.size() && "invalid loop depth"); outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { - ForInst *iv; + OpPointer<AffineForOp> iv; if ((iv = getForInductionVarOwner(operand)) && - std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) { + llvm::is_contained(outerIVs, iv) == false) { regionCst->projectOut(operand); } } @@ -357,8 +359,10 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions, } if (level == positions.size() - 1) return &inst; - if (auto *childForInst = dyn_cast<ForInst>(&inst)) - return getInstAtPosition(positions, level + 1, childForInst->getBody()); + if (auto childAffineForOp = + cast<OperationInst>(inst).dyn_cast<AffineForOp>()) + return getInstAtPosition(positions, level + 1, + childAffineForOp->getBody()); for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) { for (auto &b : blockList) @@ -385,12 +389,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, return false; } // Get loop nest surrounding src operation. - SmallVector<ForInst *, 4> srcLoopIVs; + SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs; getLoopIVs(*srcAccess.opInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector<ForInst *, 4> dstLoopIVs; + SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs; getLoopIVs(*dstAccess.opInst, &dstLoopIVs); unsigned numDstLoopIVs = dstLoopIVs.size(); if (dstLoopDepth > numDstLoopIVs) { @@ -437,38 +441,41 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // solution. // TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project // out loop IVs we don't care about and produce smaller slice. -ForInst *mlir::insertBackwardComputationSlice( +OpPointer<AffineForOp> mlir::insertBackwardComputationSlice( OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. - SmallVector<ForInst *, 4> srcLoopIVs; + SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. - SmallVector<ForInst *, 4> dstLoopIVs; + SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs; getLoopIVs(*dstOpInst, &dstLoopIVs); unsigned dstLoopIVsSize = dstLoopIVs.size(); if (dstLoopDepth > dstLoopIVsSize) { dstOpInst->emitError("invalid destination loop depth"); - return nullptr; + return OpPointer<AffineForOp>(); } // Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector<unsigned, 4> positions; // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. - findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions); + findInstPosition(srcOpInst, srcLoopIVs[0]->getInstruction()->getBlock(), + &positions); // Clone src loop nest and insert it a the beginning of the instruction block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. - auto *dstForInst = dstLoopIVs[dstLoopDepth - 1]; - FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin()); - auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopIVs[0])); + auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; + FuncBuilder b(dstAffineForOp->getBody(), dstAffineForOp->getBody()->begin()); + auto sliceLoopNest = + cast<OperationInst>(b.clone(*srcLoopIVs[0]->getInstruction())) + ->cast<AffineForOp>(); Instruction *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); // Get loop nest surrounding 'sliceInst'. - SmallVector<ForInst *, 4> sliceSurroundingLoops; + SmallVector<OpPointer<AffineForOp>, 4> sliceSurroundingLoops; getLoopIVs(*sliceInst, &sliceSurroundingLoops); // Sanity check. @@ -481,11 +488,11 @@ ForInst *mlir::insertBackwardComputationSlice( // Update loop bounds for loops in 'sliceLoopNest'. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - auto *forInst = sliceSurroundingLoops[dstLoopDepth + i]; + auto forOp = sliceSurroundingLoops[dstLoopDepth + i]; if (AffineMap lbMap = sliceState->lbs[i]) - forInst->setLowerBound(sliceState->lbOperands[i], lbMap); + forOp->setLowerBound(sliceState->lbOperands[i], lbMap); if (AffineMap ubMap = sliceState->ubs[i]) - forInst->setUpperBound(sliceState->ubOperands[i], ubMap); + forOp->setUpperBound(sliceState->ubOperands[i], ubMap); } return sliceLoopNest; } @@ -520,7 +527,7 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) { const Instruction *currInst = &stmt; unsigned depth = 0; while ((currInst = currInst->getParentInst())) { - if (isa<ForInst>(currInst)) + if (cast<OperationInst>(currInst)->isa<AffineForOp>()) depth++; } return depth; @@ -530,14 +537,14 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) { /// where each lists loops from outer-most to inner-most in loop nest. unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, const Instruction &B) { - SmallVector<ForInst *, 4> loopsA, loopsB; + SmallVector<OpPointer<AffineForOp>, 4> loopsA, loopsB; getLoopIVs(A, &loopsA); getLoopIVs(B, &loopsB); unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (loopsA[i] != loopsB[i]) + if (loopsA[i]->getInstruction() != loopsB[i]->getInstruction()) break; ++numCommonLoops; } @@ -571,13 +578,14 @@ static Optional<int64_t> getRegionSize(const MemRefRegion ®ion) { return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); } -Optional<int64_t> mlir::getMemoryFootprintBytes(const ForInst &forInst, - int memorySpace) { +Optional<int64_t> +mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp, + int memorySpace) { std::vector<std::unique_ptr<MemRefRegion>> regions; // Walk this 'for' instruction to gather all memory regions. bool error = false; - const_cast<ForInst *>(&forInst)->walkOps([&](OperationInst *opInst) { + const_cast<AffineForOp &>(*forOp).walkOps([&](OperationInst *opInst) { if (!opInst->isa<LoadOp>() && !opInst->isa<StoreOp>()) { // Neither load nor a store op. return; diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 125020e92a3..4865cb03bb4 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -16,10 +16,12 @@ // ============================================================================= #include "mlir/Analysis/VectorAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" @@ -105,7 +107,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType, static AffineMap makePermutationMap( MLIRContext *context, llvm::iterator_range<OperationInst::operand_iterator> indices, - const DenseMap<ForInst *, unsigned> &enclosingLoopToVectorDim) { + const DenseMap<Instruction *, unsigned> &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices); @@ -113,8 +115,9 @@ static AffineMap makePermutationMap( getAffineConstantExpr(0, context)); for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); - auto invariants = - getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices); + auto invariants = getInvariantAccesses( + *cast<OperationInst>(kvp.first)->cast<AffineForOp>()->getInductionVar(), + unwrappedIndices); unsigned numIndices = unwrappedIndices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { @@ -139,30 +142,30 @@ static AffineMap makePermutationMap( /// TODO(ntv): could also be implemented as a collect parents followed by a /// filter and made available outside this file. template <typename T> -static SetVector<T *> getParentsOfType(Instruction *inst) { - SetVector<T *> res; +static SetVector<OperationInst *> getParentsOfType(Instruction *inst) { + SetVector<OperationInst *> res; auto *current = inst; while (auto *parent = current->getParentInst()) { - auto *typedParent = dyn_cast<T>(parent); - if (typedParent) { - assert(res.count(typedParent) == 0 && "Already inserted"); - res.insert(typedParent); + if (auto typedParent = + cast<OperationInst>(parent)->template dyn_cast<T>()) { + assert(res.count(cast<OperationInst>(parent)) == 0 && "Already inserted"); + res.insert(cast<OperationInst>(parent)); } current = parent; } return res; } -/// Returns the enclosing ForInst, from closest to farthest. -static SetVector<ForInst *> getEnclosingforInsts(Instruction *inst) { - return getParentsOfType<ForInst>(inst); +/// Returns the enclosing AffineForOp, from closest to farthest. +static SetVector<OperationInst *> getEnclosingforOps(Instruction *inst) { + return getParentsOfType<AffineForOp>(inst); } -AffineMap -mlir::makePermutationMap(OperationInst *opInst, - const DenseMap<ForInst *, unsigned> &loopToVectorDim) { - DenseMap<ForInst *, unsigned> enclosingLoopToVectorDim; - auto enclosingLoops = getEnclosingforInsts(opInst); +AffineMap mlir::makePermutationMap( + OperationInst *opInst, + const DenseMap<Instruction *, unsigned> &loopToVectorDim) { + DenseMap<Instruction *, unsigned> enclosingLoopToVectorDim; + auto enclosingLoops = getEnclosingforOps(opInst); for (auto *forInst : enclosingLoops) { auto it = loopToVectorDim.find(forInst); if (it != loopToVectorDim.end()) { diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 474eeb2a28e..a69831053ad 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -72,7 +72,6 @@ public: bool verify(); bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); - bool verifyForInst(const ForInst &forInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -175,10 +174,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyOperation(cast<OperationInst>(inst))) return true; break; - case Instruction::Kind::For: - if (verifyForInst(cast<ForInst>(inst))) - return true; - break; } } @@ -240,11 +235,6 @@ bool FuncVerifier::verifyOperation(const OperationInst &op) { return false; } -bool FuncVerifier::verifyForInst(const ForInst &forInst) { - // TODO: check that loop bounds are properly formed. - return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); -} - bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -262,10 +252,6 @@ bool FuncVerifier::verifyDominance(const Block &block) { return true; break; } - case Instruction::Kind::For: - if (verifyDominance(*cast<ForInst>(inst).getBody())) - return true; - break; } } return false; diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index dc85c5ed682..f4d5d36d25b 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -21,12 +21,14 @@ #include "llvm/Support/raw_ostream.h" #include "mlir-c/Core.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/EDSC/MLIREmitter.h" #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/Value.h" #include "mlir/StandardOps/StandardOps.h" @@ -133,8 +135,8 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) { inst->print(os); return; } - if (auto *forInst = getForInductionVarOwner(&v)) { - forInst->print(os); + if (auto forInst = getForInductionVarOwner(&v)) { + forInst->getInstruction()->print(os); } else { os << "unknown_ssa_value"; } @@ -300,7 +302,9 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { exprs[1]->getDefiningInst()->cast<ConstantIndexOp>()->getValue(); auto step = exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue(); - res = builder->createFor(location, lb, ub, step)->getInductionVar(); + auto forOp = builder->create<AffineForOp>(location, lb, ub, step); + forOp->createBody(); + res = forOp->getInductionVar(); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index cb4c1f0edce..0fb18fa0004 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -130,21 +130,8 @@ private: void recordTypeReference(Type ty) { usedTypes.insert(ty); } - // Return true if this map could be printed using the custom assembly form. - static bool hasCustomForm(AffineMap boundMap) { - if (boundMap.isSingleConstant()) - return true; - - // Check if the affine map is single dim id or single symbol identity - - // (i)->(i) or ()[s]->(i) - return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 && - (boundMap.getResult(0).isa<AffineDimExpr>() || - boundMap.getResult(0).isa<AffineSymbolExpr>()); - } - // Visit functions. void visitInstruction(const Instruction *inst); - void visitForInst(const ForInst *forInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -196,16 +183,6 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitForInst(const ForInst *forInst) { - AffineMap lbMap = forInst->getLowerBoundMap(); - if (!hasCustomForm(lbMap)) - recordAffineMapReference(lbMap); - - AffineMap ubMap = forInst->getUpperBoundMap(); - if (!hasCustomForm(ubMap)) - recordAffineMapReference(ubMap); -} - void ModuleState::visitOperationInst(const OperationInst *op) { // Visit all the types used in the operation. for (auto *operand : op->getOperands()) @@ -220,8 +197,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { - case Instruction::Kind::For: - return visitForInst(cast<ForInst>(inst)); case Instruction::Kind::OperationInst: return visitOperationInst(cast<OperationInst>(inst)); } @@ -1069,7 +1044,6 @@ public: // Methods to print instructions. void print(const Instruction *inst); void print(const OperationInst *inst); - void print(const ForInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1117,10 +1091,8 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) override { - printBlockList(blocks, /*printEntryBlockArgs=*/true); - } - void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { + void printBlockList(const BlockList &blocks, + bool printEntryBlockArgs) override { os << " {\n"; if (!blocks.empty()) { auto *entryBlock = &blocks.front(); @@ -1132,10 +1104,6 @@ public: os.indent(currentIndent) << "}"; } - // Print if and loop bounds. - void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims); - void printBound(AffineBound bound, const char *prefix); - // Number of spaces used for indenting nested instructions. const static unsigned indentWidth = 2; @@ -1205,10 +1173,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { numberValuesInBlock(block); break; } - case Instruction::Kind::For: - // Recursively number the stuff in the body. - numberValuesInBlock(*cast<ForInst>(&inst)->getBody()); - break; } } } @@ -1404,8 +1368,6 @@ void FunctionPrinter::print(const Instruction *inst) { switch (inst->getKind()) { case Instruction::Kind::OperationInst: return print(cast<OperationInst>(inst)); - case Instruction::Kind::For: - return print(cast<ForInst>(inst)); } } @@ -1415,24 +1377,6 @@ void FunctionPrinter::print(const OperationInst *inst) { printTrailingLocation(inst->getLoc()); } -void FunctionPrinter::print(const ForInst *inst) { - os.indent(currentIndent) << "for "; - printOperand(inst->getInductionVar()); - os << " = "; - printBound(inst->getLowerBound(), "max"); - os << " to "; - printBound(inst->getUpperBound(), "min"); - - if (inst->getStep() != 1) - os << " step " << inst->getStep(); - - printTrailingLocation(inst->getLoc()); - - os << " {\n"; - print(inst->getBody(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; -} - void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; @@ -1560,62 +1504,6 @@ void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, os << ')'; } -void FunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops, - unsigned numDims) { - auto printComma = [&]() { os << ", "; }; - os << '('; - interleave( - ops.begin(), ops.begin() + numDims, - [&](const InstOperand &v) { printOperand(v.get()); }, printComma); - os << ')'; - - if (numDims < ops.size()) { - os << '['; - interleave( - ops.begin() + numDims, ops.end(), - [&](const InstOperand &v) { printOperand(v.get()); }, printComma); - os << ']'; - } -} - -void FunctionPrinter::printBound(AffineBound bound, const char *prefix) { - AffineMap map = bound.getMap(); - - // Check if this bound should be printed using custom assembly form. - // The decision to restrict printing custom assembly form to trivial cases - // comes from the will to roundtrip MLIR binary -> text -> binary in a - // lossless way. - // Therefore, custom assembly form parsing and printing is only supported for - // zero-operand constant maps and single symbol operand identity maps. - if (map.getNumResults() == 1) { - AffineExpr expr = map.getResult(0); - - // Print constant bound. - if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { - if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { - os << constExpr.getValue(); - return; - } - } - - // Print bound that consists of a single SSA symbol if the map is over a - // single symbol. - if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { - if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { - printOperand(bound.getOperand(0)); - return; - } - } - } else { - // Map has multiple results. Print 'min' or 'max' prefix. - os << prefix << ' '; - } - - // Print the map and its operands. - printAffineMapReference(map); - printDimAndSymbolList(bound.getInstOperands(), map.getNumDims()); -} - // Prints function with initialized module state. void ModulePrinter::print(const Function *fn) { FunctionPrinter(fn, *this).print(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index ffeb4e0317f..68fbef2d27a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -312,19 +312,3 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) { block->getInstructions().insert(insertPoint, op); return op; } - -ForInst *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands, - AffineMap lbMap, ArrayRef<Value *> ubOperands, - AffineMap ubMap, int64_t step) { - auto *inst = - ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step); - block->getInstructions().insert(insertPoint, inst); - return inst; -} - -ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, - int64_t step) { - auto lbMap = AffineMap::getConstantMap(lb, context); - auto ubMap = AffineMap::getConstantMap(ub, context); - return createFor(location, {}, lbMap, {}, ubMap, step); -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 8d43e3a783d..03f1a2702c9 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -143,9 +143,6 @@ void Instruction::destroy() { case Kind::OperationInst: cast<OperationInst>(this)->destroy(); break; - case Kind::For: - cast<ForInst>(this)->destroy(); - break; } } @@ -209,8 +206,6 @@ unsigned Instruction::getNumOperands() const { switch (getKind()) { case Kind::OperationInst: return cast<OperationInst>(this)->getNumOperands(); - case Kind::For: - return cast<ForInst>(this)->getNumOperands(); } } @@ -218,8 +213,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() { switch (getKind()) { case Kind::OperationInst: return cast<OperationInst>(this)->getInstOperands(); - case Kind::For: - return cast<ForInst>(this)->getInstOperands(); } } @@ -349,10 +342,6 @@ void Instruction::dropAllReferences() { op.drop(); switch (getKind()) { - case Kind::For: - // Make sure to drop references held by instructions within the body. - cast<ForInst>(this)->getBody()->dropAllReferences(); - break; case Kind::OperationInst: { auto *opInst = cast<OperationInst>(this); if (isTerminator()) @@ -656,217 +645,6 @@ bool OperationInst::emitOpError(const Twine &message) const { } //===----------------------------------------------------------------------===// -// ForInst -//===----------------------------------------------------------------------===// - -ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands, - AffineMap lbMap, ArrayRef<Value *> ubOperands, - AffineMap ubMap, int64_t step) { - assert((!lbMap && lbOperands.empty()) || - lbOperands.size() == lbMap.getNumInputs() && - "lower bound operand count does not match the affine map"); - assert((!ubMap && ubOperands.empty()) || - ubOperands.size() == ubMap.getNumInputs() && - "upper bound operand count does not match the affine map"); - assert(step > 0 && "step has to be a positive integer constant"); - - // Compute the byte size for the instruction and the operand storage. - unsigned numOperands = lbOperands.size() + ubOperands.size(); - auto byteSize = totalSizeToAlloc<detail::OperandStorage>( - /*detail::OperandStorage*/ 1); - byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize( - numOperands, /*resizable=*/true), - alignof(ForInst)); - void *rawMem = malloc(byteSize); - - // Initialize the OperationInst part of the instruction. - ForInst *inst = ::new (rawMem) ForInst(location, lbMap, ubMap, step); - new (&inst->getOperandStorage()) - detail::OperandStorage(numOperands, /*resizable=*/true); - - auto operands = inst->getInstOperands(); - unsigned i = 0; - for (unsigned e = lbOperands.size(); i != e; ++i) - new (&operands[i]) InstOperand(inst, lbOperands[i]); - - for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j) - new (&operands[i]) InstOperand(inst, ubOperands[j]); - - return inst; -} - -ForInst::ForInst(Location location, AffineMap lbMap, AffineMap ubMap, - int64_t step) - : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap), - ubMap(ubMap), step(step) { - - // The body of a for inst always has one block. - auto *bodyEntry = new Block(); - body.push_back(bodyEntry); - - // Add an argument to the block for the induction variable. - bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext())); -} - -ForInst::~ForInst() { getOperandStorage().~OperandStorage(); } - -const AffineBound ForInst::getLowerBound() const { - return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap); -} - -const AffineBound ForInst::getUpperBound() const { - return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); -} - -void ForInst::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { - assert(lbOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end()); - - auto ubOperands = getUpperBoundOperands(); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperandStorage().setOperands(this, newOperands); - - this->lbMap = map; -} - -void ForInst::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { - assert(ubOperands.size() == map.getNumInputs()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - - SmallVector<Value *, 4> newOperands(getLowerBoundOperands()); - newOperands.append(ubOperands.begin(), ubOperands.end()); - getOperandStorage().setOperands(this, newOperands); - - this->ubMap = map; -} - -void ForInst::setLowerBoundMap(AffineMap map) { - assert(lbMap.getNumDims() == map.getNumDims() && - lbMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->lbMap = map; -} - -void ForInst::setUpperBoundMap(AffineMap map) { - assert(ubMap.getNumDims() == map.getNumDims() && - ubMap.getNumSymbols() == map.getNumSymbols()); - assert(map.getNumResults() >= 1 && "bound map has at least one result"); - this->ubMap = map; -} - -bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } - -bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } - -int64_t ForInst::getConstantLowerBound() const { - return lbMap.getSingleConstantResult(); -} - -int64_t ForInst::getConstantUpperBound() const { - return ubMap.getSingleConstantResult(); -} - -void ForInst::setConstantLowerBound(int64_t value) { - setLowerBound({}, AffineMap::getConstantMap(value, getContext())); -} - -void ForInst::setConstantUpperBound(int64_t value) { - setUpperBound({}, AffineMap::getConstantMap(value, getContext())); -} - -ForInst::operand_range ForInst::getLowerBoundOperands() { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForInst::const_operand_range ForInst::getLowerBoundOperands() const { - return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; -} - -ForInst::operand_range ForInst::getUpperBoundOperands() { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -ForInst::const_operand_range ForInst::getUpperBoundOperands() const { - return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; -} - -bool ForInst::matchingBoundOperandList() const { - if (lbMap.getNumDims() != ubMap.getNumDims() || - lbMap.getNumSymbols() != ubMap.getNumSymbols()) - return false; - - unsigned numOperands = lbMap.getNumInputs(); - for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { - // Compare Value *'s. - if (getOperand(i) != getOperand(numOperands + i)) - return false; - } - return true; -} - -void ForInst::walkOps(std::function<void(OperationInst *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(OperationInst *)> const &callback; - Walker(std::function<void(OperationInst *)> const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } - }; - - Walker w(callback); - w.walk(this); -} - -void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(OperationInst *)> const &callback; - Walker(std::function<void(OperationInst *)> const &callback) - : callback(callback) {} - - void visitOperationInst(OperationInst *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walkPostOrder(this); -} - -/// Returns the induction variable for this loop. -Value *ForInst::getInductionVar() { return getBody()->getArgument(0); } - -void ForInst::destroy() { - this->~ForInst(); - free(this); -} - -/// Returns if the provided value is the induction variable of a ForInst. -bool mlir::isForInductionVar(const Value *val) { - return getForInductionVarOwner(val) != nullptr; -} - -/// Returns the loop parent of an induction variable. If the provided value is -/// not an induction variable, then return nullptr. -ForInst *mlir::getForInductionVarOwner(Value *val) { - const BlockArgument *ivArg = dyn_cast<BlockArgument>(val); - if (!ivArg || !ivArg->getOwner()) - return nullptr; - return dyn_cast_or_null<ForInst>( - ivArg->getOwner()->getParent()->getContainingInst()); -} -const ForInst *mlir::getForInductionVarOwner(const Value *val) { - return getForInductionVarOwner(const_cast<Value *>(val)); -} - -/// Extracts the induction variables from a list of ForInsts and returns them. -SmallVector<Value *, 8> -mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) { - SmallVector<Value *, 8> results; - for (auto *forInst : forInsts) - results.push_back(forInst->getInductionVar()); - return results; -} -//===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -879,84 +657,59 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, MLIRContext *context) const { SmallVector<Value *, 8> operands; SmallVector<Block *, 2> successors; - if (auto *opInst = dyn_cast<OperationInst>(this)) { - operands.reserve(getNumOperands() + opInst->getNumSuccessors()); - if (!opInst->isTerminator()) { - // Non-terminators just add all the operands. - for (auto *opValue : getOperands()) + auto *opInst = cast<OperationInst>(this); + operands.reserve(getNumOperands() + opInst->getNumSuccessors()); + + if (!opInst->isTerminator()) { + // Non-terminators just add all the operands. + for (auto *opValue : getOperands()) + operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue))); + } else { + // We add the operands separated by nullptr's for each successor. + unsigned firstSuccOperand = opInst->getNumSuccessors() + ? opInst->getSuccessorOperandIndex(0) + : opInst->getNumOperands(); + auto InstOperands = opInst->getInstOperands(); + + unsigned i = 0; + for (; i != firstSuccOperand; ++i) + operands.push_back( + mapper.lookupOrDefault(const_cast<Value *>(InstOperands[i].get()))); + + successors.reserve(opInst->getNumSuccessors()); + for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; ++succ) { + successors.push_back(mapper.lookupOrDefault( + const_cast<Block *>(opInst->getSuccessor(succ)))); + + // Add sentinel to delineate successor operands. + operands.push_back(nullptr); + + // Remap the successors operands. + for (auto *operand : opInst->getSuccessorOperands(succ)) operands.push_back( - mapper.lookupOrDefault(const_cast<Value *>(opValue))); - } else { - // We add the operands separated by nullptr's for each successor. - unsigned firstSuccOperand = opInst->getNumSuccessors() - ? opInst->getSuccessorOperandIndex(0) - : opInst->getNumOperands(); - auto InstOperands = opInst->getInstOperands(); - - unsigned i = 0; - for (; i != firstSuccOperand; ++i) - operands.push_back( - mapper.lookupOrDefault(const_cast<Value *>(InstOperands[i].get()))); - - successors.reserve(opInst->getNumSuccessors()); - for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; - ++succ) { - successors.push_back(mapper.lookupOrDefault( - const_cast<Block *>(opInst->getSuccessor(succ)))); - - // Add sentinel to delineate successor operands. - operands.push_back(nullptr); - - // Remap the successors operands. - for (auto *operand : opInst->getSuccessorOperands(succ)) - operands.push_back( - mapper.lookupOrDefault(const_cast<Value *>(operand))); - } + mapper.lookupOrDefault(const_cast<Value *>(operand))); } - - SmallVector<Type, 8> resultTypes; - resultTypes.reserve(opInst->getNumResults()); - for (auto *result : opInst->getResults()) - resultTypes.push_back(result->getType()); - - unsigned numBlockLists = opInst->getNumBlockLists(); - auto *newOp = OperationInst::create( - getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(), - successors, numBlockLists, opInst->hasResizableOperandsList(), context); - - // Clone the block lists. - for (unsigned i = 0; i != numBlockLists; ++i) - opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, - context); - - // Remember the mapping of any results. - for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) - mapper.map(opInst->getResult(i), newOp->getResult(i)); - return newOp; } - operands.reserve(getNumOperands()); - for (auto *opValue : getOperands()) - operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue))); + SmallVector<Type, 8> resultTypes; + resultTypes.reserve(opInst->getNumResults()); + for (auto *result : opInst->getResults()) + resultTypes.push_back(result->getType()); - // Otherwise, this must be a ForInst. - auto *forInst = cast<ForInst>(this); - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + unsigned numBlockLists = opInst->getNumBlockLists(); + auto *newOp = OperationInst::create( + getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(), + successors, numBlockLists, opInst->hasResizableOperandsList(), context); - auto *newFor = ForInst::create( - getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap, - forInst->getStep()); - - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); + // Clone the block lists. + for (unsigned i = 0; i != numBlockLists; ++i) + opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, context); - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - return newFor; + // Remember the mapping of any results. + for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i) + mapper.map(opInst->getResult(i), newOp->getResult(i)); + return newOp; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 7103eeb7389..a9c046dc7b1 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -64,8 +64,6 @@ MLIRContext *IROperandOwner::getContext() const { switch (getKind()) { case Kind::OperationInst: return cast<OperationInst>(this)->getContext(); - case Kind::ForInst: - return cast<ForInst>(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f0c140166ed..a9c62767734 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2128,23 +2128,6 @@ public: parseSuccessors(SmallVectorImpl<Block *> &destinations, SmallVectorImpl<SmallVector<Value *, 4>> &operands); - ParseResult - parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results, - Block *owner); - - ParseResult parseOperationBlockList(SmallVectorImpl<Block *> &results); - ParseResult parseBlockListBody(SmallVectorImpl<Block *> &results); - ParseResult parseBlock(Block *&block); - ParseResult parseBlockBody(Block *block); - - /// Cleans up the memory for allocated blocks when a parser error occurs. - void cleanupInvalidBlocks(ArrayRef<Block *> invalidBlocks) { - // Add the referenced blocks to the function so that they can be properly - // cleaned up when the function is destroyed. - for (auto *block : invalidBlocks) - function->push_back(block); - } - /// After the function is finished parsing, this function checks to see if /// there are any remaining issues. ParseResult finalizeFunction(SMLoc loc); @@ -2187,6 +2170,25 @@ public: // Block references. + ParseResult + parseOperationBlockList(SmallVectorImpl<Block *> &results, + ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments); + ParseResult parseBlockListBody(SmallVectorImpl<Block *> &results); + ParseResult parseBlock(Block *&block); + ParseResult parseBlockBody(Block *block); + + ParseResult + parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results, + Block *owner); + + /// Cleans up the memory for allocated blocks when a parser error occurs. + void cleanupInvalidBlocks(ArrayRef<Block *> invalidBlocks) { + // Add the referenced blocks to the function so that they can be properly + // cleaned up when the function is destroyed. + for (auto *block : invalidBlocks) + function->push_back(block); + } + /// Get the block with the specified name, creating it if it doesn't /// already exist. The location specified is the point of use, which allows /// us to diagnose references to blocks that are not defined precisely. @@ -2201,13 +2203,6 @@ public: OperationInst *parseGenericOperation(); OperationInst *parseCustomOperation(); - ParseResult parseForInst(); - ParseResult parseIntConstant(int64_t &val); - ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands, - unsigned numDims, unsigned numOperands, - const char *affineStructName); - ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map, - bool isLower); ParseResult parseInstructions(Block *block); private: @@ -2287,25 +2282,43 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) { /// /// block-list ::= '{' block-list-body /// -ParseResult -FunctionParser::parseOperationBlockList(SmallVectorImpl<Block *> &results) { +ParseResult FunctionParser::parseOperationBlockList( + SmallVectorImpl<Block *> &results, + ArrayRef<std::pair<FunctionParser::SSAUseInfo, Type>> entryArguments) { // Parse the '{'. if (parseToken(Token::l_brace, "expected '{' to begin block list")) return ParseFailure; + // Check for an empty block list. - if (consumeIf(Token::r_brace)) + if (entryArguments.empty() && consumeIf(Token::r_brace)) return ParseSuccess; Block *currentBlock = builder.getInsertionBlock(); // Parse the first block directly to allow for it to be unnamed. Block *block = new Block(); + + // Add arguments to the entry block. + for (auto &placeholderArgPair : entryArguments) + if (addDefinition(placeholderArgPair.first, + block->addArgument(placeholderArgPair.second))) { + delete block; + return ParseFailure; + } + if (parseBlock(block)) { - cleanupInvalidBlocks(block); + delete block; return ParseFailure; } - results.push_back(block); + + // Verify that no other arguments were parsed. + if (!entryArguments.empty() && + block->getNumArguments() > entryArguments.size()) { + delete block; + return emitError("entry block arguments were already defined"); + } // Parse the rest of the block list. + results.push_back(block); if (parseBlockListBody(results)) return ParseFailure; @@ -2385,10 +2398,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseOperation()) return ParseFailure; break; - case Token::kw_for: - if (parseForInst()) - return ParseFailure; - break; } } @@ -2859,7 +2868,7 @@ OperationInst *FunctionParser::parseGenericOperation() { std::vector<SmallVector<Block *, 2>> blocks; while (getToken().is(Token::l_brace)) { SmallVector<Block *, 2> newBlocks; - if (parseOperationBlockList(newBlocks)) { + if (parseOperationBlockList(newBlocks, /*entryArguments=*/llvm::None)) { for (auto &blockList : blocks) cleanupInvalidBlocks(blockList); return nullptr; @@ -2884,6 +2893,27 @@ public: CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser) : nameLoc(nameLoc), opName(opName), parser(parser) {} + bool parseOperation(const AbstractOperation *opDefinition, + OperationState *opState) { + if (opDefinition->parseAssembly(this, opState)) + return true; + + // Check that enough block lists were reserved for those that were parsed. + if (parsedBlockLists.size() > opState->numBlockLists) { + return emitError( + nameLoc, + "parsed more block lists than those reserved in the operation state"); + } + + // Check there were no dangling entry block arguments. + if (!parsedBlockListEntryArguments.empty()) { + return emitError( + nameLoc, + "no block list was attached to parsed entry block arguments"); + } + return false; + } + //===--------------------------------------------------------------------===// // High level parsing methods. //===--------------------------------------------------------------------===// @@ -2895,6 +2925,9 @@ public: bool parseComma() override { return parser.parseToken(Token::comma, "expected ','"); } + bool parseEqual() override { + return parser.parseToken(Token::equal, "expected '='"); + } bool parseType(Type &result) override { return !(result = parser.parseType()); @@ -3083,13 +3116,35 @@ public: /// Parses a list of blocks. bool parseBlockList() override { + // Parse the block list. SmallVector<Block *, 2> results; - if (parser.parseOperationBlockList(results)) + if (parser.parseOperationBlockList(results, parsedBlockListEntryArguments)) return true; + + parsedBlockListEntryArguments.clear(); parsedBlockLists.emplace_back(results); return false; } + /// Parses an argument for the entry block of the next block list to be + /// parsed. + bool parseBlockListEntryBlockArgument(Type argType) override { + SmallVector<Value *, 1> argValues; + OperandType operand; + if (parseOperand(operand)) + return true; + + // Create a place holder for this argument. + FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, + operand.location}; + if (auto *value = parser.resolveSSAUse(operandInfo, argType)) { + parsedBlockListEntryArguments.emplace_back(operandInfo, argType); + return false; + } + + return true; + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3130,6 +3185,8 @@ public: private: std::vector<SmallVector<Block *, 2>> parsedBlockLists; + SmallVector<std::pair<FunctionParser::SSAUseInfo, Type>, 2> + parsedBlockListEntryArguments; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3161,26 +3218,18 @@ OperationInst *FunctionParser::parseCustomOperation() { // Have the op implementation take a crack and parsing this. OperationState opState(builder.getContext(), srcLocation, opName); - if (opDefinition->parseAssembly(&opAsmParser, &opState)) + if (opAsmParser.parseOperation(opDefinition, &opState)) return nullptr; // If it emitted an error, we failed. if (opAsmParser.didEmitError()) return nullptr; - // Check that enough block lists were reserved for those that were parsed. - auto parsedBlockLists = opAsmParser.getParsedBlockLists(); - if (parsedBlockLists.size() > opState.numBlockLists) { - opAsmParser.emitError( - opLoc, - "parsed more block lists than those reserved in the operation state"); - return nullptr; - } - // Otherwise, we succeeded. Use the state it parsed as our op information. auto *opInst = builder.createOperation(opState); // Resolve any parsed block lists. + auto parsedBlockLists = opAsmParser.getParsedBlockLists(); for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { auto &opBlockList = opInst->getBlockList(i).getBlocks(); opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), @@ -3189,213 +3238,6 @@ OperationInst *FunctionParser::parseCustomOperation() { return opInst; } -/// For instruction. -/// -/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound -/// (`step` integer-literal)? trailing-location? `{` inst* `}` -/// -ParseResult FunctionParser::parseForInst() { - consumeToken(Token::kw_for); - - // Parse induction variable. - if (getToken().isNot(Token::percent_identifier)) - return emitError("expected SSA identifier for the loop variable"); - - auto loc = getToken().getLoc(); - StringRef inductionVariableName = getTokenSpelling(); - consumeToken(Token::percent_identifier); - - if (parseToken(Token::equal, "expected '='")) - return ParseFailure; - - // Parse lower bound. - SmallVector<Value *, 4> lbOperands; - AffineMap lbMap; - if (parseBound(lbOperands, lbMap, /*isLower*/ true)) - return ParseFailure; - - if (parseToken(Token::kw_to, "expected 'to' between bounds")) - return ParseFailure; - - // Parse upper bound. - SmallVector<Value *, 4> ubOperands; - AffineMap ubMap; - if (parseBound(ubOperands, ubMap, /*isLower*/ false)) - return ParseFailure; - - // Parse step. - int64_t step = 1; - if (consumeIf(Token::kw_step) && parseIntConstant(step)) - return ParseFailure; - - // The loop step is a positive integer constant. Since index is stored as an - // int64_t type, we restrict step to be in the set of positive integers that - // int64_t can represent. - if (step < 1) { - return emitError("step has to be a positive integer"); - } - - // Create for instruction. - ForInst *forInst = - builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap, - ubOperands, ubMap, step); - - // Create SSA value definition for the induction variable. - if (addDefinition({inductionVariableName, 0, loc}, - forInst->getInductionVar())) - return ParseFailure; - - // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(forInst)) - return ParseFailure; - - // If parsing of the for instruction body fails, - // MLIR contains for instruction with those nested instructions that have been - // successfully parsed. - auto *forBody = forInst->getBody(); - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(forBody) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - - // Reset insertion point to the current block. - builder.setInsertionPointToEnd(forInst->getBlock()); - - return ParseSuccess; -} - -/// Parse integer constant as affine constant expression. -ParseResult FunctionParser::parseIntConstant(int64_t &val) { - bool negate = consumeIf(Token::minus); - - if (getToken().isNot(Token::integer)) - return emitError("expected integer"); - - auto uval = getToken().getUInt64IntegerValue(); - - if (!uval.hasValue() || (int64_t)uval.getValue() < 0) { - return emitError("bound or step is too large for index"); - } - - val = (int64_t)uval.getValue(); - if (negate) - val = -val; - consumeToken(); - - return ParseSuccess; -} - -/// Dimensions and symbol use list. -/// -/// dim-use-list ::= `(` ssa-use-list? `)` -/// symbol-use-list ::= `[` ssa-use-list? `]` -/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list? -/// -ParseResult -FunctionParser::parseDimAndSymbolList(SmallVectorImpl<Value *> &operands, - unsigned numDims, unsigned numOperands, - const char *affineStructName) { - if (parseToken(Token::l_paren, "expected '('")) - return ParseFailure; - - SmallVector<SSAUseInfo, 4> opInfo; - parseOptionalSSAUseList(opInfo); - - if (parseToken(Token::r_paren, "expected ')'")) - return ParseFailure; - - if (numDims != opInfo.size()) - return emitError("dim operand count and " + Twine(affineStructName) + - " dim count must match"); - - if (consumeIf(Token::l_square)) { - parseOptionalSSAUseList(opInfo); - if (parseToken(Token::r_square, "expected ']'")) - return ParseFailure; - } - - if (numOperands != opInfo.size()) - return emitError("symbol operand count and " + Twine(affineStructName) + - " symbol count must match"); - - // Resolve SSA uses. - Type indexType = builder.getIndexType(); - for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { - Value *sval = resolveSSAUse(opInfo[i], indexType); - if (!sval) - return ParseFailure; - - if (i < numDims && !sval->isValidDim()) - return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + - "' cannot be used as a dimension id"); - if (i >= numDims && !sval->isValidSymbol()) - return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + - "' cannot be used as a symbol"); - operands.push_back(sval); - } - - return ParseSuccess; -} - -// Loop bound. -/// -/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list | -/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list -/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal -/// -ParseResult FunctionParser::parseBound(SmallVectorImpl<Value *> &operands, - AffineMap &map, bool isLower) { - // 'min' / 'max' prefixes are syntactic sugar. Ignore them. - if (isLower) - consumeIf(Token::kw_max); - else - consumeIf(Token::kw_min); - - // Parse full form - affine map followed by dim and symbol list. - if (getToken().isAny(Token::hash_identifier, Token::l_paren)) { - map = parseAffineMapReference(); - if (!map) - return ParseFailure; - - if (parseDimAndSymbolList(operands, map.getNumDims(), map.getNumInputs(), - "affine map")) - return ParseFailure; - return ParseSuccess; - } - - // Parse custom assembly form. - if (getToken().isAny(Token::minus, Token::integer)) { - int64_t val; - if (!parseIntConstant(val)) { - map = builder.getConstantAffineMap(val); - return ParseSuccess; - } - return ParseFailure; - } - - // Parse ssa-id as identity map. - SSAUseInfo opInfo; - if (parseSSAUse(opInfo)) - return ParseFailure; - - // TODO: improve error message when SSA value is not an affine integer. - // Currently it is 'use of value ... expects different type than prior uses' - if (auto *value = resolveSSAUse(opInfo, builder.getIndexType())) - operands.push_back(value); - else - return ParseFailure; - - // Create an identity map using dim id for an induction variable and - // symbol otherwise. This representation is optimized for storage. - // Analysis passes may expand it into a multi-dimensional map if desired. - if (isForInductionVar(operands[0])) - map = builder.getDimIdentityMap(); - else - map = builder.getSymbolIdentityMap(); - - return ParseSuccess; -} - /// Parse an affine constraint. /// affine-constraint ::= affine-expr `>=` `0` /// | affine-expr `==` `0` diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index afd18a49b79..e471b6792c5 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -183,11 +183,6 @@ void CSE::simplifyBlock(Block *bb) { } break; } - case Instruction::Kind::For: { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(cast<ForInst>(i).getBody()); - break; - } } } } diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index f9d02f7a47a..9c20e79180a 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -15,6 +15,7 @@ // limitations under the License. // ============================================================================= +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/InstVisitor.h" @@ -37,7 +38,6 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> { bool foldOperation(OperationInst *op, SmallVectorImpl<Value *> &existingConstants); void visitOperationInst(OperationInst *inst); - void visitForInst(ForInst *inst); PassResult runOnFunction(Function *f) override; static char passID; @@ -50,6 +50,12 @@ char ConstantFold::passID = 0; /// constants are found, we keep track of them in the existingConstants list. /// void ConstantFold::visitOperationInst(OperationInst *op) { + // If this operation is an AffineForOp, then fold the bounds. + if (auto forOp = op->dyn_cast<AffineForOp>()) { + constantFoldBounds(forOp); + return; + } + // If this operation is already a constant, just remember it for cleanup // later, and don't try to fold it. if (auto constant = op->dyn_cast<ConstantOp>()) { @@ -98,11 +104,6 @@ void ConstantFold::visitOperationInst(OperationInst *op) { opInstsToErase.push_back(op); } -// Override the walker's 'for' instruction visit for constant folding. -void ConstantFold::visitForInst(ForInst *forInst) { - constantFoldBounds(forInst); -} - // For now, we do a simple top-down pass over a function folding constants. We // don't handle conditional control flow, block arguments, folding // conditional branches, or anything else fancy. diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5c3a66208ec..83ec726ec2a 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -21,6 +21,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" @@ -71,9 +72,9 @@ struct DmaGeneration : public FunctionPass { } PassResult runOnFunction(Function *f) override; - void runOnForInst(ForInst *forInst); + void runOnAffineForOp(OpPointer<AffineForOp> forOp); - bool generateDma(const MemRefRegion ®ion, ForInst *forInst, + bool generateDma(const MemRefRegion ®ion, OpPointer<AffineForOp> forOp, uint64_t *sizeInBytes); // List of memory regions to DMA for. We need a map vector to have a @@ -174,7 +175,7 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, // Just get the first numSymbols IVs, which the memref region is parametric // on. - SmallVector<ForInst *, 4> ivs; + SmallVector<OpPointer<AffineForOp>, 4> ivs; getLoopIVs(*opInst, &ivs); ivs.resize(numParamLoopIVs); SmallVector<Value *, 4> symbols = extractForInductionVars(ivs); @@ -195,8 +196,10 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, // generates a DMA from the lower memory space to this one, and replaces all // loads to load from that buffer. Returns false if DMAs could not be generated // due to yet unimplemented cases. -bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, +bool DmaGeneration::generateDma(const MemRefRegion ®ion, + OpPointer<AffineForOp> forOp, uint64_t *sizeInBytes) { + auto *forInst = forOp->getInstruction(); // DMAs for read regions are going to be inserted just before the for loop. FuncBuilder prologue(forInst); @@ -386,39 +389,43 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, remapExprs.push_back(dimExpr - offsets[i]); } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); - // *Only* those uses within the body of 'forInst' are replaced. + // *Only* those uses within the body of 'forOp' are replaced. replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forInst->getBody()->begin()); + /*domInstFilter=*/&*forOp->getBody()->begin()); return true; } // TODO(bondhugula): make this run on a Block instead of a 'for' inst. -void DmaGeneration::runOnForInst(ForInst *forInst) { +void DmaGeneration::runOnAffineForOp(OpPointer<AffineForOp> forOp) { // For now (for testing purposes), we'll run this on the outermost among 'for' // inst's with unit stride, i.e., right at the top of the tile if tiling has // been done. In the future, the DMA generation has to be done at a level // where the generated data fits in a higher level of the memory hierarchy; so // the pass has to be instantiated with additional information that we aren't // provided with at the moment. - if (forInst->getStep() != 1) { - if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) { - runOnForInst(innerFor); + if (forOp->getStep() != 1) { + auto *forBody = forOp->getBody(); + if (forBody->empty()) + return; + if (auto innerFor = + cast<OperationInst>(forBody->front()).dyn_cast<AffineForOp>()) { + runOnAffineForOp(innerFor); } return; } // DMAs will be generated for this depth, i.e., for all data accessed by this // loop. - unsigned dmaDepth = getNestingDepth(*forInst); + unsigned dmaDepth = getNestingDepth(*forOp->getInstruction()); readRegions.clear(); writeRegions.clear(); fastBufferMap.clear(); // Walk this 'for' instruction to gather all memory regions. - forInst->walkOps([&](OperationInst *opInst) { + forOp->walkOps([&](OperationInst *opInst) { // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. if (auto loadOp = opInst->dyn_cast<LoadOp>()) { @@ -443,7 +450,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { LLVM_DEBUG( - forInst->emitError("Non-constant memref sizes not yet supported")); + forOp->emitError("Non-constant memref sizes not yet supported")); return; } } @@ -472,10 +479,10 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { // Perform a union with the existing region. if (!(*it).second->unionBoundingBox(*region)) { LLVM_DEBUG(llvm::dbgs() - << "Memory region bounding box failed; " + << "Memory region bounding box failed" "over-approximating to the entire memref\n"); if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) { - LLVM_DEBUG(forInst->emitError( + LLVM_DEBUG(forOp->emitError( "Non-constant memref sizes not yet supported")); } } @@ -501,7 +508,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { ®ions) { for (const auto ®ionEntry : regions) { uint64_t sizeInBytes; - bool iRet = generateDma(*regionEntry.second, forInst, &sizeInBytes); + bool iRet = generateDma(*regionEntry.second, forOp, &sizeInBytes); if (iRet) totalSizeInBytes += sizeInBytes; ret = ret & iRet; @@ -510,7 +517,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { processRegions(readRegions); processRegions(writeRegions); if (!ret) { - forInst->emitError("DMA generation failed for one or more memref's\n"); + forOp->emitError("DMA generation failed for one or more memref's\n"); return; } LLVM_DEBUG(llvm::dbgs() << Twine(llvm::divideCeil(totalSizeInBytes, 1024)) @@ -519,7 +526,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) { if (clFastMemoryCapacity && totalSizeInBytes > clFastMemoryCapacity) { // TODO(bondhugula): selecting the DMA depth so that the result DMA buffers // fit in fast memory is a TODO - not complex. - forInst->emitError( + forOp->emitError( "Total size of all DMA buffers' exceeds memory capacity\n"); } } @@ -531,8 +538,8 @@ PassResult DmaGeneration::runOnFunction(Function *f) { for (auto &block : *f) { for (auto &inst : block) { - if (auto *forInst = dyn_cast<ForInst>(&inst)) { - runOnForInst(forInst); + if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>()) { + runOnAffineForOp(forOp); } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index fa0e3b51de3..7d4ff03e306 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -97,15 +97,15 @@ namespace { // operations, and whether or not an IfInst was encountered in the loop nest. class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> { public: - SmallVector<ForInst *, 4> forInsts; + SmallVector<OpPointer<AffineForOp>, 4> forOps; SmallVector<OperationInst *, 4> loadOpInsts; SmallVector<OperationInst *, 4> storeOpInsts; bool hasNonForRegion = false; - void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitOperationInst(OperationInst *opInst) { - if (opInst->getNumBlockLists() != 0) + if (opInst->isa<AffineForOp>()) + forOps.push_back(opInst->cast<AffineForOp>()); + else if (opInst->getNumBlockLists() != 0) hasNonForRegion = true; else if (opInst->isa<LoadOp>()) loadOpInsts.push_back(opInst); @@ -491,14 +491,14 @@ bool MemRefDependenceGraph::init(Function *f) { if (f->getBlocks().size() != 1) return false; - DenseMap<ForInst *, unsigned> forToNodeMap; + DenseMap<Instruction *, unsigned> forToNodeMap; for (auto &inst : f->front()) { - if (auto *forInst = dyn_cast<ForInst>(&inst)) { - // Create graph node 'id' to represent top-level 'forInst' and record + if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) { + // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walkForInst(forInst); - // Return false if IfInsts are found (not currently supported). + collector.walk(&inst); + // Return false if a non 'for' region was found (not currently supported). if (collector.hasNonForRegion) return false; Node node(nextNodeId++, &inst); @@ -512,10 +512,9 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); } - forToNodeMap[forInst] = node.id; + forToNodeMap[&inst] = node.id; nodes.insert({node.id, node}); - } - if (auto *opInst = dyn_cast<OperationInst>(&inst)) { + } else if (auto *opInst = dyn_cast<OperationInst>(&inst)) { if (auto loadOp = opInst->dyn_cast<LoadOp>()) { // Create graph node for top-level load op. Node node(nextNodeId++, &inst); @@ -552,12 +551,12 @@ bool MemRefDependenceGraph::init(Function *f) { for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { auto *userOpInst = cast<OperationInst>(use.getOwner()); - SmallVector<ForInst *, 4> loops; + SmallVector<OpPointer<AffineForOp>, 4> loops; getLoopIVs(*userOpInst, &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0]) > 0); - unsigned userLoopNestId = forToNodeMap[loops[0]]; + assert(forToNodeMap.count(loops[0]->getInstruction()) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()]; addEdge(node.id, userLoopNestId, value); } } @@ -587,12 +586,12 @@ namespace { // LoopNestStats aggregates various per-loop statistics (eg. loop trip count // and operation count) for a loop nest up until the innermost loop body. struct LoopNestStats { - // Map from ForInst to immediate child ForInsts in its loop body. - DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap; - // Map from ForInst to count of operations in its loop body. - DenseMap<ForInst *, uint64_t> opCountMap; - // Map from ForInst to its constant trip count. - DenseMap<ForInst *, uint64_t> tripCountMap; + // Map from AffineForOp to immediate child AffineForOps in its loop body. + DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap; + // Map from AffineForOp to count of operations in its loop body. + DenseMap<Instruction *, uint64_t> opCountMap; + // Map from AffineForOp to its constant trip count. + DenseMap<Instruction *, uint64_t> tripCountMap; }; // LoopNestStatsCollector walks a single loop nest and gathers per-loop @@ -604,23 +603,31 @@ public: LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitForInst(ForInst *forInst) { - auto *parentInst = forInst->getParentInst(); + void visitOperationInst(OperationInst *opInst) { + auto forOp = opInst->dyn_cast<AffineForOp>(); + if (!forOp) + return; + + auto *forInst = forOp->getInstruction(); + auto *parentInst = forOp->getInstruction()->getParentInst(); if (parentInst != nullptr) { - assert(isa<ForInst>(parentInst) && "Expected parent ForInst"); - // Add mapping to 'forInst' from its parent ForInst. - stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst); + assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() && + "Expected parent AffineForOp"); + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentInst].push_back(forOp); } - // Record the number of op instructions in the body of 'forInst'. + + // Record the number of op instructions in the body of 'forOp'. unsigned count = 0; stats->opCountMap[forInst] = 0; - for (auto &inst : *forInst->getBody()) { - if (isa<OperationInst>(&inst)) + for (auto &inst : *forOp->getBody()) { + if (!(cast<OperationInst>(inst).isa<AffineForOp>() || + cast<OperationInst>(inst).isa<AffineIfOp>())) ++count; } stats->opCountMap[forInst] = count; - // Record trip count for 'forInst'. Set flag if trip count is not constant. - Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst); + // Record trip count for 'forOp'. Set flag if trip count is not constant. + Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount.hasValue()) { hasLoopWithNonConstTripCount = true; return; @@ -629,7 +636,7 @@ public: } }; -// Computes the total cost of the loop nest rooted at 'forInst'. +// Computes the total cost of the loop nest rooted at 'forOp'. // Currently, the total cost is computed by counting the total operation // instance count (i.e. total number of operations in the loop bodyloop // operation count * loop trip count) for the entire loop nest. @@ -637,7 +644,7 @@ public: // specified in the map when computing the total op instance count. // NOTE: this is used to compute the cost of computation slices, which are // sliced along the iteration dimension, and thus reduce the trip count. -// If 'computeCostMap' is non-null, the total op count for forInsts specified +// If 'computeCostMap' is non-null, the total op count for forOps specified // in the map is increased (not overridden) by adding the op count from the // map to the existing op count for the for loop. This is done before // multiplying by the loop's trip count, and is used to model the cost of @@ -645,15 +652,15 @@ public: // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. static int64_t getComputeCost( - ForInst *forInst, LoopNestStats *stats, - llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap, - DenseMap<ForInst *, int64_t> *computeCostMap) { - // 'opCount' is the total number operations in one iteration of 'forInst' body + Instruction *forInst, LoopNestStats *stats, + llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap, + DenseMap<Instruction *, int64_t> *computeCostMap) { + // 'opCount' is the total number operations in one iteration of 'forOp' body int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { - for (auto *childForInst : stats->loopMap[forInst]) { - opCount += getComputeCost(childForInst, stats, tripCountOverrideMap, - computeCostMap); + for (auto childForOp : stats->loopMap[forInst]) { + opCount += getComputeCost(childForOp->getInstruction(), stats, + tripCountOverrideMap, computeCostMap); } } // Add in additional op instances from slice (if specified in map). @@ -694,18 +701,18 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) { return cExpr.getValue(); } -// Builds a map 'tripCountMap' from ForInst to constant trip count for loop +// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop // nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'. // Returns true on success, false otherwise (if a non-constant trip count // was encountered). // TODO(andydavis) Make this work with non-unit step loops. static bool buildSliceTripCountMap( OperationInst *srcOpInst, ComputationSliceState *sliceState, - llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) { - SmallVector<ForInst *, 4> srcLoopIVs; + llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) { + SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); - // Populate map from ForInst -> trip count + // Populate map from AffineForOp -> trip count for (unsigned i = 0; i < numSrcLoopIVs; ++i) { AffineMap lbMap = sliceState->lbs[i]; AffineMap ubMap = sliceState->ubs[i]; @@ -713,7 +720,7 @@ static bool buildSliceTripCountMap( // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (srcLoopIVs[i]->hasConstantLowerBound() && srcLoopIVs[i]->hasConstantUpperBound()) { - (*tripCountMap)[srcLoopIVs[i]] = + (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = srcLoopIVs[i]->getConstantUpperBound() - srcLoopIVs[i]->getConstantLowerBound(); continue; @@ -723,7 +730,7 @@ static bool buildSliceTripCountMap( Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap); if (!tripCount.hasValue()) return false; - (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue(); + (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue(); } return true; } @@ -750,7 +757,7 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) { unsigned numOps = ops.size(); assert(numOps > 0); - std::vector<SmallVector<ForInst *, 4>> loops(numOps); + std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps); unsigned loopDepthLimit = std::numeric_limits<unsigned>::max(); for (unsigned i = 0; i < numOps; ++i) { getLoopIVs(*ops[i], &loops[i]); @@ -762,9 +769,8 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) { for (unsigned d = 0; d < loopDepthLimit; ++d) { unsigned i; for (i = 1; i < numOps; ++i) { - if (loops[i - 1][d] != loops[i][d]) { + if (loops[i - 1][d] != loops[i][d]) break; - } } if (i != numOps) break; @@ -871,14 +877,16 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA, } // Creates and returns a private (single-user) memref for fused loop rooted -// at 'forInst', with (potentially reduced) memref size based on the +// at 'forOp', with (potentially reduced) memref size based on the // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static Value *createPrivateMemRef(ForInst *forInst, +static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp, OperationInst *srcStoreOpInst, unsigned dstLoopDepth) { - // Create builder to insert alloc op just before 'forInst'. + auto *forInst = forOp->getInstruction(); + + // Create builder to insert alloc op just before 'forOp'. FuncBuilder b(forInst); // Builder to create constants at the top level. FuncBuilder top(forInst->getFunction()); @@ -934,16 +942,16 @@ static Value *createPrivateMemRef(ForInst *forInst, for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) allocOperands.push_back( - top.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++)); + top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++)); } - // Create new private memref for fused loop 'forInst'. + // Create new private memref for fused loop 'forOp'. // TODO(andydavis) Create/move alloc ops for private memrefs closer to their // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the function, because loop nests can be reordered // during the fusion pass. Value *newMemRef = - top.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands); + top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector<AffineExpr, 4> remapExprs; @@ -967,7 +975,7 @@ static Value *createPrivateMemRef(ForInst *forInst, bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, /*extraOperands=*/outerIVs, - /*domInstFilter=*/&*forInst->getBody()->begin()); + /*domInstFilter=*/&*forOp->getBody()->begin()); assert(ret && "replaceAllMemrefUsesWith should always succeed here"); (void)ret; return newMemRef; @@ -975,7 +983,7 @@ static Value *createPrivateMemRef(ForInst *forInst, // Does the slice have a single iteration? static uint64_t getSliceIterationCount( - const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) { + const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) { uint64_t iterCount = 1; for (const auto &count : sliceTripCountMap) { iterCount *= count.second; @@ -1030,25 +1038,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst, }); // Compute cost of sliced and unsliced src loop nest. - SmallVector<ForInst *, 4> srcLoopIVs; + SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.walk(srcLoopIVs[0]); + srcStatsCollector.walk(srcLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) return false; // Compute cost of dst loop nest. - SmallVector<ForInst *, 4> dstLoopIVs; + SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs; getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.walk(dstLoopIVs[0]); + dstStatsCollector.walk(dstLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1075,17 +1083,19 @@ static bool isFusionProfitable(OperationInst *srcOpInst, Optional<unsigned> bestDstLoopDepth = None; // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t srcLoopNestCost = + getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + uint64_t dstLoopNestCost = + getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); - llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap; - DenseMap<ForInst *, int64_t> computeCostMap; + llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap; + DenseMap<Instruction *, int64_t> computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { MemRefAccess srcAccess(srcOpInst); // Handle the common case of one dst load without a copy. @@ -1121,24 +1131,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // The store and loads to this memref will disappear. if (storeLoadFwdGuaranteed) { // A single store disappears: -1 for that. - computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1; + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; for (auto *loadOp : dstLoadOpInsts) { - if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst())) - computeCostMap[loadLoop] = -1; + auto *parentInst = loadOp->getParentInst(); + if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>()) + computeCostMap[parentInst] = -1; } } // Compute op instance count for the src loop nest with iteration slicing. int64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats, /*tripCountOverrideMap=*/&sliceTripCountMap, /*computeCostMap=*/&computeCostMap); // Compute cost of fusion for this depth. - computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; + computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost; int64_t fusedLoopNestComputeCost = - getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); double additionalComputeFraction = @@ -1211,8 +1222,8 @@ static bool isFusionProfitable(OperationInst *srcOpInst, << "\n fused loop nest compute cost: " << minFusedLoopNestComputeCost << "\n"); - auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]); - auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]); + auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); Optional<double> storageReduction = None; @@ -1292,9 +1303,9 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // // *) A worklist is initialized with node ids from the dependence graph. // *) For each node id in the worklist: -// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate -// destination ForInst into which fusion will be attempted. -// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'. +// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a +// candidate destination AffineForOp into which fusion will be attempted. +// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. // *) For each LoadOp in 'dstLoadOps' do: // *) Lookup dependent loop nests at earlier positions in the Function // which have a single store op to the same memref. @@ -1342,7 +1353,7 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!isa<ForInst>(dstNode->inst)) + if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>()) continue; SmallVector<OperationInst *, 4> loads = dstNode->loads; @@ -1375,7 +1386,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!isa<ForInst>(srcNode->inst)) + if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>()) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -1417,25 +1428,26 @@ public: continue; // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - auto *sliceLoopNest = mlir::insertBackwardComputationSlice( + auto sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { - // Move 'dstForInst' before 'insertPointInst' if needed. - auto *dstForInst = cast<ForInst>(dstNode->inst); - if (insertPointInst != dstForInst) { - dstForInst->moveBefore(insertPointInst); + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + auto dstAffineForOp = + cast<OperationInst>(dstNode->inst)->cast<AffineForOp>(); + if (insertPointInst != dstAffineForOp->getInstruction()) { + dstAffineForOp->getInstruction()->moveBefore(insertPointInst); } // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id, memref); // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.walkForInst(sliceLoopNest); + sliceCollector.walk(sliceLoopNest->getInstruction()); // Promote single iteration slice loops to single IV value. - for (auto *forInst : sliceCollector.forInsts) { - promoteIfSingleIteration(forInst); + for (auto forOp : sliceCollector.forOps) { + promoteIfSingleIteration(forOp); } - // Create private memref for 'memref' in 'dstForInst'. + // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector<OperationInst *, 4> storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast<StoreOp>()->getMemRef() == memref) @@ -1443,7 +1455,7 @@ public: } assert(storesForMemref.size() == 1); auto *newMemRef = createPrivateMemRef( - dstForInst, storesForMemref[0], bestDstLoopDepth); + dstAffineForOp, storesForMemref[0], bestDstLoopDepth); visitedMemrefs.insert(newMemRef); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = @@ -1453,7 +1465,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.walkForInst(dstForInst); + dstLoopCollector.walk(dstAffineForOp->getInstruction()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. @@ -1472,7 +1484,7 @@ public: // function. if (mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); - cast<ForInst>(srcNode->inst)->erase(); + srcNode->inst->erase(); } } } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 396fc8eb658..f1ee7fd1853 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -60,16 +61,17 @@ char LoopTiling::passID = 0; /// Function. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } -// Move the loop body of ForInst 'src' from 'src' into the specified location in -// destination's body. -static inline void moveLoopBody(ForInst *src, ForInst *dest, +// Move the loop body of AffineForOp 'src' from 'src' into the specified +// location in destination's body. +static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest, Block::iterator loc) { dest->getBody()->getInstructions().splice(loc, src->getBody()->getInstructions()); } -// Move the loop body of ForInst 'src' from 'src' to the start of dest's body. -static inline void moveLoopBody(ForInst *src, ForInst *dest) { +// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's +// body. +static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest) { moveLoopBody(src, dest, dest->getBody()->begin()); } @@ -78,13 +80,14 @@ static inline void moveLoopBody(ForInst *src, ForInst *dest) { /// depend on other dimensions. Bounds of each dimension can thus be treated /// independently, and deriving the new bounds is much simpler and faster /// than for the case of tiling arbitrary polyhedral shapes. -static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, - ArrayRef<ForInst *> newLoops, - ArrayRef<unsigned> tileSizes) { +static void constructTiledIndexSetHyperRect( + MutableArrayRef<OpPointer<AffineForOp>> origLoops, + MutableArrayRef<OpPointer<AffineForOp>> newLoops, + ArrayRef<unsigned> tileSizes) { assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); - FuncBuilder b(origLoops[0]); + FuncBuilder b(origLoops[0]->getInstruction()); unsigned width = origLoops.size(); // Bounds for tile space loops. @@ -99,8 +102,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, } // Bounds for intra-tile loops. for (unsigned i = 0; i < width; i++) { - int64_t largestDiv = getLargestDivisorOfTripCount(*origLoops[i]); - auto mayBeConstantCount = getConstantTripCount(*origLoops[i]); + int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]); + auto mayBeConstantCount = getConstantTripCount(origLoops[i]); // The lower bound is just the tile-space loop. AffineMap lbMap = b.getDimIdentityMap(); newLoops[width + i]->setLowerBound( @@ -144,38 +147,40 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, +UtilResult mlir::tileCodeGen(MutableArrayRef<OpPointer<AffineForOp>> band, ArrayRef<unsigned> tileSizes) { assert(!band.empty()); assert(band.size() == tileSizes.size()); // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i]->getParentInst() == band[i - 1]); + assert(band[i]->getInstruction()->getParentInst() == + band[i - 1]->getInstruction()); } auto origLoops = band; - ForInst *rootForInst = origLoops[0]; - auto loc = rootForInst->getLoc(); + OpPointer<AffineForOp> rootAffineForOp = origLoops[0]; + auto loc = rootAffineForOp->getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector<ForInst *, 12> newLoops(2 * width); - ForInst *innermostPointLoop; + SmallVector<OpPointer<AffineForOp>, 12> newLoops(2 * width); + OpPointer<AffineForOp> innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootForInst; + auto *topLoop = rootAffineForOp->getInstruction(); // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { FuncBuilder b(topLoop); // Loop bounds will be set later. - auto *pointLoop = b.createFor(loc, 0, 0); + auto pointLoop = b.create<AffineForOp>(loc, 0, 0); + pointLoop->createBody(); pointLoop->getBody()->getInstructions().splice( pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - 1 - i] = pointLoop; - topLoop = pointLoop; + topLoop = pointLoop->getInstruction(); if (i == 0) innermostPointLoop = pointLoop; } @@ -184,12 +189,13 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, for (unsigned i = width; i < 2 * width; i++) { FuncBuilder b(topLoop); // Loop bounds will be set later. - auto *tileSpaceLoop = b.createFor(loc, 0, 0); + auto tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0); + tileSpaceLoop->createBody(); tileSpaceLoop->getBody()->getInstructions().splice( tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(), topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; - topLoop = tileSpaceLoop; + topLoop = tileSpaceLoop->getInstruction(); } // Move the loop body of the original nest to the new one. @@ -201,8 +207,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootForInst->emitError("tiled code generation unimplemented for the" - "non-hyperrectangular case"); + rootAffineForOp->emitError("tiled code generation unimplemented for the" + "non-hyperrectangular case"); return UtilResult::Failure; } @@ -213,7 +219,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, } // Erase the old loop nest. - rootForInst->erase(); + rootAffineForOp->erase(); return UtilResult::Success; } @@ -221,38 +227,36 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function *f, - std::vector<SmallVector<ForInst *, 6>> *bands) { +static void +getTileableBands(Function *f, + std::vector<SmallVector<OpPointer<AffineForOp>, 6>> *bands) { // Get maximal perfect nest of 'for' insts starting from root (inclusive). - auto getMaximalPerfectLoopNest = [&](ForInst *root) { - SmallVector<ForInst *, 6> band; - ForInst *currInst = root; + auto getMaximalPerfectLoopNest = [&](OpPointer<AffineForOp> root) { + SmallVector<OpPointer<AffineForOp>, 6> band; + OpPointer<AffineForOp> currInst = root; do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = dyn_cast<ForInst>(&currInst->getBody()->front()))); + (currInst = cast<OperationInst>(currInst->getBody()->front()) + .dyn_cast<AffineForOp>())); bands->push_back(band); }; - for (auto &block : *f) { - for (auto &inst : block) { - auto *forInst = dyn_cast<ForInst>(&inst); - if (!forInst) - continue; - getMaximalPerfectLoopNest(forInst); - } - } + for (auto &block : *f) + for (auto &inst : block) + if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>()) + getMaximalPerfectLoopNest(forOp); } PassResult LoopTiling::runOnFunction(Function *f) { - std::vector<SmallVector<ForInst *, 6>> bands; + std::vector<SmallVector<OpPointer<AffineForOp>, 6>> bands; getTileableBands(f, &bands); // Temporary tile sizes. unsigned tileSize = clTileSize.getNumOccurrences() > 0 ? clTileSize : kDefaultTileSize; - for (const auto &band : bands) { + for (auto &band : bands) { SmallVector<unsigned, 6> tileSizes(band.size(), tileSize); if (tileCodeGen(band, tileSizes)) { return failure(); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6d63e4afd2d..86e913bd71f 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -70,18 +71,19 @@ struct LoopUnroll : public FunctionPass { const Optional<bool> unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function<unsigned(const ForInst &)> getUnrollFactor; + const std::function<unsigned(ConstOpPointer<AffineForOp>)> getUnrollFactor; - explicit LoopUnroll( - Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, - const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr) + explicit LoopUnroll(Optional<unsigned> unrollFactor = None, + Optional<bool> unrollFull = None, + const std::function<unsigned(ConstOpPointer<AffineForOp>)> + &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnFunction(Function *f) override; /// Unroll this for inst. Returns false if nothing was done. - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer<AffineForOp> forOp); static const unsigned kDefaultUnrollFactor = 4; @@ -96,7 +98,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> { public: // Store innermost loops as we walk. - std::vector<ForInst *> loops; + std::vector<OpPointer<AffineForOp>> loops; // This method specialized to encode custom return logic. using InstListType = llvm::iplist<Instruction>; @@ -111,20 +113,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkForInstPostOrder(ForInst *forInst) { - bool hasInnerLoops = - walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end()); - if (!hasInnerLoops) - loops.push_back(forInst); - return true; - } - bool walkOpInstPostOrder(OperationInst *opInst) { + bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) - if (walkPostOrder(block.begin(), block.end())) - return true; - return false; + hasInnerLoops |= walkPostOrder(block.begin(), block.end()); + if (opInst->isa<AffineForOp>()) { + if (!hasInnerLoops) + loops.push_back(opInst->cast<AffineForOp>()); + return true; + } + return hasInnerLoops; } // FIXME: can't use base class method for this because that in turn would @@ -137,14 +136,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> { public: // Store short loops as we walk. - std::vector<ForInst *> loops; + std::vector<OpPointer<AffineForOp>> loops; const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitForInst(ForInst *forInst) { - Optional<uint64_t> tripCount = getConstantTripCount(*forInst); + void visitOperationInst(OperationInst *opInst) { + auto forOp = opInst->dyn_cast<AffineForOp>(); + if (!forOp) + return; + Optional<uint64_t> tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) - loops.push_back(forInst); + loops.push_back(forOp); } }; @@ -156,8 +158,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { // ones). slg.walkPostOrder(f); auto &loops = slg.loops; - for (auto *forInst : loops) - loopUnrollFull(forInst); + for (auto forOp : loops) + loopUnrollFull(forOp); return success(); } @@ -172,8 +174,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { if (loops.empty()) break; bool unrolled = false; - for (auto *forInst : loops) - unrolled |= runOnForInst(forInst); + for (auto forOp : loops) + unrolled |= runOnAffineForOp(forOp); if (!unrolled) // Break out if nothing was unrolled. break; @@ -183,29 +185,30 @@ PassResult LoopUnroll::runOnFunction(Function *f) { /// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. -bool LoopUnroll::runOnForInst(ForInst *forInst) { +bool LoopUnroll::runOnAffineForOp(OpPointer<AffineForOp> forOp) { // Use the function callback if one was provided. if (getUnrollFactor) { - return loopUnrollByFactor(forInst, getUnrollFactor(*forInst)); + return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) - return loopUnrollByFactor(forInst, unrollFactor.getValue()); + return loopUnrollByFactor(forOp, unrollFactor.getValue()); // Unroll by the command line factor if one was specified. if (clUnrollFactor.getNumOccurrences() > 0) - return loopUnrollByFactor(forInst, clUnrollFactor); + return loopUnrollByFactor(forOp, clUnrollFactor); // Unroll completely if full loop unroll was specified. if (clUnrollFull.getNumOccurrences() > 0 || (unrollFull.hasValue() && unrollFull.getValue())) - return loopUnrollFull(forInst); + return loopUnrollFull(forOp); // Unroll by four otherwise. - return loopUnrollByFactor(forInst, kDefaultUnrollFactor); + return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } FunctionPass *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function<unsigned(const ForInst &)> &getUnrollFactor) { + const std::function<unsigned(ConstOpPointer<AffineForOp>)> + &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7deaf850362..7327a37ee3a 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -43,6 +43,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -80,7 +81,7 @@ struct LoopUnrollAndJam : public FunctionPass { unrollJamFactor(unrollJamFactor) {} PassResult runOnFunction(Function *f) override; - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer<AffineForOp> forOp); static char passID; }; @@ -95,47 +96,51 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is - // unroll-and-jammed by this pass. However, runOnForInst can be called on any - // for Inst. + // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on + // any for Inst. auto &entryBlock = f->front(); if (!entryBlock.empty()) - if (auto *forInst = dyn_cast<ForInst>(&entryBlock.front())) - runOnForInst(forInst); + if (auto forOp = + cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>()) + runOnAffineForOp(forOp); return success(); } /// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. -bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) { +bool LoopUnrollAndJam::runOnAffineForOp(OpPointer<AffineForOp> forOp) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) - return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue()); // Otherwise, unroll jam by the command-line factor if one was specified. if (clUnrollJamFactor.getNumOccurrences() > 0) - return loopUnrollJamByFactor(forInst, clUnrollJamFactor); + return loopUnrollJamByFactor(forOp, clUnrollJamFactor); // Unroll and jam by four otherwise. - return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor); + return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor); } -bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollJamUpToFactor(OpPointer<AffineForOp> forOp, + uint64_t unrollJamFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forInst, unrollJamFactor); + return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor); } /// Unrolls and jams this loop by the specified factor. -bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { +bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp, + uint64_t unrollJamFactor) { // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though // in its tree). class JamBlockGatherer : public InstWalker<JamBlockGatherer> { public: using InstListType = llvm::iplist<Instruction>; + using InstWalker<JamBlockGatherer>::walk; // Store iterators to the first and last inst of each sub-block found. std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; @@ -144,30 +149,30 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !isa<ForInst>(it)) + while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && isa<ForInst>(it)) - walkForInst(cast<ForInst>(it++)); + while (it != End && cast<OperationInst>(it)->isa<AffineForOp>()) + walk(&*it++); } } }; assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forInst->getBody()->empty()) + if (unrollJamFactor == 1 || forOp->getBody()->empty()) return false; - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); if (!mayBeConstantTripCount.hasValue() && - getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0) + getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) return false; - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -178,7 +183,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different sets of operands. - if (!forInst->matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return false; // If the trip count is lower than the unroll jam factor, no unroll jam. @@ -187,35 +192,38 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { mayBeConstantTripCount.getValue() < unrollJamFactor) return false; + auto *forInst = forOp->getInstruction(); + // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkForInst(forInst); + jbg.walkOpInst(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of // unrollJamFactor. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { - // Insert the cleanup loop right after 'forInst'. + // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst)); - cleanupForInst->setLowerBoundMap( - getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder)); + auto cleanupAffineForOp = + cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>(); + cleanupAffineForOp->setLowerBoundMap( + getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder)); // The upper bound needs to be adjusted. - forInst->setUpperBoundMap( - getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder)); + forOp->setUpperBoundMap( + getUnrolledLoopUpperBound(forOp, unrollJamFactor, &builder)); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForInst); + promoteIfSingleIteration(cleanupAffineForOp); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forInst->getStep(); - forInst->setStep(step * unrollJamFactor); + int64_t step = forOp->getStep(); + forOp->setStep(step * unrollJamFactor); - auto *forInstIV = forInst->getInductionVar(); + auto *forOpIV = forOp->getInductionVar(); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of // sub-block. @@ -227,13 +235,13 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInstIV->use_empty()) { + if (!forOpIV->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); - auto ivUnroll = builder.create<AffineApplyOp>(forInst->getLoc(), - bumpMap, forInstIV); - operandMapping.map(forInstIV, ivUnroll); + auto ivUnroll = + builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV); + operandMapping.map(forOpIV, ivUnroll); } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { @@ -243,7 +251,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forInst); + promoteIfSingleIteration(forOp); return true; } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f770684f519..24ca4e95082 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" @@ -246,7 +247,7 @@ public: LowerAffinePass() : FunctionPass(&passID) {} PassResult runOnFunction(Function *function) override; - bool lowerForInst(ForInst *forInst); + bool lowerAffineFor(OpPointer<AffineForOp> forOp); bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); @@ -295,11 +296,11 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // a nested loop). Induction variable modification is appended to the body SESE // region that always loops back to the condition block. // -// +--------------------------------+ -// | <code before the ForInst> | -// | <compute initial %iv value> | -// | br cond(%iv) | -// +--------------------------------+ +// +---------------------------------+ +// | <code before the AffineForOp> | +// | <compute initial %iv value> | +// | br cond(%iv) | +// +---------------------------------+ // | // -------| | // | v v @@ -322,11 +323,12 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, // v // +--------------------------------+ // | end: | -// | <code after the ForInst> | +// | <code after the AffineForOp> | // +--------------------------------+ // -bool LowerAffinePass::lowerForInst(ForInst *forInst) { - auto loc = forInst->getLoc(); +bool LowerAffinePass::lowerAffineFor(OpPointer<AffineForOp> forOp) { + auto loc = forOp->getLoc(); + auto *forInst = forOp->getInstruction(); // Start by splitting the block containing the 'for' into two parts. The part // before will get the init code, the part after will be the end point. @@ -339,23 +341,23 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { conditionBlock->insertBefore(endBlock); auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext())); - // Create the body block, moving the body of the forInst over to it. + // Create the body block, moving the body of the forOp over to it. auto *bodyBlock = new Block(); bodyBlock->insertBefore(endBlock); - auto *oldBody = forInst->getBody(); + auto *oldBody = forOp->getBody(); bodyBlock->getInstructions().splice(bodyBlock->begin(), oldBody->getInstructions(), oldBody->begin(), oldBody->end()); - // The code in the body of the forInst now uses 'iv' as its indvar. - forInst->getInductionVar()->replaceAllUsesWith(iv); + // The code in the body of the forOp now uses 'iv' as its indvar. + forOp->getInductionVar()->replaceAllUsesWith(iv); // Append the induction variable stepping logic and branch back to the exit // condition block. Construct an affine expression f : (x -> x+step) and // apply this expression to the induction variable. FuncBuilder builder(bodyBlock); - auto affStep = builder.getAffineConstantExpr(forInst->getStep()); + auto affStep = builder.getAffineConstantExpr(forOp->getStep()); auto affDim = builder.getAffineDimExpr(0); auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {}); if (!stepped) @@ -368,18 +370,18 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { builder.setInsertionPointToEnd(initBlock); // Compute loop bounds. - SmallVector<Value *, 8> operands(forInst->getLowerBoundOperands()); + SmallVector<Value *, 8> operands(forOp->getLowerBoundOperands()); auto lbValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getLowerBoundMap(), operands); + forOp->getLowerBoundMap(), operands); if (!lbValues) return true; Value *lowerBound = buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder); - operands.assign(forInst->getUpperBoundOperands().begin(), - forInst->getUpperBoundOperands().end()); + operands.assign(forOp->getUpperBoundOperands().begin(), + forOp->getUpperBoundOperands().end()); auto ubValues = expandAffineMap(&builder, forInst->getLoc(), - forInst->getUpperBoundMap(), operands); + forOp->getUpperBoundMap(), operands); if (!ubValues) return true; Value *upperBound = @@ -394,7 +396,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { endBlock, ArrayRef<Value *>()); // Ok, we're done! - forInst->erase(); + forOp->erase(); return false; } @@ -614,28 +616,26 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa<ForInst>(inst)) - instsToRewrite.push_back(inst); - auto op = dyn_cast<OperationInst>(inst); - if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>())) + auto op = cast<OperationInst>(inst); + if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() || + op->isa<AffineIfOp>()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. - for (auto *inst : instsToRewrite) - if (auto *forInst = dyn_cast<ForInst>(inst)) { - if (lowerForInst(forInst)) + for (auto *inst : instsToRewrite) { + auto op = cast<OperationInst>(inst); + if (auto ifOp = op->dyn_cast<AffineIfOp>()) { + if (lowerAffineIf(ifOp)) return failure(); - } else { - auto op = cast<OperationInst>(inst); - if (auto ifOp = op->dyn_cast<AffineIfOp>()) { - if (lowerAffineIf(ifOp)) - return failure(); - } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { + } else if (auto forOp = op->dyn_cast<AffineForOp>()) { + if (lowerAffineFor(forOp)) return failure(); - } + } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { + return failure(); } + } return success(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 432ad1f39b8..f2dae11112b 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -75,7 +75,7 @@ /// Implementation details /// ====================== /// The current decisions made by the super-vectorization pass guarantee that -/// use-def chains do not escape an enclosing vectorized ForInst. In other +/// use-def chains do not escape an enclosing vectorized AffineForOp. In other /// words, this pass operates on a scoped program slice. Furthermore, since we /// do not vectorize in the presence of conditionals for now, sliced chains are /// guaranteed not to escape the innermost scope, which has to be either the top @@ -285,13 +285,12 @@ static Value *substitute(Value *v, VectorType hwVectorType, /// /// The general problem this function solves is as follows: /// Assume a vector_transfer operation at the super-vector granularity that has -/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates -/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of -/// rank `h`. -/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the -/// super-vector is vector<3x32xf32> and the hardware vector is vector<8xf32>. -/// Assume the following MLIR snippet after super-vectorization has been -/// applied: +/// `l` enclosing loops (AffineForOp). Assume the vector transfer operation +/// operates on a MemRef of rank `r`, a super-vector of rank `s` and a hardware +/// vector of rank `h`. For the purpose of illustration assume l==4, r==3, s==2, +/// h==1 and that the super-vector is vector<3x32xf32> and the hardware vector +/// is vector<8xf32>. Assume the following MLIR snippet after +/// super-vectorization has been applied: /// /// ```mlir /// for %i0 = 0 to %M { @@ -351,7 +350,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, SmallVector<AffineExpr, 8> affineExprs; // TODO(ntv): support a concrete map and composition. unsigned i = 0; - // The first numMemRefIndices correspond to ForInst that have not been + // The first numMemRefIndices correspond to AffineForOp that have not been // vectorized, the transformation is the identity on those. for (i = 0; i < numMemRefIndices; ++i) { auto d_i = b->getAffineDimExpr(i); @@ -554,9 +553,6 @@ static bool instantiateMaterialization(Instruction *inst, MaterializationState *state) { LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst); - if (isa<ForInst>(inst)) - return inst->emitError("NYI path ForInst"); - // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast<OperationInst>(inst); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 811741d08d1..2e083bbfd79 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -21,11 +21,11 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -38,15 +38,12 @@ using namespace mlir; namespace { -struct PipelineDataTransfer : public FunctionPass, - InstWalker<PipelineDataTransfer> { +struct PipelineDataTransfer : public FunctionPass { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnFunction(Function *f) override; - PassResult runOnForInst(ForInst *forInst); + PassResult runOnAffineForOp(OpPointer<AffineForOp> forOp); - // Collect all 'for' instructions. - void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - std::vector<ForInst *> forInsts; + std::vector<OpPointer<AffineForOp>> forOps; static char passID; }; @@ -79,8 +76,8 @@ static unsigned getTagMemRefPos(const OperationInst &dmaInst) { /// of the old memref by the new one while indexing the newly added dimension by /// the loop IV of the specified 'for' instruction modulo 2. Returns false if /// such a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { - auto *forBody = forInst->getBody(); +static bool doubleBuffer(Value *oldMemRef, OpPointer<AffineForOp> forOp) { + auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -101,6 +98,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { auto newMemRefType = doubleShape(oldMemRefType); // Put together alloc operands for the dynamic dimensions of the memref. + auto *forInst = forOp->getInstruction(); FuncBuilder bOuter(forInst); SmallVector<Value *, 4> allocOperands; unsigned dynamicDimCount = 0; @@ -118,16 +116,16 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); - int64_t step = forInst->getStep(); + int64_t step = forOp->getStep(); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0.floorDiv(step) % 2}, {}); - auto ivModTwoOp = bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, - forInst->getInductionVar()); + auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp->getLoc(), modTwoMap, + forOp->getInductionVar()); - // replaceAllMemRefUsesWith will always succeed unless the forInst body has + // replaceAllMemRefUsesWith will always succeed unless the forOp body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(), - {}, &*forInst->getBody()->begin())) { + {}, &*forOp->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -143,11 +141,14 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). - forInsts.clear(); - walkPostOrder(f); + forOps.clear(); + f->walkOpsPostOrder([&](OperationInst *opInst) { + if (auto forOp = opInst->dyn_cast<AffineForOp>()) + forOps.push_back(forOp); + }); bool ret = false; - for (auto *forInst : forInsts) { - ret = ret | runOnForInst(forInst); + for (auto forOp : forOps) { + ret = ret | runOnAffineForOp(forOp); } return ret ? failure() : success(); } @@ -178,13 +179,13 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( - ForInst *forInst, + OpPointer<AffineForOp> forOp, SmallVectorImpl<std::pair<OperationInst *, OperationInst *>> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast<OperationInst>(&inst); if (!opInst) continue; @@ -195,7 +196,7 @@ static void findMatchingStartFinishInsts( } SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast<OperationInst>(&inst); if (!opInst) continue; @@ -227,7 +228,7 @@ static void findMatchingStartFinishInsts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) { + if (!forOp->getBody()->findAncestorInstInBlock(*use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -251,17 +252,18 @@ static void findMatchingStartFinishInsts( } /// Overlap DMA transfers with computation in this loop. If successful, -/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { - auto mayBeConstTripCount = getConstantTripCount(*forInst); +PassResult +PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) { + auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return success(); } SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs; - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); @@ -280,7 +282,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( dmaStartInst->cast<DmaStartOp>()->getFasterMemPos()); - if (!doubleBuffer(oldMemRef, forInst)) { + if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); @@ -302,7 +304,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaFinishInst = pair.second; Value *oldTagMemRef = dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); - if (!doubleBuffer(oldTagMemRef, forInst)) { + if (!doubleBuffer(oldTagMemRef, forOp)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); } @@ -315,7 +317,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { // Double buffering would have invalidated all the old DMA start/wait insts. startWaitPairs.clear(); - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); // Store shift for instruction for later lookup for AffineApplyOp's. DenseMap<const Instruction *, unsigned> instShiftMap; @@ -342,16 +344,16 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &inst : *forInst->getBody()) { + for (const auto &inst : *forOp->getBody()) { if (instShiftMap.find(&inst) == instShiftMap.end()) { instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size()); + std::vector<uint64_t> shifts(forOp->getBody()->getInstructions().size()); unsigned s = 0; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; LLVM_DEBUG( @@ -363,13 +365,13 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { }); } - if (!isInstwiseShiftValid(*forInst, shifts)) { + if (!isInstwiseShiftValid(forOp, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); return success(); } - if (instBodySkew(forInst, shifts)) { + if (instBodySkew(forOp, shifts)) { LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); return success(); } diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index ba59123c700..ae003b3e495 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Function.h" #include "mlir/IR/Instructions.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 59da2b0a56e..ce16656243d 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/LoopUtils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -39,22 +40,22 @@ using namespace mlir; /// Returns the upper bound of an unrolled loop with lower bound 'lb' and with /// the specified trip count, stride, and unroll factor. Returns nullptr when /// the trip count can't be expressed as an affine expression. -AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, +AffineMap mlir::getUnrolledLoopUpperBound(ConstOpPointer<AffineForOp> forOp, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forInst.getLowerBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap(); // Sometimes, the trip count cannot be expressed as an affine expression. - auto tripCount = getTripCountExpr(forInst); + auto tripCount = getTripCountExpr(forOp); if (!tripCount) return AffineMap(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forInst.getStep(); + unsigned step = forOp->getStep(); auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), @@ -65,50 +66,51 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst, /// bound 'lb' and with the specified trip count, stride, and unroll factor. /// Returns an AffinMap with nullptr storage (that evaluates to false) /// when the trip count can't be expressed as an affine expression. -AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst, +AffineMap mlir::getCleanupLoopLowerBound(ConstOpPointer<AffineForOp> forOp, unsigned unrollFactor, FuncBuilder *builder) { - auto lbMap = forInst.getLowerBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); // Single result lower bound map only. if (lbMap.getNumResults() != 1) return AffineMap(); // Sometimes the trip count cannot be expressed as an affine expression. - AffineExpr tripCount(getTripCountExpr(forInst)); + AffineExpr tripCount(getTripCountExpr(forOp)); if (!tripCount) return AffineMap(); AffineExpr lb(lbMap.getResult(0)); - unsigned step = forInst.getStep(); + unsigned step = forOp->getStep(); auto newLb = lb + (tripCount - tripCount % unrollFactor) * step; return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(), {newLb}, {}); } -/// Promotes the loop body of a forInst to its containing block if the forInst +/// Promotes the loop body of a forOp to its containing block if the forOp /// was known to have a single iteration. Returns false otherwise. // TODO(bondhugula): extend this for arbitrary affine bounds. -bool mlir::promoteIfSingleIteration(ForInst *forInst) { - Optional<uint64_t> tripCount = getConstantTripCount(*forInst); +bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) { + Optional<uint64_t> tripCount = getConstantTripCount(forOp); if (!tripCount.hasValue() || tripCount.getValue() != 1) return false; // TODO(mlir-team): there is no builder for a max. - if (forInst->getLowerBoundMap().getNumResults() != 1) + if (forOp->getLowerBoundMap().getNumResults() != 1) return false; // Replaces all IV uses to its single iteration value. - auto *iv = forInst->getInductionVar(); + auto *iv = forOp->getInductionVar(); + OperationInst *forInst = forOp->getInstruction(); if (!iv->use_empty()) { - if (forInst->hasConstantLowerBound()) { + if (forOp->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); FuncBuilder topBuilder(mlFunc); auto constOp = topBuilder.create<ConstantIndexOp>( - forInst->getLoc(), forInst->getConstantLowerBound()); + forOp->getLoc(), forOp->getConstantLowerBound()); iv->replaceAllUsesWith(constOp); } else { - const AffineBound lb = forInst->getLowerBound(); + const AffineBound lb = forOp->getLowerBound(); SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end()); FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); if (lb.getMap() == builder.getDimIdentityMap()) { @@ -124,8 +126,8 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { // Move the loop body instructions to the loop's containing block. auto *block = forInst->getBlock(); block->getInstructions().splice(Block::iterator(forInst), - forInst->getBody()->getInstructions()); - forInst->erase(); + forOp->getBody()->getInstructions()); + forOp->erase(); return true; } @@ -133,13 +135,10 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class LoopBodyPromoter : public InstWalker<LoopBodyPromoter> { - public: - void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); } - }; - - LoopBodyPromoter fsw; - fsw.walkPostOrder(f); + f->walkOpsPostOrder([](OperationInst *inst) { + if (auto forOp = inst->dyn_cast<AffineForOp>()) + promoteIfSingleIteration(forOp); + }); } /// Generates a 'for' inst with the specified lower and upper bounds while @@ -149,19 +148,22 @@ void mlir::promoteSingleIterationLoops(Function *f) { /// the pair specifies the shift applied to that group of instructions; note /// that the shift is multiplied by the loop step before being applied. Returns /// nullptr if the generated loop simplifies to a single iteration one. -static ForInst * +static OpPointer<AffineForOp> generateLoop(AffineMap lbMap, AffineMap ubMap, const std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> &instGroupQueue, - unsigned offset, ForInst *srcForInst, FuncBuilder *b) { + unsigned offset, OpPointer<AffineForOp> srcForInst, + FuncBuilder *b) { SmallVector<Value *, 4> lbOperands(srcForInst->getLowerBoundOperands()); SmallVector<Value *, 4> ubOperands(srcForInst->getUpperBoundOperands()); assert(lbMap.getNumInputs() == lbOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size()); - auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap, - ubOperands, ubMap, srcForInst->getStep()); + auto loopChunk = + b->create<AffineForOp>(srcForInst->getLoc(), lbOperands, lbMap, + ubOperands, ubMap, srcForInst->getStep()); + loopChunk->createBody(); auto *loopChunkIV = loopChunk->getInductionVar(); auto *srcIV = srcForInst->getInductionVar(); @@ -176,7 +178,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV->use_empty() && shift != 0) { - auto b = FuncBuilder::getForInstBodyBuilder(loopChunk); + FuncBuilder b(loopChunk->getBody()); auto ivRemap = b.create<AffineApplyOp>( srcForInst->getLoc(), b.getSingleDimShiftAffineMap( @@ -191,7 +193,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, } } if (promoteIfSingleIteration(loopChunk)) - return nullptr; + return OpPointer<AffineForOp>(); return loopChunk; } @@ -210,28 +212,29 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // asserts preservation of SSA dominance. A check for that as well as that for // memory-based depedence preservation check rests with the users of this // method. -UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts, +UtilResult mlir::instBodySkew(OpPointer<AffineForOp> forOp, + ArrayRef<uint64_t> shifts, bool unrollPrologueEpilogue) { - if (forInst->getBody()->empty()) + if (forOp->getBody()->empty()) return UtilResult::Success; // If the trip counts aren't constant, we would need versioning and // conditional guards (or context information to prevent such versioning). The // better way to pipeline for such loops is to first tile them and extract // constant trip count "full tiles" before applying this. - auto mayBeConstTripCount = getConstantTripCount(*forInst); + auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";); return UtilResult::Success; } uint64_t tripCount = mayBeConstTripCount.getValue(); - assert(isInstwiseShiftValid(*forInst, shifts) && + assert(isInstwiseShiftValid(forOp, shifts) && "shifts will lead to an invalid transformation\n"); - int64_t step = forInst->getStep(); + int64_t step = forOp->getStep(); - unsigned numChildInsts = forInst->getBody()->getInstructions().size(); + unsigned numChildInsts = forOp->getBody()->getInstructions().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -249,7 +252,7 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts, // body of the 'for' inst. std::vector<std::vector<Instruction *>> sortedInstGroups(maxShift + 1); unsigned pos = 0; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto shift = shifts[pos++]; sortedInstGroups[shift].push_back(&inst); } @@ -259,17 +262,17 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts, // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first // loop generated as the prologue and the last as epilogue and unroll these // fully. - ForInst *prologue = nullptr; - ForInst *epilogue = nullptr; + OpPointer<AffineForOp> prologue; + OpPointer<AffineForOp> epilogue; // Do a sweep over the sorted shifts while storing open groups in a // vector, and generating loop portions as necessary during the sweep. A block // of instructions is paired with its shift. std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> instGroupQueue; - auto origLbMap = forInst->getLowerBoundMap(); + auto origLbMap = forOp->getLowerBoundMap(); uint64_t lbShift = 0; - FuncBuilder b(forInst); + FuncBuilder b(forOp->getInstruction()); for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) { // If nothing is shifted by d, continue. if (sortedInstGroups[d].empty()) @@ -280,19 +283,19 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts, // The interval for which the loop needs to be generated here is: // [lbShift, min(lbShift + tripCount, d)) and the body of the // loop needs to have all instructions in instQueue in that order. - ForInst *res; + OpPointer<AffineForOp> res; if (lbShift + tripCount * step < d * step) { res = generateLoop( b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step), - instGroupQueue, 0, forInst, &b); + instGroupQueue, 0, forOp, &b); // Entire loop for the queued inst groups generated, empty it. instGroupQueue.clear(); lbShift += tripCount * step; } else { res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, d), instGroupQueue, - 0, forInst, &b); + 0, forOp, &b); lbShift = d * step; } if (!prologue && res) @@ -312,60 +315,63 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts, uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step; epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift), b.getShiftedAffineMap(origLbMap, ubShift), - instGroupQueue, i, forInst, &b); + instGroupQueue, i, forOp, &b); lbShift = ubShift; if (!prologue) prologue = epilogue; } // Erase the original for inst. - forInst->erase(); + forOp->erase(); if (unrollPrologueEpilogue && prologue) loopUnrollFull(prologue); - if (unrollPrologueEpilogue && !epilogue && epilogue != prologue) + if (unrollPrologueEpilogue && !epilogue && + epilogue->getInstruction() != prologue->getInstruction()) loopUnrollFull(epilogue); return UtilResult::Success; } /// Unrolls this loop completely. -bool mlir::loopUnrollFull(ForInst *forInst) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollFull(OpPointer<AffineForOp> forOp) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue()) { uint64_t tripCount = mayBeConstantTripCount.getValue(); if (tripCount == 1) { - return promoteIfSingleIteration(forInst); + return promoteIfSingleIteration(forOp); } - return loopUnrollByFactor(forInst, tripCount); + return loopUnrollByFactor(forOp, tripCount); } return false; } /// Unrolls and jams this loop by the specified factor or by the trip count (if /// constant) whichever is lower. -bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollUpToFactor(OpPointer<AffineForOp> forOp, + uint64_t unrollFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollFactor) - return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue()); - return loopUnrollByFactor(forInst, unrollFactor); + return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollByFactor(forOp, unrollFactor); } /// Unrolls this loop by the specified factor. Returns true if the loop /// is successfully unrolled. -bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { +bool mlir::loopUnrollByFactor(OpPointer<AffineForOp> forOp, + uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); if (unrollFactor == 1) - return promoteIfSingleIteration(forInst); + return promoteIfSingleIteration(forOp); - if (forInst->getBody()->empty()) + if (forOp->getBody()->empty()) return false; - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -376,10 +382,10 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different operand lists. - if (!forInst->matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return false; - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); // If the trip count is lower than the unroll factor, no unrolled body. // TODO(bondhugula): option to specify cleanup loop unrolling. @@ -388,10 +394,12 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) { + OperationInst *forInst = forOp->getInstruction(); + if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); - auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst)); - auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder); + auto cleanupForInst = + cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>(); + auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " "always be determined"); @@ -401,50 +409,50 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { // Adjust upper bound. auto unrolledUbMap = - getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder); + getUnrolledLoopUpperBound(forOp, unrollFactor, &builder); assert(unrolledUbMap && "upper bound map can alwayys be determined for an unrolled loop " "with single result bounds"); - forInst->setUpperBoundMap(unrolledUbMap); + forOp->setUpperBoundMap(unrolledUbMap); } // Scale the step of loop being unrolled by unroll factor. - int64_t step = forInst->getStep(); - forInst->setStep(step * unrollFactor); + int64_t step = forOp->getStep(); + forOp->setStep(step * unrollFactor); // Builder to insert unrolled bodies right after the last instruction in the - // body of 'forInst'. - FuncBuilder builder(forInst->getBody(), forInst->getBody()->end()); + // body of 'forOp'. + FuncBuilder builder(forOp->getBody(), forOp->getBody()->end()); // Keep a pointer to the last instruction in the original block so that we // know what to clone (since we are doing this in-place). - Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end()); + Block::iterator srcBlockEnd = std::prev(forOp->getBody()->end()); - // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies). - auto *forInstIV = forInst->getInductionVar(); + // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies). + auto *forOpIV = forOp->getInductionVar(); for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInstIV->use_empty()) { + if (!forOpIV->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto ivUnroll = - builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV); - operandMap.map(forInstIV, ivUnroll); + builder.create<AffineApplyOp>(forOp->getLoc(), bumpMap, forOpIV); + operandMap.map(forOpIV, ivUnroll); } - // Clone the original body of 'forInst'. - for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd); + // Clone the original body of 'forOp'. + for (auto it = forOp->getBody()->begin(); it != std::next(srcBlockEnd); it++) { builder.clone(*it, operandMap); } } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forInst); + promoteIfSingleIteration(forOp); return true; } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index d3689d056d6..819f1a59b6f 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/Utils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Dominance.h" @@ -278,8 +279,8 @@ void mlir::createAffineComputationSlice( /// Folds the specified (lower or upper) bound to a constant if possible /// considering its operands. Returns false if the folding happens for any of /// the bounds, true otherwise. -bool mlir::constantFoldBounds(ForInst *forInst) { - auto foldLowerOrUpperBound = [forInst](bool lower) { +bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) { + auto foldLowerOrUpperBound = [&forInst](bool lower) { // Check if the bound is already a constant. if (lower && forInst->hasConstantLowerBound()) return true; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ac551d7c20c..7f26161e520 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" @@ -252,9 +253,9 @@ using namespace mlir; /// ========== /// The algorithm proceeds in a few steps: /// 1. defining super-vectorization patterns and matching them on the tree of -/// ForInst. A super-vectorization pattern is defined as a recursive data -/// structures that matches and captures nested, imperfectly-nested loops -/// that have a. comformable loop annotations attached (e.g. parallel, +/// AffineForOp. A super-vectorization pattern is defined as a recursive +/// data structures that matches and captures nested, imperfectly-nested +/// loops that have a. comformable loop annotations attached (e.g. parallel, /// reduction, vectoriable, ...) as well as b. all contiguous load/store /// operations along a specified minor dimension (not necessarily the /// fastest varying) ; @@ -279,11 +280,11 @@ using namespace mlir; /// it by its vector form. Otherwise, if the scalar value is a constant, /// it is vectorized into a splat. In all other cases, vectorization for /// the pattern currently fails. -/// e. if everything under the root ForInst in the current pattern vectorizes -/// properly, we commit that loop to the IR. Otherwise we discard it and -/// restore a previously cloned version of the loop. Thanks to the -/// recursive scoping nature of matchers and captured patterns, this is -/// transparently achieved by a simple RAII implementation. +/// e. if everything under the root AffineForOp in the current pattern +/// vectorizes properly, we commit that loop to the IR. Otherwise we +/// discard it and restore a previously cloned version of the loop. Thanks +/// to the recursive scoping nature of matchers and captured patterns, +/// this is transparently achieved by a simple RAII implementation. /// f. vectorization is applied on the next pattern in the list. Because /// pattern interference avoidance is not yet implemented and that we do /// not support further vectorizing an already vector load we need to @@ -667,12 +668,13 @@ namespace { struct VectorizationStrategy { SmallVector<int64_t, 8> vectorSizes; - DenseMap<ForInst *, unsigned> loopToVectorDim; + DenseMap<Instruction *, unsigned> loopToVectorDim; }; } // end anonymous namespace -static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, +static void vectorizeLoopIfProfitable(Instruction *loop, + unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { assert(patternDepth > depthInPattern && @@ -704,13 +706,13 @@ static bool analyzeProfitability(ArrayRef<NestedMatch> matches, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForInst>(m.getMatchedInstruction()); bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, patternDepth, strategy); if (fail) { return fail; } - vectorizeLoopIfProfitable(loop, depthInPattern, patternDepth, strategy); + vectorizeLoopIfProfitable(m.getMatchedInstruction(), depthInPattern, + patternDepth, strategy); } return false; } @@ -855,8 +857,8 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, /// Coarsens the loops bounds and transforms all remaining load and store /// operations into the appropriate vector_transfer. -static bool vectorizeForInst(ForInst *loop, int64_t step, - VectorizationState *state) { +static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, + VectorizationState *state) { using namespace functional; loop->setStep(step); @@ -873,7 +875,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector<NestedMatch, 8> loadAndStoresMatches; - loadAndStores.match(loop, &loadAndStoresMatches); + loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { auto *opInst = cast<OperationInst>(ls.getMatchedInstruction()); auto load = opInst->dyn_cast<LoadOp>(); @@ -898,7 +900,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { return [fastestVaryingMemRefDimension](const Instruction &forInst) { - const auto &loop = cast<ForInst>(forInst); + auto loop = cast<OperationInst>(forInst).cast<AffineForOp>(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -912,7 +914,8 @@ static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, /// if all vectorizations in `childrenMatches` have already succeeded /// recursively in DFS post-order. static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { - ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction()); + auto *loopInst = oneMatch.getMatchedInstruction(); + auto loop = cast<OperationInst>(loopInst)->cast<AffineForOp>(); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -924,7 +927,7 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { // 2. This loop may have been omitted from vectorization for various reasons // (e.g. due to the performance model or pattern depth > vector size). - auto it = state->strategy->loopToVectorDim.find(loop); + auto it = state->strategy->loopToVectorDim.find(loopInst); if (it == state->strategy->loopToVectorDim.end()) { return false; } @@ -939,10 +942,10 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.: // | ub -> ub // | step -> step * vectorSize - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize << " : "); - LLVM_DEBUG(loop->print(dbgs())); - return vectorizeForInst(loop, loop->getStep() * vectorSize, state); + LLVM_DEBUG(loopInst->print(dbgs())); + return vectorizeAffineForOp(loop, loop->getStep() * vectorSize, state); } /// Non-root pattern iterates over the matches at this level, calls doVectorize @@ -1186,7 +1189,8 @@ static bool vectorizeOperations(VectorizationState *state) { /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto *loop = cast<ForInst>(m.getMatchedInstruction()); + auto loop = + cast<OperationInst>(m.getMatchedInstruction())->cast<AffineForOp>(); VectorizationState state; state.strategy = strategy; @@ -1197,17 +1201,20 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { // vectorizable. If a pattern is not vectorizable anymore, we just skip it. // TODO(ntv): implement a non-greedy profitability analysis that keeps only // non-intersecting patterns. - if (!isVectorizableLoop(*loop)) { + if (!isVectorizableLoop(loop)) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); return true; } - FuncBuilder builder(loop); // builder to insert in place of loop - ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop)); + auto *loopInst = loop->getInstruction(); + FuncBuilder builder(loopInst); + auto clonedLoop = + cast<OperationInst>(builder.clone(*loopInst))->cast<AffineForOp>(); + auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via /// RAII. - ScopeGuard sg2([&fail, loop, clonedLoop]() { + ScopeGuard sg2([&fail, &loop, &clonedLoop]() { if (fail) { loop->getInductionVar()->replaceAllUsesWith( clonedLoop->getInductionVar()); @@ -1291,8 +1298,8 @@ PassResult Vectorize::runOnFunction(Function *f) { if (fail) { continue; } - auto *loop = cast<ForInst>(m.getMatchedInstruction()); - vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); + vectorizeLoopIfProfitable(m.getMatchedInstruction(), 0, patternDepth, + &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. fail = vectorizeRootMatch(m, &strategy); |

