diff options
Diffstat (limited to 'mlir/lib/AffineOps/AffineOps.cpp')
| -rw-r--r-- | mlir/lib/AffineOps/AffineOps.cpp | 443 |
1 files changed, 442 insertions, 1 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 5b29467fc44..f1693c8e449 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -17,7 +17,10 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" using namespace mlir; @@ -27,7 +30,445 @@ using namespace mlir; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { - addOperations<AffineIfOp>(); + addOperations<AffineForOp, AffineIfOp>(); +} + +//===----------------------------------------------------------------------===// +// AffineForOp +//===----------------------------------------------------------------------===// + +void AffineForOp::build(Builder *builder, OperationState *result, + ArrayRef<Value *> lbOperands, AffineMap lbMap, + ArrayRef<Value *> ubOperands, AffineMap ubMap, + int64_t step) { + assert((!lbMap && lbOperands.empty()) || + lbOperands.size() == lbMap.getNumInputs() && + "lower bound operand count does not match the affine map"); + assert((!ubMap && ubOperands.empty()) || + ubOperands.size() == ubMap.getNumInputs() && + "upper bound operand count does not match the affine map"); + assert(step > 0 && "step has to be a positive integer constant"); + + // Add an attribute for the step. + result->addAttribute(getStepAttrName(), + builder->getIntegerAttr(builder->getIndexType(), step)); + + // Add the lower bound. + result->addAttribute(getLowerBoundAttrName(), + builder->getAffineMapAttr(lbMap)); + result->addOperands(lbOperands); + + // Add the upper bound. + result->addAttribute(getUpperBoundAttrName(), + builder->getAffineMapAttr(ubMap)); + result->addOperands(ubOperands); + + // Reserve a block list for the body. + result->reserveBlockLists(/*numReserved=*/1); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); +} + +void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step) { + auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); + auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); + return build(builder, result, {}, lbMap, {}, ubMap, step); +} + +bool AffineForOp::verify() const { + const auto &bodyBlockList = getInstruction()->getBlockList(0); + + // The body block list must contain a single basic block. + if (bodyBlockList.empty() || + std::next(bodyBlockList.begin()) != bodyBlockList.end()) + return emitOpError("expected body block list to have a single block"); + + // Check that the body defines as single block argument for the induction + // variable. + const auto *body = getBody(); + if (body->getNumArguments() != 1 || + !body->getArgument(0)->getType().isIndex()) + return emitOpError("expected body to have a single index argument for the " + "induction variable"); + + // TODO: check that loop bounds are properly formed. + return false; +} + +/// Parse a for operation loop bounds. +static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { + // 'min' / 'max' prefixes are generally syntactic sugar, but are required if + // the map has multiple results. + bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min"); + + auto &builder = p->getBuilder(); + auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() + : AffineForOp::getUpperBoundAttrName(); + + // Parse ssa-id as identity map. + SmallVector<OpAsmParser::OperandType, 1> boundOpInfos; + if (p->parseOperandList(boundOpInfos)) + return true; + + if (!boundOpInfos.empty()) { + // Check that only one operand was parsed. + if (boundOpInfos.size() > 1) + return p->emitError(p->getNameLoc(), + "expected only one loop bound operand"); + + // TODO: improve error message when SSA value is not an affine integer. + // Currently it is 'use of value ... expects different type than prior uses' + if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), + result->operands)) + return true; + + // Create an identity map using symbol id. This representation is optimized + // for storage. Analysis passes may expand it into a multi-dimensional map + // if desired. + AffineMap map = builder.getSymbolIdentityMap(); + result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); + return false; + } + + Attribute boundAttr; + if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName.data(), + result->attributes)) + return true; + + // Parse full form - affine map followed by dim and symbol list. + if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { + unsigned currentNumOperands = result->operands.size(); + unsigned numDims; + if (parseDimAndSymbolList(p, result->operands, numDims)) + return true; + + auto map = affineMapAttr.getValue(); + if (map.getNumDims() != numDims) + return p->emitError( + p->getNameLoc(), + "dim operand count and integer set dim count must match"); + + unsigned numDimAndSymbolOperands = + result->operands.size() - currentNumOperands; + if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) + return p->emitError( + p->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // If the map has multiple results, make sure that we parsed the min/max + // prefix. + if (map.getNumResults() > 1 && failedToParsedMinMax) { + if (isLower) { + return p->emitError(p->getNameLoc(), + "lower loop bound affine map with multiple results " + "requires 'max' prefix"); + } + return p->emitError(p->getNameLoc(), + "upper loop bound affine map with multiple results " + "requires 'min' prefix"); + } + return false; + } + + // Parse custom assembly form. + if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { + result->attributes.pop_back(); + result->addAttribute( + boundAttrName, builder.getAffineMapAttr( + builder.getConstantAffineMap(integerAttr.getInt()))); + return false; + } + + return p->emitError( + p->getNameLoc(), + "expected valid affine map representation for loop bounds"); +} + +bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + // Parse the induction variable followed by '='. + if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) || + parser->parseEqual()) + return true; + + // Parse loop bounds. + if (parseBound(/*isLower=*/true, result, parser) || + parser->parseKeyword("to", " between bounds") || + parseBound(/*isLower=*/false, result, parser)) + return true; + + // Parse the optional loop step, we default to 1 if one is not present. + if (parser->parseOptionalKeyword("step")) { + result->addAttribute( + getStepAttrName(), + builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); + } else { + llvm::SMLoc stepLoc; + IntegerAttr stepAttr; + if (parser->getCurrentLocation(&stepLoc) || + parser->parseAttribute(stepAttr, builder.getIndexType(), + getStepAttrName().data(), result->attributes)) + return true; + + if (stepAttr.getValue().getSExtValue() < 0) + return parser->emitError( + stepLoc, + "expected step to be representable as a positive signed integer"); + } + + // Parse the body block list. + result->reserveBlockLists(/*numReserved=*/1); + if (parser->parseBlockList()) + return true; + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); + return false; +} + +static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) { + AffineMap map = bound.getMap(); + + // Check if this bound should be printed using custom assembly form. + // The decision to restrict printing custom assembly form to trivial cases + // comes from the will to roundtrip MLIR binary -> text -> binary in a + // lossless way. + // Therefore, custom assembly form parsing and printing is only supported for + // zero-operand constant maps and single symbol operand identity maps. + if (map.getNumResults() == 1) { + AffineExpr expr = map.getResult(0); + + // Print constant bound. + if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { + if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { + *p << constExpr.getValue(); + return; + } + } + + // Print bound that consists of a single SSA symbol if the map is over a + // single symbol. + if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { + if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { + p->printOperand(bound.getOperand(0)); + return; + } + } + } else { + // Map has multiple results. Print 'min' or 'max' prefix. + *p << prefix << ' '; + } + + // Print the map and its operands. + p->printAffineMap(map); + printDimAndSymbolList(bound.operand_begin(), bound.operand_end(), + map.getNumDims(), p); +} + +void AffineForOp::print(OpAsmPrinter *p) const { + *p << "for "; + p->printOperand(getBody()->getArgument(0)); + *p << " = "; + printBound(getLowerBound(), "max", p); + *p << " to "; + printBound(getUpperBound(), "min", p); + + if (getStep() != 1) + *p << " step " << getStep(); + p->printBlockList(getInstruction()->getBlockList(0), + /*printEntryBlockArgs=*/false); +} + +Block *AffineForOp::createBody() { + auto &bodyBlockList = getBlockList(); + assert(bodyBlockList.empty() && "expected no existing body blocks"); + + // Create a new block for the body, and add an argument for the induction + // variable. + Block *body = new Block(); + body->addArgument(IndexType::get(getInstruction()->getContext())); + bodyBlockList.push_back(body); + return body; +} + +const AffineBound AffineForOp::getLowerBound() const { + auto lbMap = getLowerBoundMap(); + return AffineBound(ConstOpPointer<AffineForOp>(*this), 0, + lbMap.getNumInputs(), lbMap); +} + +const AffineBound AffineForOp::getUpperBound() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + return AffineBound(ConstOpPointer<AffineForOp>(*this), lbMap.getNumInputs(), + getNumOperands(), ubMap); +} + +void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end()); + + auto ubOperands = getUpperBoundOperands(); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector<Value *, 4> newOperands(getLowerBoundOperands()); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getInstruction()->setOperands(newOperands); + + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setLowerBoundMap(AffineMap map) { + auto lbMap = getLowerBoundMap(); + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)lbMap; + setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBoundMap(AffineMap map) { + auto ubMap = getUpperBoundMap(); + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)ubMap; + setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()), + AffineMapAttr::get(map)); +} + +bool AffineForOp::hasConstantLowerBound() const { + return getLowerBoundMap().isSingleConstant(); +} + +bool AffineForOp::hasConstantUpperBound() const { + return getUpperBoundMap().isSingleConstant(); +} + +int64_t AffineForOp::getConstantLowerBound() const { + return getLowerBoundMap().getSingleConstantResult(); +} + +int64_t AffineForOp::getConstantUpperBound() const { + return getUpperBoundMap().getSingleConstantResult(); +} + +void AffineForOp::setConstantLowerBound(int64_t value) { + setLowerBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +void AffineForOp::setConstantUpperBound(int64_t value) { + setUpperBound( + {}, AffineMap::getConstantMap(value, getInstruction()->getContext())); +} + +AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +bool AffineForOp::matchingBoundOperandList() const { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) + return false; + + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { + // Compare Value *'s. + if (getOperand(i) != getOperand(numOperands + i)) + return false; + } + return true; +} + +void AffineForOp::walkOps(std::function<void(OperationInst *)> callback) { + struct Walker : public InstWalker<Walker> { + std::function<void(OperationInst *)> const &callback; + Walker(std::function<void(OperationInst *)> const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker w(callback); + w.walk(getInstruction()); +} + +void AffineForOp::walkOpsPostOrder( + std::function<void(OperationInst *)> callback) { + struct Walker : public InstWalker<Walker> { + std::function<void(OperationInst *)> const &callback; + Walker(std::function<void(OperationInst *)> const &callback) + : callback(callback) {} + + void visitOperationInst(OperationInst *opInst) { callback(opInst); } + }; + + Walker v(callback); + v.walkPostOrder(getInstruction()); +} + +/// Returns the induction variable for this loop. +Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool mlir::isForInductionVar(const Value *val) { + return getForInductionVarOwner(val) != nullptr; +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +OpPointer<AffineForOp> mlir::getForInductionVarOwner(Value *val) { + const BlockArgument *ivArg = dyn_cast<BlockArgument>(val); + if (!ivArg || !ivArg->getOwner()) + return OpPointer<AffineForOp>(); + auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); + if (!containingInst) + return OpPointer<AffineForOp>(); + return cast<OperationInst>(containingInst)->dyn_cast<AffineForOp>(); +} +ConstOpPointer<AffineForOp> mlir::getForInductionVarOwner(const Value *val) { + auto nonConstOwner = getForInductionVarOwner(const_cast<Value *>(val)); + return ConstOpPointer<AffineForOp>(nonConstOwner); +} + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +SmallVector<Value *, 8> mlir::extractForInductionVars( + MutableArrayRef<OpPointer<AffineForOp>> forInsts) { + SmallVector<Value *, 8> results; + for (auto forInst : forInsts) + results.push_back(forInst->getInductionVar()); + return results; } //===----------------------------------------------------------------------===// |

