summaryrefslogtreecommitdiffstats
path: root/mlir/lib/AffineOps/AffineOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/AffineOps/AffineOps.cpp')
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp443
1 files changed, 442 insertions, 1 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
index 5b29467fc44..f1693c8e449 100644
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -17,7 +17,10 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
@@ -27,7 +30,445 @@ using namespace mlir;
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
- addOperations<AffineIfOp>();
+ addOperations<AffineForOp, AffineIfOp>();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineForOp
+//===----------------------------------------------------------------------===//
+
+void AffineForOp::build(Builder *builder, OperationState *result,
+ ArrayRef<Value *> lbOperands, AffineMap lbMap,
+ ArrayRef<Value *> ubOperands, AffineMap ubMap,
+ int64_t step) {
+ assert((!lbMap && lbOperands.empty()) ||
+ lbOperands.size() == lbMap.getNumInputs() &&
+ "lower bound operand count does not match the affine map");
+ assert((!ubMap && ubOperands.empty()) ||
+ ubOperands.size() == ubMap.getNumInputs() &&
+ "upper bound operand count does not match the affine map");
+ assert(step > 0 && "step has to be a positive integer constant");
+
+ // Add an attribute for the step.
+ result->addAttribute(getStepAttrName(),
+ builder->getIntegerAttr(builder->getIndexType(), step));
+
+ // Add the lower bound.
+ result->addAttribute(getLowerBoundAttrName(),
+ builder->getAffineMapAttr(lbMap));
+ result->addOperands(lbOperands);
+
+ // Add the upper bound.
+ result->addAttribute(getUpperBoundAttrName(),
+ builder->getAffineMapAttr(ubMap));
+ result->addOperands(ubOperands);
+
+ // Reserve a block list for the body.
+ result->reserveBlockLists(/*numReserved=*/1);
+
+ // Set the operands list as resizable so that we can freely modify the bounds.
+ result->setOperandListToResizable();
+}
+
+void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb,
+ int64_t ub, int64_t step) {
+ auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
+ auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
+ return build(builder, result, {}, lbMap, {}, ubMap, step);
+}
+
+bool AffineForOp::verify() const {
+ const auto &bodyBlockList = getInstruction()->getBlockList(0);
+
+ // The body block list must contain a single basic block.
+ if (bodyBlockList.empty() ||
+ std::next(bodyBlockList.begin()) != bodyBlockList.end())
+ return emitOpError("expected body block list to have a single block");
+
+ // Check that the body defines as single block argument for the induction
+ // variable.
+ const auto *body = getBody();
+ if (body->getNumArguments() != 1 ||
+ !body->getArgument(0)->getType().isIndex())
+ return emitOpError("expected body to have a single index argument for the "
+ "induction variable");
+
+ // TODO: check that loop bounds are properly formed.
+ return false;
+}
+
+/// Parse a for operation loop bounds.
+static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
+ // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
+ // the map has multiple results.
+ bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min");
+
+ auto &builder = p->getBuilder();
+ auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
+ : AffineForOp::getUpperBoundAttrName();
+
+ // Parse ssa-id as identity map.
+ SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
+ if (p->parseOperandList(boundOpInfos))
+ return true;
+
+ if (!boundOpInfos.empty()) {
+ // Check that only one operand was parsed.
+ if (boundOpInfos.size() > 1)
+ return p->emitError(p->getNameLoc(),
+ "expected only one loop bound operand");
+
+ // TODO: improve error message when SSA value is not an affine integer.
+ // Currently it is 'use of value ... expects different type than prior uses'
+ if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
+ result->operands))
+ return true;
+
+ // Create an identity map using symbol id. This representation is optimized
+ // for storage. Analysis passes may expand it into a multi-dimensional map
+ // if desired.
+ AffineMap map = builder.getSymbolIdentityMap();
+ result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
+ return false;
+ }
+
+ Attribute boundAttr;
+ if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName.data(),
+ result->attributes))
+ return true;
+
+ // Parse full form - affine map followed by dim and symbol list.
+ if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
+ unsigned currentNumOperands = result->operands.size();
+ unsigned numDims;
+ if (parseDimAndSymbolList(p, result->operands, numDims))
+ return true;
+
+ auto map = affineMapAttr.getValue();
+ if (map.getNumDims() != numDims)
+ return p->emitError(
+ p->getNameLoc(),
+ "dim operand count and integer set dim count must match");
+
+ unsigned numDimAndSymbolOperands =
+ result->operands.size() - currentNumOperands;
+ if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
+ return p->emitError(
+ p->getNameLoc(),
+ "symbol operand count and integer set symbol count must match");
+
+ // If the map has multiple results, make sure that we parsed the min/max
+ // prefix.
+ if (map.getNumResults() > 1 && failedToParsedMinMax) {
+ if (isLower) {
+ return p->emitError(p->getNameLoc(),
+ "lower loop bound affine map with multiple results "
+ "requires 'max' prefix");
+ }
+ return p->emitError(p->getNameLoc(),
+ "upper loop bound affine map with multiple results "
+ "requires 'min' prefix");
+ }
+ return false;
+ }
+
+ // Parse custom assembly form.
+ if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
+ result->attributes.pop_back();
+ result->addAttribute(
+ boundAttrName, builder.getAffineMapAttr(
+ builder.getConstantAffineMap(integerAttr.getInt())));
+ return false;
+ }
+
+ return p->emitError(
+ p->getNameLoc(),
+ "expected valid affine map representation for loop bounds");
+}
+
+bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ // Parse the induction variable followed by '='.
+ if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) ||
+ parser->parseEqual())
+ return true;
+
+ // Parse loop bounds.
+ if (parseBound(/*isLower=*/true, result, parser) ||
+ parser->parseKeyword("to", " between bounds") ||
+ parseBound(/*isLower=*/false, result, parser))
+ return true;
+
+ // Parse the optional loop step, we default to 1 if one is not present.
+ if (parser->parseOptionalKeyword("step")) {
+ result->addAttribute(
+ getStepAttrName(),
+ builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
+ } else {
+ llvm::SMLoc stepLoc;
+ IntegerAttr stepAttr;
+ if (parser->getCurrentLocation(&stepLoc) ||
+ parser->parseAttribute(stepAttr, builder.getIndexType(),
+ getStepAttrName().data(), result->attributes))
+ return true;
+
+ if (stepAttr.getValue().getSExtValue() < 0)
+ return parser->emitError(
+ stepLoc,
+ "expected step to be representable as a positive signed integer");
+ }
+
+ // Parse the body block list.
+ result->reserveBlockLists(/*numReserved=*/1);
+ if (parser->parseBlockList())
+ return true;
+
+ // Set the operands list as resizable so that we can freely modify the bounds.
+ result->setOperandListToResizable();
+ return false;
+}
+
+static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) {
+ AffineMap map = bound.getMap();
+
+ // Check if this bound should be printed using custom assembly form.
+ // The decision to restrict printing custom assembly form to trivial cases
+ // comes from the will to roundtrip MLIR binary -> text -> binary in a
+ // lossless way.
+ // Therefore, custom assembly form parsing and printing is only supported for
+ // zero-operand constant maps and single symbol operand identity maps.
+ if (map.getNumResults() == 1) {
+ AffineExpr expr = map.getResult(0);
+
+ // Print constant bound.
+ if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
+ if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ *p << constExpr.getValue();
+ return;
+ }
+ }
+
+ // Print bound that consists of a single SSA symbol if the map is over a
+ // single symbol.
+ if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
+ if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
+ p->printOperand(bound.getOperand(0));
+ return;
+ }
+ }
+ } else {
+ // Map has multiple results. Print 'min' or 'max' prefix.
+ *p << prefix << ' ';
+ }
+
+ // Print the map and its operands.
+ p->printAffineMap(map);
+ printDimAndSymbolList(bound.operand_begin(), bound.operand_end(),
+ map.getNumDims(), p);
+}
+
+void AffineForOp::print(OpAsmPrinter *p) const {
+ *p << "for ";
+ p->printOperand(getBody()->getArgument(0));
+ *p << " = ";
+ printBound(getLowerBound(), "max", p);
+ *p << " to ";
+ printBound(getUpperBound(), "min", p);
+
+ if (getStep() != 1)
+ *p << " step " << getStep();
+ p->printBlockList(getInstruction()->getBlockList(0),
+ /*printEntryBlockArgs=*/false);
+}
+
+Block *AffineForOp::createBody() {
+ auto &bodyBlockList = getBlockList();
+ assert(bodyBlockList.empty() && "expected no existing body blocks");
+
+ // Create a new block for the body, and add an argument for the induction
+ // variable.
+ Block *body = new Block();
+ body->addArgument(IndexType::get(getInstruction()->getContext()));
+ bodyBlockList.push_back(body);
+ return body;
+}
+
+const AffineBound AffineForOp::getLowerBound() const {
+ auto lbMap = getLowerBoundMap();
+ return AffineBound(ConstOpPointer<AffineForOp>(*this), 0,
+ lbMap.getNumInputs(), lbMap);
+}
+
+const AffineBound AffineForOp::getUpperBound() const {
+ auto lbMap = getLowerBoundMap();
+ auto ubMap = getUpperBoundMap();
+ return AffineBound(ConstOpPointer<AffineForOp>(*this), lbMap.getNumInputs(),
+ getNumOperands(), ubMap);
+}
+
+void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
+ assert(lbOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+ SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
+
+ auto ubOperands = getUpperBoundOperands();
+ newOperands.append(ubOperands.begin(), ubOperands.end());
+ getInstruction()->setOperands(newOperands);
+
+ setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
+ assert(ubOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+ SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
+ newOperands.append(ubOperands.begin(), ubOperands.end());
+ getInstruction()->setOperands(newOperands);
+
+ setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+void AffineForOp::setLowerBoundMap(AffineMap map) {
+ auto lbMap = getLowerBoundMap();
+ assert(lbMap.getNumDims() == map.getNumDims() &&
+ lbMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+ (void)lbMap;
+ setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBoundMap(AffineMap map) {
+ auto ubMap = getUpperBoundMap();
+ assert(ubMap.getNumDims() == map.getNumDims() &&
+ ubMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+ (void)ubMap;
+ setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+bool AffineForOp::hasConstantLowerBound() const {
+ return getLowerBoundMap().isSingleConstant();
+}
+
+bool AffineForOp::hasConstantUpperBound() const {
+ return getUpperBoundMap().isSingleConstant();
+}
+
+int64_t AffineForOp::getConstantLowerBound() const {
+ return getLowerBoundMap().getSingleConstantResult();
+}
+
+int64_t AffineForOp::getConstantUpperBound() const {
+ return getUpperBoundMap().getSingleConstantResult();
+}
+
+void AffineForOp::setConstantLowerBound(int64_t value) {
+ setLowerBound(
+ {}, AffineMap::getConstantMap(value, getInstruction()->getContext()));
+}
+
+void AffineForOp::setConstantUpperBound(int64_t value) {
+ setUpperBound(
+ {}, AffineMap::getConstantMap(value, getInstruction()->getContext()));
+}
+
+AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
+ return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
+}
+
+AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const {
+ return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
+}
+
+AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
+ return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
+}
+
+AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const {
+ return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
+}
+
+bool AffineForOp::matchingBoundOperandList() const {
+ auto lbMap = getLowerBoundMap();
+ auto ubMap = getUpperBoundMap();
+ if (lbMap.getNumDims() != ubMap.getNumDims() ||
+ lbMap.getNumSymbols() != ubMap.getNumSymbols())
+ return false;
+
+ unsigned numOperands = lbMap.getNumInputs();
+ for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
+ // Compare Value *'s.
+ if (getOperand(i) != getOperand(numOperands + i))
+ return false;
+ }
+ return true;
+}
+
+void AffineForOp::walkOps(std::function<void(OperationInst *)> callback) {
+ struct Walker : public InstWalker<Walker> {
+ std::function<void(OperationInst *)> const &callback;
+ Walker(std::function<void(OperationInst *)> const &callback)
+ : callback(callback) {}
+
+ void visitOperationInst(OperationInst *opInst) { callback(opInst); }
+ };
+
+ Walker w(callback);
+ w.walk(getInstruction());
+}
+
+void AffineForOp::walkOpsPostOrder(
+ std::function<void(OperationInst *)> callback) {
+ struct Walker : public InstWalker<Walker> {
+ std::function<void(OperationInst *)> const &callback;
+ Walker(std::function<void(OperationInst *)> const &callback)
+ : callback(callback) {}
+
+ void visitOperationInst(OperationInst *opInst) { callback(opInst); }
+ };
+
+ Walker v(callback);
+ v.walkPostOrder(getInstruction());
+}
+
+/// Returns the induction variable for this loop.
+Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); }
+
+/// Returns if the provided value is the induction variable of a AffineForOp.
+bool mlir::isForInductionVar(const Value *val) {
+ return getForInductionVarOwner(val) != nullptr;
+}
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+OpPointer<AffineForOp> mlir::getForInductionVarOwner(Value *val) {
+ const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
+ if (!ivArg || !ivArg->getOwner())
+ return OpPointer<AffineForOp>();
+ auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst();
+ if (!containingInst)
+ return OpPointer<AffineForOp>();
+ return cast<OperationInst>(containingInst)->dyn_cast<AffineForOp>();
+}
+ConstOpPointer<AffineForOp> mlir::getForInductionVarOwner(const Value *val) {
+ auto nonConstOwner = getForInductionVarOwner(const_cast<Value *>(val));
+ return ConstOpPointer<AffineForOp>(nonConstOwner);
+}
+
+/// Extracts the induction variables from a list of AffineForOps and returns
+/// them.
+SmallVector<Value *, 8> mlir::extractForInductionVars(
+ MutableArrayRef<OpPointer<AffineForOp>> forInsts) {
+ SmallVector<Value *, 8> results;
+ for (auto forInst : forInsts)
+ results.push_back(forInst->getInductionVar());
+ return results;
}
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud