summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp443
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp28
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp22
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp75
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp26
-rw-r--r--mlir/lib/Analysis/SliceAnalysis.cpp22
-rw-r--r--mlir/lib/Analysis/Utils.cpp76
-rw-r--r--mlir/lib/Analysis/VectorAnalysis.cpp37
-rw-r--r--mlir/lib/Analysis/Verifier.cpp14
-rw-r--r--mlir/lib/EDSC/MLIREmitter.cpp10
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp116
-rw-r--r--mlir/lib/IR/Builders.cpp16
-rw-r--r--mlir/lib/IR/Instruction.cpp337
-rw-r--r--mlir/lib/IR/Value.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp346
-rw-r--r--mlir/lib/Transforms/CSE.cpp5
-rw-r--r--mlir/lib/Transforms/ConstantFold.cpp13
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp47
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp200
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp86
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp67
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp86
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp68
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp20
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp70
-rw-r--r--mlir/lib/Transforms/SimplifyAffineStructures.cpp1
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp170
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp5
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp63
29 files changed, 1221 insertions, 1250 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
index 5b29467fc44..f1693c8e449 100644
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -17,7 +17,10 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
@@ -27,7 +30,445 @@ using namespace mlir;
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
- addOperations<AffineIfOp>();
+ addOperations<AffineForOp, AffineIfOp>();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineForOp
+//===----------------------------------------------------------------------===//
+
+void AffineForOp::build(Builder *builder, OperationState *result,
+ ArrayRef<Value *> lbOperands, AffineMap lbMap,
+ ArrayRef<Value *> ubOperands, AffineMap ubMap,
+ int64_t step) {
+ assert((!lbMap && lbOperands.empty()) ||
+ lbOperands.size() == lbMap.getNumInputs() &&
+ "lower bound operand count does not match the affine map");
+ assert((!ubMap && ubOperands.empty()) ||
+ ubOperands.size() == ubMap.getNumInputs() &&
+ "upper bound operand count does not match the affine map");
+ assert(step > 0 && "step has to be a positive integer constant");
+
+ // Add an attribute for the step.
+ result->addAttribute(getStepAttrName(),
+ builder->getIntegerAttr(builder->getIndexType(), step));
+
+ // Add the lower bound.
+ result->addAttribute(getLowerBoundAttrName(),
+ builder->getAffineMapAttr(lbMap));
+ result->addOperands(lbOperands);
+
+ // Add the upper bound.
+ result->addAttribute(getUpperBoundAttrName(),
+ builder->getAffineMapAttr(ubMap));
+ result->addOperands(ubOperands);
+
+ // Reserve a block list for the body.
+ result->reserveBlockLists(/*numReserved=*/1);
+
+ // Set the operands list as resizable so that we can freely modify the bounds.
+ result->setOperandListToResizable();
+}
+
+void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb,
+ int64_t ub, int64_t step) {
+ auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
+ auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
+ return build(builder, result, {}, lbMap, {}, ubMap, step);
+}
+
+bool AffineForOp::verify() const {
+ const auto &bodyBlockList = getInstruction()->getBlockList(0);
+
+ // The body block list must contain a single basic block.
+ if (bodyBlockList.empty() ||
+ std::next(bodyBlockList.begin()) != bodyBlockList.end())
+ return emitOpError("expected body block list to have a single block");
+
+ // Check that the body defines as single block argument for the induction
+ // variable.
+ const auto *body = getBody();
+ if (body->getNumArguments() != 1 ||
+ !body->getArgument(0)->getType().isIndex())
+ return emitOpError("expected body to have a single index argument for the "
+ "induction variable");
+
+ // TODO: check that loop bounds are properly formed.
+ return false;
+}
+
+/// Parse a for operation loop bounds.
+static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
+ // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
+ // the map has multiple results.
+ bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min");
+
+ auto &builder = p->getBuilder();
+ auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
+ : AffineForOp::getUpperBoundAttrName();
+
+ // Parse ssa-id as identity map.
+ SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
+ if (p->parseOperandList(boundOpInfos))
+ return true;
+
+ if (!boundOpInfos.empty()) {
+ // Check that only one operand was parsed.
+ if (boundOpInfos.size() > 1)
+ return p->emitError(p->getNameLoc(),
+ "expected only one loop bound operand");
+
+ // TODO: improve error message when SSA value is not an affine integer.
+ // Currently it is 'use of value ... expects different type than prior uses'
+ if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
+ result->operands))
+ return true;
+
+ // Create an identity map using symbol id. This representation is optimized
+ // for storage. Analysis passes may expand it into a multi-dimensional map
+ // if desired.
+ AffineMap map = builder.getSymbolIdentityMap();
+ result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
+ return false;
+ }
+
+ Attribute boundAttr;
+ if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName.data(),
+ result->attributes))
+ return true;
+
+ // Parse full form - affine map followed by dim and symbol list.
+ if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
+ unsigned currentNumOperands = result->operands.size();
+ unsigned numDims;
+ if (parseDimAndSymbolList(p, result->operands, numDims))
+ return true;
+
+ auto map = affineMapAttr.getValue();
+ if (map.getNumDims() != numDims)
+ return p->emitError(
+ p->getNameLoc(),
+ "dim operand count and integer set dim count must match");
+
+ unsigned numDimAndSymbolOperands =
+ result->operands.size() - currentNumOperands;
+ if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
+ return p->emitError(
+ p->getNameLoc(),
+ "symbol operand count and integer set symbol count must match");
+
+ // If the map has multiple results, make sure that we parsed the min/max
+ // prefix.
+ if (map.getNumResults() > 1 && failedToParsedMinMax) {
+ if (isLower) {
+ return p->emitError(p->getNameLoc(),
+ "lower loop bound affine map with multiple results "
+ "requires 'max' prefix");
+ }
+ return p->emitError(p->getNameLoc(),
+ "upper loop bound affine map with multiple results "
+ "requires 'min' prefix");
+ }
+ return false;
+ }
+
+ // Parse custom assembly form.
+ if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
+ result->attributes.pop_back();
+ result->addAttribute(
+ boundAttrName, builder.getAffineMapAttr(
+ builder.getConstantAffineMap(integerAttr.getInt())));
+ return false;
+ }
+
+ return p->emitError(
+ p->getNameLoc(),
+ "expected valid affine map representation for loop bounds");
+}
+
+bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ // Parse the induction variable followed by '='.
+ if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) ||
+ parser->parseEqual())
+ return true;
+
+ // Parse loop bounds.
+ if (parseBound(/*isLower=*/true, result, parser) ||
+ parser->parseKeyword("to", " between bounds") ||
+ parseBound(/*isLower=*/false, result, parser))
+ return true;
+
+ // Parse the optional loop step, we default to 1 if one is not present.
+ if (parser->parseOptionalKeyword("step")) {
+ result->addAttribute(
+ getStepAttrName(),
+ builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
+ } else {
+ llvm::SMLoc stepLoc;
+ IntegerAttr stepAttr;
+ if (parser->getCurrentLocation(&stepLoc) ||
+ parser->parseAttribute(stepAttr, builder.getIndexType(),
+ getStepAttrName().data(), result->attributes))
+ return true;
+
+ if (stepAttr.getValue().getSExtValue() < 0)
+ return parser->emitError(
+ stepLoc,
+ "expected step to be representable as a positive signed integer");
+ }
+
+ // Parse the body block list.
+ result->reserveBlockLists(/*numReserved=*/1);
+ if (parser->parseBlockList())
+ return true;
+
+ // Set the operands list as resizable so that we can freely modify the bounds.
+ result->setOperandListToResizable();
+ return false;
+}
+
+static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) {
+ AffineMap map = bound.getMap();
+
+ // Check if this bound should be printed using custom assembly form.
+ // The decision to restrict printing custom assembly form to trivial cases
+ // comes from the will to roundtrip MLIR binary -> text -> binary in a
+ // lossless way.
+ // Therefore, custom assembly form parsing and printing is only supported for
+ // zero-operand constant maps and single symbol operand identity maps.
+ if (map.getNumResults() == 1) {
+ AffineExpr expr = map.getResult(0);
+
+ // Print constant bound.
+ if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
+ if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ *p << constExpr.getValue();
+ return;
+ }
+ }
+
+ // Print bound that consists of a single SSA symbol if the map is over a
+ // single symbol.
+ if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
+ if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
+ p->printOperand(bound.getOperand(0));
+ return;
+ }
+ }
+ } else {
+ // Map has multiple results. Print 'min' or 'max' prefix.
+ *p << prefix << ' ';
+ }
+
+ // Print the map and its operands.
+ p->printAffineMap(map);
+ printDimAndSymbolList(bound.operand_begin(), bound.operand_end(),
+ map.getNumDims(), p);
+}
+
+void AffineForOp::print(OpAsmPrinter *p) const {
+ *p << "for ";
+ p->printOperand(getBody()->getArgument(0));
+ *p << " = ";
+ printBound(getLowerBound(), "max", p);
+ *p << " to ";
+ printBound(getUpperBound(), "min", p);
+
+ if (getStep() != 1)
+ *p << " step " << getStep();
+ p->printBlockList(getInstruction()->getBlockList(0),
+ /*printEntryBlockArgs=*/false);
+}
+
+Block *AffineForOp::createBody() {
+ auto &bodyBlockList = getBlockList();
+ assert(bodyBlockList.empty() && "expected no existing body blocks");
+
+ // Create a new block for the body, and add an argument for the induction
+ // variable.
+ Block *body = new Block();
+ body->addArgument(IndexType::get(getInstruction()->getContext()));
+ bodyBlockList.push_back(body);
+ return body;
+}
+
+const AffineBound AffineForOp::getLowerBound() const {
+ auto lbMap = getLowerBoundMap();
+ return AffineBound(ConstOpPointer<AffineForOp>(*this), 0,
+ lbMap.getNumInputs(), lbMap);
+}
+
+const AffineBound AffineForOp::getUpperBound() const {
+ auto lbMap = getLowerBoundMap();
+ auto ubMap = getUpperBoundMap();
+ return AffineBound(ConstOpPointer<AffineForOp>(*this), lbMap.getNumInputs(),
+ getNumOperands(), ubMap);
+}
+
+void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
+ assert(lbOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+ SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
+
+ auto ubOperands = getUpperBoundOperands();
+ newOperands.append(ubOperands.begin(), ubOperands.end());
+ getInstruction()->setOperands(newOperands);
+
+ setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
+ assert(ubOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+ SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
+ newOperands.append(ubOperands.begin(), ubOperands.end());
+ getInstruction()->setOperands(newOperands);
+
+ setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+void AffineForOp::setLowerBoundMap(AffineMap map) {
+ auto lbMap = getLowerBoundMap();
+ assert(lbMap.getNumDims() == map.getNumDims() &&
+ lbMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+ (void)lbMap;
+ setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBoundMap(AffineMap map) {
+ auto ubMap = getUpperBoundMap();
+ assert(ubMap.getNumDims() == map.getNumDims() &&
+ ubMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+ (void)ubMap;
+ setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()),
+ AffineMapAttr::get(map));
+}
+
+bool AffineForOp::hasConstantLowerBound() const {
+ return getLowerBoundMap().isSingleConstant();
+}
+
+bool AffineForOp::hasConstantUpperBound() const {
+ return getUpperBoundMap().isSingleConstant();
+}
+
+int64_t AffineForOp::getConstantLowerBound() const {
+ return getLowerBoundMap().getSingleConstantResult();
+}
+
+int64_t AffineForOp::getConstantUpperBound() const {
+ return getUpperBoundMap().getSingleConstantResult();
+}
+
+void AffineForOp::setConstantLowerBound(int64_t value) {
+ setLowerBound(
+ {}, AffineMap::getConstantMap(value, getInstruction()->getContext()));
+}
+
+void AffineForOp::setConstantUpperBound(int64_t value) {
+ setUpperBound(
+ {}, AffineMap::getConstantMap(value, getInstruction()->getContext()));
+}
+
+AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
+ return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
+}
+
+AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const {
+ return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
+}
+
+AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
+ return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
+}
+
+AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const {
+ return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
+}
+
+bool AffineForOp::matchingBoundOperandList() const {
+ auto lbMap = getLowerBoundMap();
+ auto ubMap = getUpperBoundMap();
+ if (lbMap.getNumDims() != ubMap.getNumDims() ||
+ lbMap.getNumSymbols() != ubMap.getNumSymbols())
+ return false;
+
+ unsigned numOperands = lbMap.getNumInputs();
+ for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
+ // Compare Value *'s.
+ if (getOperand(i) != getOperand(numOperands + i))
+ return false;
+ }
+ return true;
+}
+
+void AffineForOp::walkOps(std::function<void(OperationInst *)> callback) {
+ struct Walker : public InstWalker<Walker> {
+ std::function<void(OperationInst *)> const &callback;
+ Walker(std::function<void(OperationInst *)> const &callback)
+ : callback(callback) {}
+
+ void visitOperationInst(OperationInst *opInst) { callback(opInst); }
+ };
+
+ Walker w(callback);
+ w.walk(getInstruction());
+}
+
+void AffineForOp::walkOpsPostOrder(
+ std::function<void(OperationInst *)> callback) {
+ struct Walker : public InstWalker<Walker> {
+ std::function<void(OperationInst *)> const &callback;
+ Walker(std::function<void(OperationInst *)> const &callback)
+ : callback(callback) {}
+
+ void visitOperationInst(OperationInst *opInst) { callback(opInst); }
+ };
+
+ Walker v(callback);
+ v.walkPostOrder(getInstruction());
+}
+
+/// Returns the induction variable for this loop.
+Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); }
+
+/// Returns if the provided value is the induction variable of a AffineForOp.
+bool mlir::isForInductionVar(const Value *val) {
+ return getForInductionVarOwner(val) != nullptr;
+}
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+OpPointer<AffineForOp> mlir::getForInductionVarOwner(Value *val) {
+ const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
+ if (!ivArg || !ivArg->getOwner())
+ return OpPointer<AffineForOp>();
+ auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst();
+ if (!containingInst)
+ return OpPointer<AffineForOp>();
+ return cast<OperationInst>(containingInst)->dyn_cast<AffineForOp>();
+}
+ConstOpPointer<AffineForOp> mlir::getForInductionVarOwner(const Value *val) {
+ auto nonConstOwner = getForInductionVarOwner(const_cast<Value *>(val));
+ return ConstOpPointer<AffineForOp>(nonConstOwner);
+}
+
+/// Extracts the induction variables from a list of AffineForOps and returns
+/// them.
+SmallVector<Value *, 8> mlir::extractForInductionVars(
+ MutableArrayRef<OpPointer<AffineForOp>> forInsts) {
+ SmallVector<Value *, 8> results;
+ for (auto forInst : forInsts)
+ results.push_back(forInst->getInductionVar());
+ return results;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 0153546a4c6..d2366f1ce81 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -21,12 +21,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
@@ -519,7 +521,7 @@ void mlir::getReachableAffineApplyOps(
State &state = worklist.back();
auto *opInst = state.value->getDefiningInst();
// Note: getDefiningInst will return nullptr if the operand is not an
- // OperationInst (i.e. ForInst), which is a terminator for the search.
+ // OperationInst (i.e. AffineForOp), which is a terminator for the search.
if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) {
worklist.pop_back();
continue;
@@ -546,21 +548,21 @@ void mlir::getReachableAffineApplyOps(
}
// Builds a system of constraints with dimensional identifiers corresponding to
-// the loop IVs of the forInsts appearing in that order. Any symbols founds in
+// the loop IVs of the forOps appearing in that order. Any symbols founds in
// the bound operands are added as symbols in the system. Returns false for the
// yet unimplemented cases.
// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
// step = 0 and/or by introducing a method in FlatAffineConstraints
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
-bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
+bool mlir::getIndexSet(MutableArrayRef<OpPointer<AffineForOp>> forOps,
FlatAffineConstraints *domain) {
- auto indices = extractForInductionVars(forInsts);
+ auto indices = extractForInductionVars(forOps);
// Reset while associated Values in 'indices' to the domain.
- domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
- for (auto *forInst : forInsts) {
- // Add constraints from forInst's bounds.
- if (!domain->addForInstDomain(*forInst))
+ domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
+ for (auto forOp : forOps) {
+ // Add constraints from forOp's bounds.
+ if (!domain->addAffineForOpDomain(forOp))
return false;
}
return true;
@@ -576,7 +578,7 @@ static bool getInstIndexSet(const Instruction *inst,
FlatAffineConstraints *indexSet) {
// TODO(andydavis) Extend this to gather enclosing IfInsts and consider
// factoring it out into a utility function.
- SmallVector<ForInst *, 4> loops;
+ SmallVector<OpPointer<AffineForOp>, 4> loops;
getLoopIVs(*inst, &loops);
return getIndexSet(loops, indexSet);
}
@@ -998,9 +1000,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess,
return block;
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
- auto *forInst = getForInductionVarOwner(commonForValue);
- assert(forInst && "commonForValue was not an induction variable");
- return forInst->getBody();
+ auto forOp = getForInductionVarOwner(commonForValue);
+ assert(forOp && "commonForValue was not an induction variable");
+ return forOp->getBody();
}
// Returns true if the ancestor operation instruction of 'srcAccess' appears
@@ -1195,7 +1197,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// until operands of the AffineValueMap are loop IVs or symbols.
// *) Build iteration domain constraints for each access. Iteration domain
// constraints are pairs of inequality contraints representing the
-// upper/lower loop bounds for each ForInst in the loop nest associated
+// upper/lower loop bounds for each AffineForOp in the loop nest associated
// with each access.
// *) Build dimension and symbol position maps for each access, which map
// Values from access functions and iteration domains to their position
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 5e7f8e3243c..c794899d3e1 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -20,6 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
@@ -1247,22 +1248,23 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
numSymbols = newSymbolCount;
}
-bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
+bool FlatAffineConstraints::addAffineForOpDomain(
+ ConstOpPointer<AffineForOp> forOp) {
unsigned pos;
// Pre-condition for this method.
- if (!findId(*forInst.getInductionVar(), &pos)) {
+ if (!findId(*forOp->getInductionVar(), &pos)) {
assert(0 && "Value not found");
return false;
}
- if (forInst.getStep() != 1)
+ if (forOp->getStep() != 1)
LLVM_DEBUG(llvm::dbgs()
<< "Domain conservative: non-unit stride not handled\n");
// Adds a lower or upper bound when the bounds aren't constant.
auto addLowerOrUpperBound = [&](bool lower) -> bool {
- auto operands = lower ? forInst.getLowerBoundOperands()
- : forInst.getUpperBoundOperands();
+ auto operands =
+ lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands();
for (const auto &operand : operands) {
unsigned loc;
if (!findId(*operand, &loc)) {
@@ -1291,7 +1293,7 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
}
auto boundMap =
- lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap();
+ lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap();
FlatAffineConstraints localVarCst;
std::vector<SmallVector<int64_t, 8>> flatExprs;
@@ -1321,16 +1323,16 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
return true;
};
- if (forInst.hasConstantLowerBound()) {
- addConstantLowerBound(pos, forInst.getConstantLowerBound());
+ if (forOp->hasConstantLowerBound()) {
+ addConstantLowerBound(pos, forOp->getConstantLowerBound());
} else {
// Non-constant lower bound case.
if (!addLowerOrUpperBound(/*lower=*/true))
return false;
}
- if (forInst.hasConstantUpperBound()) {
- addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1);
+ if (forOp->hasConstantUpperBound()) {
+ addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1);
return true;
}
// Non-constant upper bound case.
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 7d88a3d9b9f..249776d42c9 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -43,27 +43,27 @@ using namespace mlir;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExpr mlir::getTripCountExpr(const ForInst &forInst) {
+AffineExpr mlir::getTripCountExpr(ConstOpPointer<AffineForOp> forOp) {
// upper_bound - lower_bound
int64_t loopSpan;
- int64_t step = forInst.getStep();
- auto *context = forInst.getContext();
+ int64_t step = forOp->getStep();
+ auto *context = forOp->getInstruction()->getContext();
- if (forInst.hasConstantBounds()) {
- int64_t lb = forInst.getConstantLowerBound();
- int64_t ub = forInst.getConstantUpperBound();
+ if (forOp->hasConstantBounds()) {
+ int64_t lb = forOp->getConstantLowerBound();
+ int64_t ub = forOp->getConstantUpperBound();
loopSpan = ub - lb;
} else {
- auto lbMap = forInst.getLowerBoundMap();
- auto ubMap = forInst.getUpperBoundMap();
+ auto lbMap = forOp->getLowerBoundMap();
+ auto ubMap = forOp->getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
return nullptr;
// TODO(bondhugula): handle bounds with different operands.
// Bounds have different operands, unhandled for now.
- if (!forInst.matchingBoundOperandList())
+ if (!forOp->matchingBoundOperandList())
return nullptr;
// ub_expr - lb_expr
@@ -89,8 +89,9 @@ AffineExpr mlir::getTripCountExpr(const ForInst &forInst) {
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
-llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) {
- auto tripCountExpr = getTripCountExpr(forInst);
+llvm::Optional<uint64_t>
+mlir::getConstantTripCount(ConstOpPointer<AffineForOp> forOp) {
+ auto tripCountExpr = getTripCountExpr(forOp);
if (!tripCountExpr)
return None;
@@ -104,8 +105,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) {
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
-uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
- auto tripCountExpr = getTripCountExpr(forInst);
+uint64_t mlir::getLargestDivisorOfTripCount(ConstOpPointer<AffineForOp> forOp) {
+ auto tripCountExpr = getTripCountExpr(forOp);
if (!tripCountExpr)
return 1;
@@ -126,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
}
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
- assert(isForInductionVar(&iv) && "iv must be a ForInst");
+ assert(isForInductionVar(&iv) && "iv must be a AffineForOp");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
@@ -163,7 +164,7 @@ mlir::getInvariantAccesses(const Value &iv,
}
/// Given:
-/// 1. an induction variable `iv` of type ForInst;
+/// 1. an induction variable `iv` of type AffineForOp;
/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
/// 3. the index of the `fastestVaryingDim` along which to check;
/// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access
@@ -231,17 +232,18 @@ static bool isVectorTransferReadOrWrite(const Instruction &inst) {
}
using VectorizableInstFun =
- std::function<bool(const ForInst &, const OperationInst &)>;
+ std::function<bool(ConstOpPointer<AffineForOp>, const OperationInst &)>;
-static bool isVectorizableLoopWithCond(const ForInst &loop,
+static bool isVectorizableLoopWithCond(ConstOpPointer<AffineForOp> loop,
VectorizableInstFun isVectorizableInst) {
- if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
+ auto *forInst = const_cast<OperationInst *>(loop->getInstruction());
+ if (!matcher::isParallelLoop(*forInst) &&
+ !matcher::isReductionLoop(*forInst)) {
return false;
}
// No vectorization across conditionals for now.
auto conditionals = matcher::If();
- auto *forInst = const_cast<ForInst *>(&loop);
SmallVector<NestedMatch, 8> conditionalsMatched;
conditionals.match(forInst, &conditionalsMatched);
if (!conditionalsMatched.empty()) {
@@ -251,7 +253,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
// No vectorization across unknown regions.
auto regions = matcher::Op([](const Instruction &inst) -> bool {
auto &opInst = cast<OperationInst>(inst);
- return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>();
+ return opInst.getNumBlockLists() != 0 &&
+ !(opInst.isa<AffineIfOp>() || opInst.isa<AffineForOp>());
});
SmallVector<NestedMatch, 8> regionsMatched;
regions.match(forInst, &regionsMatched);
@@ -288,23 +291,25 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
}
bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
- const ForInst &loop, unsigned fastestVaryingDim) {
- VectorizableInstFun fun(
- [fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
- auto load = op.dyn_cast<LoadOp>();
- auto store = op.dyn_cast<StoreOp>();
- return load ? isContiguousAccess(*loop.getInductionVar(), *load,
- fastestVaryingDim)
- : isContiguousAccess(*loop.getInductionVar(), *store,
- fastestVaryingDim);
- });
+ ConstOpPointer<AffineForOp> loop, unsigned fastestVaryingDim) {
+ VectorizableInstFun fun([fastestVaryingDim](ConstOpPointer<AffineForOp> loop,
+ const OperationInst &op) {
+ auto load = op.dyn_cast<LoadOp>();
+ auto store = op.dyn_cast<StoreOp>();
+ return load ? isContiguousAccess(*loop->getInductionVar(), *load,
+ fastestVaryingDim)
+ : isContiguousAccess(*loop->getInductionVar(), *store,
+ fastestVaryingDim);
+ });
return isVectorizableLoopWithCond(loop, fun);
}
-bool mlir::isVectorizableLoop(const ForInst &loop) {
+bool mlir::isVectorizableLoop(ConstOpPointer<AffineForOp> loop) {
VectorizableInstFun fun(
// TODO: implement me
- [](const ForInst &loop, const OperationInst &op) { return true; });
+ [](ConstOpPointer<AffineForOp> loop, const OperationInst &op) {
+ return true;
+ });
return isVectorizableLoopWithCond(loop, fun);
}
@@ -313,9 +318,9 @@ bool mlir::isVectorizableLoop(const ForInst &loop) {
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
-bool mlir::isInstwiseShiftValid(const ForInst &forInst,
+bool mlir::isInstwiseShiftValid(ConstOpPointer<AffineForOp> forOp,
ArrayRef<uint64_t> shifts) {
- auto *forBody = forInst.getBody();
+ auto *forBody = forOp->getBody();
assert(shifts.size() == forBody->getInstructions().size());
unsigned s = 0;
for (const auto &inst : *forBody) {
@@ -325,7 +330,7 @@ bool mlir::isInstwiseShiftValid(const ForInst &forInst,
for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) {
const Value *result = opInst->getResult(i);
for (const InstOperand &use : result->getUses()) {
- // If an ancestor instruction doesn't lie in the block of forInst,
+ // If an ancestor instruction doesn't lie in the block of forOp,
// there is no shift to check. This is a naive way. If performance
// becomes an issue, a map can be used to store 'shifts' - to look up
// the shift for a instruction in constant time.
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
index 46bf5ad0b97..214b4ce403c 100644
--- a/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -115,6 +115,10 @@ void NestedPattern::matchOne(Instruction *inst,
}
}
+static bool isAffineForOp(const Instruction &inst) {
+ return cast<OperationInst>(inst).isa<AffineForOp>();
+}
+
static bool isAffineIfOp(const Instruction &inst) {
return isa<OperationInst>(inst) &&
cast<OperationInst>(inst).isa<AffineIfOp>();
@@ -147,28 +151,34 @@ NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
}
NestedPattern For(NestedPattern child) {
- return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp);
}
NestedPattern For(FilterFunctionType filter, NestedPattern child) {
- return NestedPattern(Instruction::Kind::For, child, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, child,
+ [=](const Instruction &inst) {
+ return isAffineForOp(inst) && filter(inst);
+ });
}
NestedPattern For(ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp);
}
NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::For, nested, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, nested,
+ [=](const Instruction &inst) {
+ return isAffineForOp(inst) && filter(inst);
+ });
}
// TODO(ntv): parallel annotation on loops.
bool isParallelLoop(const Instruction &inst) {
- const auto *loop = cast<ForInst>(&inst);
- return (void *)loop || true; // loop->isParallel();
+ auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
+ return loop || true; // loop->isParallel();
};
// TODO(ntv): reduction annotation on loops.
bool isReductionLoop(const Instruction &inst) {
- const auto *loop = cast<ForInst>(&inst);
- return (void *)loop || true; // loop->isReduction();
+ auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
+ return loop || true; // loop->isReduction();
};
bool isLoadOrStore(const Instruction &inst) {
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index d16a7fcb1b3..4025af936f3 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -20,6 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
@@ -52,7 +53,16 @@ void mlir::getForwardSlice(Instruction *inst,
return;
}
- if (auto *opInst = dyn_cast<OperationInst>(inst)) {
+ auto *opInst = cast<OperationInst>(inst);
+ if (auto forOp = opInst->dyn_cast<AffineForOp>()) {
+ for (auto &u : forOp->getInductionVar()->getUses()) {
+ auto *ownerInst = u.getOwner();
+ if (forwardSlice->count(ownerInst) == 0) {
+ getForwardSlice(ownerInst, forwardSlice, filter,
+ /*topLevel=*/false);
+ }
+ }
+ } else {
assert(opInst->getNumResults() <= 1 && "NYI: multiple results");
if (opInst->getNumResults() > 0) {
for (auto &u : opInst->getResult(0)->getUses()) {
@@ -63,16 +73,6 @@ void mlir::getForwardSlice(Instruction *inst,
}
}
}
- } else if (auto *forInst = dyn_cast<ForInst>(inst)) {
- for (auto &u : forInst->getInductionVar()->getUses()) {
- auto *ownerInst = u.getOwner();
- if (forwardSlice->count(ownerInst) == 0) {
- getForwardSlice(ownerInst, forwardSlice, filter,
- /*topLevel=*/false);
- }
- }
- } else {
- assert(false && "NYI slicing case");
}
// At the top level we reverse to get back the actual topological order.
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 0e77d4d9084..4b8afd9a620 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -38,15 +38,17 @@ using namespace mlir;
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
/// the outermost 'for' instruction to the innermost one.
void mlir::getLoopIVs(const Instruction &inst,
- SmallVectorImpl<ForInst *> *loops) {
+ SmallVectorImpl<OpPointer<AffineForOp>> *loops) {
auto *currInst = inst.getParentInst();
- ForInst *currForInst;
+ OpPointer<AffineForOp> currAffineForOp;
// Traverse up the hierarchy collecing all 'for' instruction while skipping
// over 'if' instructions.
- while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
- cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
- if (currForInst)
- loops->push_back(currForInst);
+ while (currInst &&
+ ((currAffineForOp =
+ cast<OperationInst>(currInst)->dyn_cast<AffineForOp>()) ||
+ cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
+ if (currAffineForOp)
+ loops->push_back(currAffineForOp);
currInst = currInst->getParentInst();
}
std::reverse(loops->begin(), loops->end());
@@ -148,7 +150,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
if (rank == 0) {
// A rank 0 memref has a 0-d region.
- SmallVector<ForInst *, 4> ivs;
+ SmallVector<OpPointer<AffineForOp>, 4> ivs;
getLoopIVs(*opInst, &ivs);
SmallVector<Value *, 8> regionSymbols = extractForInductionVars(ivs);
@@ -174,12 +176,12 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
unsigned numSymbols = accessMap.getNumSymbols();
// Add inequalties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
- if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
+ if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
// Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
// conditionals will be handled when the latter supports it.
- if (!regionCst->addForInstDomain(*loop))
+ if (!regionCst->addAffineForOpDomain(loop))
return false;
} else {
// Has to be a valid symbol.
@@ -203,14 +205,14 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
// this memref region is symbolic.
- SmallVector<ForInst *, 4> outerIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> outerIVs;
getLoopIVs(*opInst, &outerIVs);
assert(loopDepth <= outerIVs.size() && "invalid loop depth");
outerIVs.resize(loopDepth);
for (auto *operand : accessValueMap.getOperands()) {
- ForInst *iv;
+ OpPointer<AffineForOp> iv;
if ((iv = getForInductionVarOwner(operand)) &&
- std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
+ llvm::is_contained(outerIVs, iv) == false) {
regionCst->projectOut(operand);
}
}
@@ -357,8 +359,10 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
}
if (level == positions.size() - 1)
return &inst;
- if (auto *childForInst = dyn_cast<ForInst>(&inst))
- return getInstAtPosition(positions, level + 1, childForInst->getBody());
+ if (auto childAffineForOp =
+ cast<OperationInst>(inst).dyn_cast<AffineForOp>())
+ return getInstAtPosition(positions, level + 1,
+ childAffineForOp->getBody());
for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) {
for (auto &b : blockList)
@@ -385,12 +389,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
return false;
}
// Get loop nest surrounding src operation.
- SmallVector<ForInst *, 4> srcLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcAccess.opInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Get loop nest surrounding dst operation.
- SmallVector<ForInst *, 4> dstLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
getLoopIVs(*dstAccess.opInst, &dstLoopIVs);
unsigned numDstLoopIVs = dstLoopIVs.size();
if (dstLoopDepth > numDstLoopIVs) {
@@ -437,38 +441,41 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
// solution.
// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
// out loop IVs we don't care about and produce smaller slice.
-ForInst *mlir::insertBackwardComputationSlice(
+OpPointer<AffineForOp> mlir::insertBackwardComputationSlice(
OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth,
ComputationSliceState *sliceState) {
// Get loop nest surrounding src operation.
- SmallVector<ForInst *, 4> srcLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Get loop nest surrounding dst operation.
- SmallVector<ForInst *, 4> dstLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
getLoopIVs(*dstOpInst, &dstLoopIVs);
unsigned dstLoopIVsSize = dstLoopIVs.size();
if (dstLoopDepth > dstLoopIVsSize) {
dstOpInst->emitError("invalid destination loop depth");
- return nullptr;
+ return OpPointer<AffineForOp>();
}
// Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'.
SmallVector<unsigned, 4> positions;
// TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
- findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
+ findInstPosition(srcOpInst, srcLoopIVs[0]->getInstruction()->getBlock(),
+ &positions);
// Clone src loop nest and insert it a the beginning of the instruction block
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
- auto *dstForInst = dstLoopIVs[dstLoopDepth - 1];
- FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin());
- auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopIVs[0]));
+ auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
+ FuncBuilder b(dstAffineForOp->getBody(), dstAffineForOp->getBody()->begin());
+ auto sliceLoopNest =
+ cast<OperationInst>(b.clone(*srcLoopIVs[0]->getInstruction()))
+ ->cast<AffineForOp>();
Instruction *sliceInst =
getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceInst'.
- SmallVector<ForInst *, 4> sliceSurroundingLoops;
+ SmallVector<OpPointer<AffineForOp>, 4> sliceSurroundingLoops;
getLoopIVs(*sliceInst, &sliceSurroundingLoops);
// Sanity check.
@@ -481,11 +488,11 @@ ForInst *mlir::insertBackwardComputationSlice(
// Update loop bounds for loops in 'sliceLoopNest'.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
- auto *forInst = sliceSurroundingLoops[dstLoopDepth + i];
+ auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
if (AffineMap lbMap = sliceState->lbs[i])
- forInst->setLowerBound(sliceState->lbOperands[i], lbMap);
+ forOp->setLowerBound(sliceState->lbOperands[i], lbMap);
if (AffineMap ubMap = sliceState->ubs[i])
- forInst->setUpperBound(sliceState->ubOperands[i], ubMap);
+ forOp->setUpperBound(sliceState->ubOperands[i], ubMap);
}
return sliceLoopNest;
}
@@ -520,7 +527,7 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) {
const Instruction *currInst = &stmt;
unsigned depth = 0;
while ((currInst = currInst->getParentInst())) {
- if (isa<ForInst>(currInst))
+ if (cast<OperationInst>(currInst)->isa<AffineForOp>())
depth++;
}
return depth;
@@ -530,14 +537,14 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) {
/// where each lists loops from outer-most to inner-most in loop nest.
unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A,
const Instruction &B) {
- SmallVector<ForInst *, 4> loopsA, loopsB;
+ SmallVector<OpPointer<AffineForOp>, 4> loopsA, loopsB;
getLoopIVs(A, &loopsA);
getLoopIVs(B, &loopsB);
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
- if (loopsA[i] != loopsB[i])
+ if (loopsA[i]->getInstruction() != loopsB[i]->getInstruction())
break;
++numCommonLoops;
}
@@ -571,13 +578,14 @@ static Optional<int64_t> getRegionSize(const MemRefRegion &region) {
return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
}
-Optional<int64_t> mlir::getMemoryFootprintBytes(const ForInst &forInst,
- int memorySpace) {
+Optional<int64_t>
+mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
+ int memorySpace) {
std::vector<std::unique_ptr<MemRefRegion>> regions;
// Walk this 'for' instruction to gather all memory regions.
bool error = false;
- const_cast<ForInst *>(&forInst)->walkOps([&](OperationInst *opInst) {
+ const_cast<AffineForOp &>(*forOp).walkOps([&](OperationInst *opInst) {
if (!opInst->isa<LoadOp>() && !opInst->isa<StoreOp>()) {
// Neither load nor a store op.
return;
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp
index 125020e92a3..4865cb03bb4 100644
--- a/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -16,10 +16,12 @@
// =============================================================================
#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
@@ -105,7 +107,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
static AffineMap makePermutationMap(
MLIRContext *context,
llvm::iterator_range<OperationInst::operand_iterator> indices,
- const DenseMap<ForInst *, unsigned> &enclosingLoopToVectorDim) {
+ const DenseMap<Instruction *, unsigned> &enclosingLoopToVectorDim) {
using functional::makePtrDynCaster;
using functional::map;
auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
@@ -113,8 +115,9 @@ static AffineMap makePermutationMap(
getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) {
assert(kvp.second < perm.size());
- auto invariants =
- getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices);
+ auto invariants = getInvariantAccesses(
+ *cast<OperationInst>(kvp.first)->cast<AffineForOp>()->getInductionVar(),
+ unwrappedIndices);
unsigned numIndices = unwrappedIndices.size();
unsigned countInvariantIndices = 0;
for (unsigned dim = 0; dim < numIndices; ++dim) {
@@ -139,30 +142,30 @@ static AffineMap makePermutationMap(
/// TODO(ntv): could also be implemented as a collect parents followed by a
/// filter and made available outside this file.
template <typename T>
-static SetVector<T *> getParentsOfType(Instruction *inst) {
- SetVector<T *> res;
+static SetVector<OperationInst *> getParentsOfType(Instruction *inst) {
+ SetVector<OperationInst *> res;
auto *current = inst;
while (auto *parent = current->getParentInst()) {
- auto *typedParent = dyn_cast<T>(parent);
- if (typedParent) {
- assert(res.count(typedParent) == 0 && "Already inserted");
- res.insert(typedParent);
+ if (auto typedParent =
+ cast<OperationInst>(parent)->template dyn_cast<T>()) {
+ assert(res.count(cast<OperationInst>(parent)) == 0 && "Already inserted");
+ res.insert(cast<OperationInst>(parent));
}
current = parent;
}
return res;
}
-/// Returns the enclosing ForInst, from closest to farthest.
-static SetVector<ForInst *> getEnclosingforInsts(Instruction *inst) {
- return getParentsOfType<ForInst>(inst);
+/// Returns the enclosing AffineForOp, from closest to farthest.
+static SetVector<OperationInst *> getEnclosingforOps(Instruction *inst) {
+ return getParentsOfType<AffineForOp>(inst);
}
-AffineMap
-mlir::makePermutationMap(OperationInst *opInst,
- const DenseMap<ForInst *, unsigned> &loopToVectorDim) {
- DenseMap<ForInst *, unsigned> enclosingLoopToVectorDim;
- auto enclosingLoops = getEnclosingforInsts(opInst);
+AffineMap mlir::makePermutationMap(
+ OperationInst *opInst,
+ const DenseMap<Instruction *, unsigned> &loopToVectorDim) {
+ DenseMap<Instruction *, unsigned> enclosingLoopToVectorDim;
+ auto enclosingLoops = getEnclosingforOps(opInst);
for (auto *forInst : enclosingLoops) {
auto it = loopToVectorDim.find(forInst);
if (it != loopToVectorDim.end()) {
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 474eeb2a28e..a69831053ad 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -72,7 +72,6 @@ public:
bool verify();
bool verifyBlock(const Block &block, bool isTopLevel);
bool verifyOperation(const OperationInst &op);
- bool verifyForInst(const ForInst &forInst);
bool verifyDominance(const Block &block);
bool verifyInstDominance(const Instruction &inst);
@@ -175,10 +174,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) {
if (verifyOperation(cast<OperationInst>(inst)))
return true;
break;
- case Instruction::Kind::For:
- if (verifyForInst(cast<ForInst>(inst)))
- return true;
- break;
}
}
@@ -240,11 +235,6 @@ bool FuncVerifier::verifyOperation(const OperationInst &op) {
return false;
}
-bool FuncVerifier::verifyForInst(const ForInst &forInst) {
- // TODO: check that loop bounds are properly formed.
- return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false);
-}
-
bool FuncVerifier::verifyDominance(const Block &block) {
for (auto &inst : block) {
// Check that all operands on the instruction are ok.
@@ -262,10 +252,6 @@ bool FuncVerifier::verifyDominance(const Block &block) {
return true;
break;
}
- case Instruction::Kind::For:
- if (verifyDominance(*cast<ForInst>(inst).getBody()))
- return true;
- break;
}
}
return false;
diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp
index dc85c5ed682..f4d5d36d25b 100644
--- a/mlir/lib/EDSC/MLIREmitter.cpp
+++ b/mlir/lib/EDSC/MLIREmitter.cpp
@@ -21,12 +21,14 @@
#include "llvm/Support/raw_ostream.h"
#include "mlir-c/Core.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
#include "mlir/StandardOps/StandardOps.h"
@@ -133,8 +135,8 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
inst->print(os);
return;
}
- if (auto *forInst = getForInductionVarOwner(&v)) {
- forInst->print(os);
+ if (auto forInst = getForInductionVarOwner(&v)) {
+ forInst->getInstruction()->print(os);
} else {
os << "unknown_ssa_value";
}
@@ -300,7 +302,9 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
exprs[1]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
auto step =
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
- res = builder->createFor(location, lb, ub, step)->getInductionVar();
+ auto forOp = builder->create<AffineForOp>(location, lb, ub, step);
+ forOp->createBody();
+ res = forOp->getInductionVar();
}
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index cb4c1f0edce..0fb18fa0004 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -130,21 +130,8 @@ private:
void recordTypeReference(Type ty) { usedTypes.insert(ty); }
- // Return true if this map could be printed using the custom assembly form.
- static bool hasCustomForm(AffineMap boundMap) {
- if (boundMap.isSingleConstant())
- return true;
-
- // Check if the affine map is single dim id or single symbol identity -
- // (i)->(i) or ()[s]->(i)
- return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 &&
- (boundMap.getResult(0).isa<AffineDimExpr>() ||
- boundMap.getResult(0).isa<AffineSymbolExpr>());
- }
-
// Visit functions.
void visitInstruction(const Instruction *inst);
- void visitForInst(const ForInst *forInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
@@ -196,16 +183,6 @@ void ModuleState::visitAttribute(Attribute attr) {
}
}
-void ModuleState::visitForInst(const ForInst *forInst) {
- AffineMap lbMap = forInst->getLowerBoundMap();
- if (!hasCustomForm(lbMap))
- recordAffineMapReference(lbMap);
-
- AffineMap ubMap = forInst->getUpperBoundMap();
- if (!hasCustomForm(ubMap))
- recordAffineMapReference(ubMap);
-}
-
void ModuleState::visitOperationInst(const OperationInst *op) {
// Visit all the types used in the operation.
for (auto *operand : op->getOperands())
@@ -220,8 +197,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) {
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
- case Instruction::Kind::For:
- return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
return visitOperationInst(cast<OperationInst>(inst));
}
@@ -1069,7 +1044,6 @@ public:
// Methods to print instructions.
void print(const Instruction *inst);
void print(const OperationInst *inst);
- void print(const ForInst *inst);
void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
@@ -1117,10 +1091,8 @@ public:
unsigned index) override;
/// Print a block list.
- void printBlockList(const BlockList &blocks) override {
- printBlockList(blocks, /*printEntryBlockArgs=*/true);
- }
- void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
+ void printBlockList(const BlockList &blocks,
+ bool printEntryBlockArgs) override {
os << " {\n";
if (!blocks.empty()) {
auto *entryBlock = &blocks.front();
@@ -1132,10 +1104,6 @@ public:
os.indent(currentIndent) << "}";
}
- // Print if and loop bounds.
- void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
- void printBound(AffineBound bound, const char *prefix);
-
// Number of spaces used for indenting nested instructions.
const static unsigned indentWidth = 2;
@@ -1205,10 +1173,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
numberValuesInBlock(block);
break;
}
- case Instruction::Kind::For:
- // Recursively number the stuff in the body.
- numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
- break;
}
}
}
@@ -1404,8 +1368,6 @@ void FunctionPrinter::print(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::OperationInst:
return print(cast<OperationInst>(inst));
- case Instruction::Kind::For:
- return print(cast<ForInst>(inst));
}
}
@@ -1415,24 +1377,6 @@ void FunctionPrinter::print(const OperationInst *inst) {
printTrailingLocation(inst->getLoc());
}
-void FunctionPrinter::print(const ForInst *inst) {
- os.indent(currentIndent) << "for ";
- printOperand(inst->getInductionVar());
- os << " = ";
- printBound(inst->getLowerBound(), "max");
- os << " to ";
- printBound(inst->getUpperBound(), "min");
-
- if (inst->getStep() != 1)
- os << " step " << inst->getStep();
-
- printTrailingLocation(inst->getLoc());
-
- os << " {\n";
- print(inst->getBody(), /*printBlockArgs=*/false);
- os.indent(currentIndent) << "}";
-}
-
void FunctionPrinter::printValueID(const Value *value,
bool printResultNo) const {
int resultNo = -1;
@@ -1560,62 +1504,6 @@ void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
os << ')';
}
-void FunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
- unsigned numDims) {
- auto printComma = [&]() { os << ", "; };
- os << '(';
- interleave(
- ops.begin(), ops.begin() + numDims,
- [&](const InstOperand &v) { printOperand(v.get()); }, printComma);
- os << ')';
-
- if (numDims < ops.size()) {
- os << '[';
- interleave(
- ops.begin() + numDims, ops.end(),
- [&](const InstOperand &v) { printOperand(v.get()); }, printComma);
- os << ']';
- }
-}
-
-void FunctionPrinter::printBound(AffineBound bound, const char *prefix) {
- AffineMap map = bound.getMap();
-
- // Check if this bound should be printed using custom assembly form.
- // The decision to restrict printing custom assembly form to trivial cases
- // comes from the will to roundtrip MLIR binary -> text -> binary in a
- // lossless way.
- // Therefore, custom assembly form parsing and printing is only supported for
- // zero-operand constant maps and single symbol operand identity maps.
- if (map.getNumResults() == 1) {
- AffineExpr expr = map.getResult(0);
-
- // Print constant bound.
- if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
- os << constExpr.getValue();
- return;
- }
- }
-
- // Print bound that consists of a single SSA symbol if the map is over a
- // single symbol.
- if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
- if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
- printOperand(bound.getOperand(0));
- return;
- }
- }
- } else {
- // Map has multiple results. Print 'min' or 'max' prefix.
- os << prefix << ' ';
- }
-
- // Print the map and its operands.
- printAffineMapReference(map);
- printDimAndSymbolList(bound.getInstOperands(), map.getNumDims());
-}
-
// Prints function with initialized module state.
void ModulePrinter::print(const Function *fn) {
FunctionPrinter(fn, *this).print();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index ffeb4e0317f..68fbef2d27a 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -312,19 +312,3 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) {
block->getInstructions().insert(insertPoint, op);
return op;
}
-
-ForInst *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
- AffineMap lbMap, ArrayRef<Value *> ubOperands,
- AffineMap ubMap, int64_t step) {
- auto *inst =
- ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
- block->getInstructions().insert(insertPoint, inst);
- return inst;
-}
-
-ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
- int64_t step) {
- auto lbMap = AffineMap::getConstantMap(lb, context);
- auto ubMap = AffineMap::getConstantMap(ub, context);
- return createFor(location, {}, lbMap, {}, ubMap, step);
-}
diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp
index 8d43e3a783d..03f1a2702c9 100644
--- a/mlir/lib/IR/Instruction.cpp
+++ b/mlir/lib/IR/Instruction.cpp
@@ -143,9 +143,6 @@ void Instruction::destroy() {
case Kind::OperationInst:
cast<OperationInst>(this)->destroy();
break;
- case Kind::For:
- cast<ForInst>(this)->destroy();
- break;
}
}
@@ -209,8 +206,6 @@ unsigned Instruction::getNumOperands() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getNumOperands();
- case Kind::For:
- return cast<ForInst>(this)->getNumOperands();
}
}
@@ -218,8 +213,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getInstOperands();
- case Kind::For:
- return cast<ForInst>(this)->getInstOperands();
}
}
@@ -349,10 +342,6 @@ void Instruction::dropAllReferences() {
op.drop();
switch (getKind()) {
- case Kind::For:
- // Make sure to drop references held by instructions within the body.
- cast<ForInst>(this)->getBody()->dropAllReferences();
- break;
case Kind::OperationInst: {
auto *opInst = cast<OperationInst>(this);
if (isTerminator())
@@ -656,217 +645,6 @@ bool OperationInst::emitOpError(const Twine &message) const {
}
//===----------------------------------------------------------------------===//
-// ForInst
-//===----------------------------------------------------------------------===//
-
-ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
- AffineMap lbMap, ArrayRef<Value *> ubOperands,
- AffineMap ubMap, int64_t step) {
- assert((!lbMap && lbOperands.empty()) ||
- lbOperands.size() == lbMap.getNumInputs() &&
- "lower bound operand count does not match the affine map");
- assert((!ubMap && ubOperands.empty()) ||
- ubOperands.size() == ubMap.getNumInputs() &&
- "upper bound operand count does not match the affine map");
- assert(step > 0 && "step has to be a positive integer constant");
-
- // Compute the byte size for the instruction and the operand storage.
- unsigned numOperands = lbOperands.size() + ubOperands.size();
- auto byteSize = totalSizeToAlloc<detail::OperandStorage>(
- /*detail::OperandStorage*/ 1);
- byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize(
- numOperands, /*resizable=*/true),
- alignof(ForInst));
- void *rawMem = malloc(byteSize);
-
- // Initialize the OperationInst part of the instruction.
- ForInst *inst = ::new (rawMem) ForInst(location, lbMap, ubMap, step);
- new (&inst->getOperandStorage())
- detail::OperandStorage(numOperands, /*resizable=*/true);
-
- auto operands = inst->getInstOperands();
- unsigned i = 0;
- for (unsigned e = lbOperands.size(); i != e; ++i)
- new (&operands[i]) InstOperand(inst, lbOperands[i]);
-
- for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
- new (&operands[i]) InstOperand(inst, ubOperands[j]);
-
- return inst;
-}
-
-ForInst::ForInst(Location location, AffineMap lbMap, AffineMap ubMap,
- int64_t step)
- : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap),
- ubMap(ubMap), step(step) {
-
- // The body of a for inst always has one block.
- auto *bodyEntry = new Block();
- body.push_back(bodyEntry);
-
- // Add an argument to the block for the induction variable.
- bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext()));
-}
-
-ForInst::~ForInst() { getOperandStorage().~OperandStorage(); }
-
-const AffineBound ForInst::getLowerBound() const {
- return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
-}
-
-const AffineBound ForInst::getUpperBound() const {
- return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
-}
-
-void ForInst::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
- assert(lbOperands.size() == map.getNumInputs());
- assert(map.getNumResults() >= 1 && "bound map has at least one result");
-
- SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
-
- auto ubOperands = getUpperBoundOperands();
- newOperands.append(ubOperands.begin(), ubOperands.end());
- getOperandStorage().setOperands(this, newOperands);
-
- this->lbMap = map;
-}
-
-void ForInst::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
- assert(ubOperands.size() == map.getNumInputs());
- assert(map.getNumResults() >= 1 && "bound map has at least one result");
-
- SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
- newOperands.append(ubOperands.begin(), ubOperands.end());
- getOperandStorage().setOperands(this, newOperands);
-
- this->ubMap = map;
-}
-
-void ForInst::setLowerBoundMap(AffineMap map) {
- assert(lbMap.getNumDims() == map.getNumDims() &&
- lbMap.getNumSymbols() == map.getNumSymbols());
- assert(map.getNumResults() >= 1 && "bound map has at least one result");
- this->lbMap = map;
-}
-
-void ForInst::setUpperBoundMap(AffineMap map) {
- assert(ubMap.getNumDims() == map.getNumDims() &&
- ubMap.getNumSymbols() == map.getNumSymbols());
- assert(map.getNumResults() >= 1 && "bound map has at least one result");
- this->ubMap = map;
-}
-
-bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
-
-bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
-
-int64_t ForInst::getConstantLowerBound() const {
- return lbMap.getSingleConstantResult();
-}
-
-int64_t ForInst::getConstantUpperBound() const {
- return ubMap.getSingleConstantResult();
-}
-
-void ForInst::setConstantLowerBound(int64_t value) {
- setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
-}
-
-void ForInst::setConstantUpperBound(int64_t value) {
- setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
-}
-
-ForInst::operand_range ForInst::getLowerBoundOperands() {
- return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
-}
-
-ForInst::const_operand_range ForInst::getLowerBoundOperands() const {
- return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
-}
-
-ForInst::operand_range ForInst::getUpperBoundOperands() {
- return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
-}
-
-ForInst::const_operand_range ForInst::getUpperBoundOperands() const {
- return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
-}
-
-bool ForInst::matchingBoundOperandList() const {
- if (lbMap.getNumDims() != ubMap.getNumDims() ||
- lbMap.getNumSymbols() != ubMap.getNumSymbols())
- return false;
-
- unsigned numOperands = lbMap.getNumInputs();
- for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
- // Compare Value *'s.
- if (getOperand(i) != getOperand(numOperands + i))
- return false;
- }
- return true;
-}
-
-void ForInst::walkOps(std::function<void(OperationInst *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(OperationInst *)> const &callback;
- Walker(std::function<void(OperationInst *)> const &callback)
- : callback(callback) {}
-
- void visitOperationInst(OperationInst *opInst) { callback(opInst); }
- };
-
- Walker w(callback);
- w.walk(this);
-}
-
-void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(OperationInst *)> const &callback;
- Walker(std::function<void(OperationInst *)> const &callback)
- : callback(callback) {}
-
- void visitOperationInst(OperationInst *opInst) { callback(opInst); }
- };
-
- Walker v(callback);
- v.walkPostOrder(this);
-}
-
-/// Returns the induction variable for this loop.
-Value *ForInst::getInductionVar() { return getBody()->getArgument(0); }
-
-void ForInst::destroy() {
- this->~ForInst();
- free(this);
-}
-
-/// Returns if the provided value is the induction variable of a ForInst.
-bool mlir::isForInductionVar(const Value *val) {
- return getForInductionVarOwner(val) != nullptr;
-}
-
-/// Returns the loop parent of an induction variable. If the provided value is
-/// not an induction variable, then return nullptr.
-ForInst *mlir::getForInductionVarOwner(Value *val) {
- const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
- if (!ivArg || !ivArg->getOwner())
- return nullptr;
- return dyn_cast_or_null<ForInst>(
- ivArg->getOwner()->getParent()->getContainingInst());
-}
-const ForInst *mlir::getForInductionVarOwner(const Value *val) {
- return getForInductionVarOwner(const_cast<Value *>(val));
-}
-
-/// Extracts the induction variables from a list of ForInsts and returns them.
-SmallVector<Value *, 8>
-mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
- SmallVector<Value *, 8> results;
- for (auto *forInst : forInsts)
- results.push_back(forInst->getInductionVar());
- return results;
-}
-//===----------------------------------------------------------------------===//
// Instruction Cloning
//===----------------------------------------------------------------------===//
@@ -879,84 +657,59 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
MLIRContext *context) const {
SmallVector<Value *, 8> operands;
SmallVector<Block *, 2> successors;
- if (auto *opInst = dyn_cast<OperationInst>(this)) {
- operands.reserve(getNumOperands() + opInst->getNumSuccessors());
- if (!opInst->isTerminator()) {
- // Non-terminators just add all the operands.
- for (auto *opValue : getOperands())
+ auto *opInst = cast<OperationInst>(this);
+ operands.reserve(getNumOperands() + opInst->getNumSuccessors());
+
+ if (!opInst->isTerminator()) {
+ // Non-terminators just add all the operands.
+ for (auto *opValue : getOperands())
+ operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
+ } else {
+ // We add the operands separated by nullptr's for each successor.
+ unsigned firstSuccOperand = opInst->getNumSuccessors()
+ ? opInst->getSuccessorOperandIndex(0)
+ : opInst->getNumOperands();
+ auto InstOperands = opInst->getInstOperands();
+
+ unsigned i = 0;
+ for (; i != firstSuccOperand; ++i)
+ operands.push_back(
+ mapper.lookupOrDefault(const_cast<Value *>(InstOperands[i].get())));
+
+ successors.reserve(opInst->getNumSuccessors());
+ for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; ++succ) {
+ successors.push_back(mapper.lookupOrDefault(
+ const_cast<Block *>(opInst->getSuccessor(succ))));
+
+ // Add sentinel to delineate successor operands.
+ operands.push_back(nullptr);
+
+ // Remap the successors operands.
+ for (auto *operand : opInst->getSuccessorOperands(succ))
operands.push_back(
- mapper.lookupOrDefault(const_cast<Value *>(opValue)));
- } else {
- // We add the operands separated by nullptr's for each successor.
- unsigned firstSuccOperand = opInst->getNumSuccessors()
- ? opInst->getSuccessorOperandIndex(0)
- : opInst->getNumOperands();
- auto InstOperands = opInst->getInstOperands();
-
- unsigned i = 0;
- for (; i != firstSuccOperand; ++i)
- operands.push_back(
- mapper.lookupOrDefault(const_cast<Value *>(InstOperands[i].get())));
-
- successors.reserve(opInst->getNumSuccessors());
- for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e;
- ++succ) {
- successors.push_back(mapper.lookupOrDefault(
- const_cast<Block *>(opInst->getSuccessor(succ))));
-
- // Add sentinel to delineate successor operands.
- operands.push_back(nullptr);
-
- // Remap the successors operands.
- for (auto *operand : opInst->getSuccessorOperands(succ))
- operands.push_back(
- mapper.lookupOrDefault(const_cast<Value *>(operand)));
- }
+ mapper.lookupOrDefault(const_cast<Value *>(operand)));
}
-
- SmallVector<Type, 8> resultTypes;
- resultTypes.reserve(opInst->getNumResults());
- for (auto *result : opInst->getResults())
- resultTypes.push_back(result->getType());
-
- unsigned numBlockLists = opInst->getNumBlockLists();
- auto *newOp = OperationInst::create(
- getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(),
- successors, numBlockLists, opInst->hasResizableOperandsList(), context);
-
- // Clone the block lists.
- for (unsigned i = 0; i != numBlockLists; ++i)
- opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper,
- context);
-
- // Remember the mapping of any results.
- for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i)
- mapper.map(opInst->getResult(i), newOp->getResult(i));
- return newOp;
}
- operands.reserve(getNumOperands());
- for (auto *opValue : getOperands())
- operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
+ SmallVector<Type, 8> resultTypes;
+ resultTypes.reserve(opInst->getNumResults());
+ for (auto *result : opInst->getResults())
+ resultTypes.push_back(result->getType());
- // Otherwise, this must be a ForInst.
- auto *forInst = cast<ForInst>(this);
- auto lbMap = forInst->getLowerBoundMap();
- auto ubMap = forInst->getUpperBoundMap();
+ unsigned numBlockLists = opInst->getNumBlockLists();
+ auto *newOp = OperationInst::create(
+ getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(),
+ successors, numBlockLists, opInst->hasResizableOperandsList(), context);
- auto *newFor = ForInst::create(
- getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
- lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap,
- forInst->getStep());
-
- // Remember the induction variable mapping.
- mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
+ // Clone the block lists.
+ for (unsigned i = 0; i != numBlockLists; ++i)
+ opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, context);
- // Recursively clone the body of the for loop.
- for (auto &subInst : *forInst->getBody())
- newFor->getBody()->push_back(subInst.clone(mapper, context));
- return newFor;
+ // Remember the mapping of any results.
+ for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i)
+ mapper.map(opInst->getResult(i), newOp->getResult(i));
+ return newOp;
}
Instruction *Instruction::clone(MLIRContext *context) const {
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 7103eeb7389..a9c046dc7b1 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -64,8 +64,6 @@ MLIRContext *IROperandOwner::getContext() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getContext();
- case Kind::ForInst:
- return cast<ForInst>(this)->getContext();
}
}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index f0c140166ed..a9c62767734 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -2128,23 +2128,6 @@ public:
parseSuccessors(SmallVectorImpl<Block *> &destinations,
SmallVectorImpl<SmallVector<Value *, 4>> &operands);
- ParseResult
- parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results,
- Block *owner);
-
- ParseResult parseOperationBlockList(SmallVectorImpl<Block *> &results);
- ParseResult parseBlockListBody(SmallVectorImpl<Block *> &results);
- ParseResult parseBlock(Block *&block);
- ParseResult parseBlockBody(Block *block);
-
- /// Cleans up the memory for allocated blocks when a parser error occurs.
- void cleanupInvalidBlocks(ArrayRef<Block *> invalidBlocks) {
- // Add the referenced blocks to the function so that they can be properly
- // cleaned up when the function is destroyed.
- for (auto *block : invalidBlocks)
- function->push_back(block);
- }
-
/// After the function is finished parsing, this function checks to see if
/// there are any remaining issues.
ParseResult finalizeFunction(SMLoc loc);
@@ -2187,6 +2170,25 @@ public:
// Block references.
+ ParseResult
+ parseOperationBlockList(SmallVectorImpl<Block *> &results,
+ ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments);
+ ParseResult parseBlockListBody(SmallVectorImpl<Block *> &results);
+ ParseResult parseBlock(Block *&block);
+ ParseResult parseBlockBody(Block *block);
+
+ ParseResult
+ parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results,
+ Block *owner);
+
+ /// Cleans up the memory for allocated blocks when a parser error occurs.
+ void cleanupInvalidBlocks(ArrayRef<Block *> invalidBlocks) {
+ // Add the referenced blocks to the function so that they can be properly
+ // cleaned up when the function is destroyed.
+ for (auto *block : invalidBlocks)
+ function->push_back(block);
+ }
+
/// Get the block with the specified name, creating it if it doesn't
/// already exist. The location specified is the point of use, which allows
/// us to diagnose references to blocks that are not defined precisely.
@@ -2201,13 +2203,6 @@ public:
OperationInst *parseGenericOperation();
OperationInst *parseCustomOperation();
- ParseResult parseForInst();
- ParseResult parseIntConstant(int64_t &val);
- ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
- unsigned numDims, unsigned numOperands,
- const char *affineStructName);
- ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
- bool isLower);
ParseResult parseInstructions(Block *block);
private:
@@ -2287,25 +2282,43 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) {
///
/// block-list ::= '{' block-list-body
///
-ParseResult
-FunctionParser::parseOperationBlockList(SmallVectorImpl<Block *> &results) {
+ParseResult FunctionParser::parseOperationBlockList(
+ SmallVectorImpl<Block *> &results,
+ ArrayRef<std::pair<FunctionParser::SSAUseInfo, Type>> entryArguments) {
// Parse the '{'.
if (parseToken(Token::l_brace, "expected '{' to begin block list"))
return ParseFailure;
+
// Check for an empty block list.
- if (consumeIf(Token::r_brace))
+ if (entryArguments.empty() && consumeIf(Token::r_brace))
return ParseSuccess;
Block *currentBlock = builder.getInsertionBlock();
// Parse the first block directly to allow for it to be unnamed.
Block *block = new Block();
+
+ // Add arguments to the entry block.
+ for (auto &placeholderArgPair : entryArguments)
+ if (addDefinition(placeholderArgPair.first,
+ block->addArgument(placeholderArgPair.second))) {
+ delete block;
+ return ParseFailure;
+ }
+
if (parseBlock(block)) {
- cleanupInvalidBlocks(block);
+ delete block;
return ParseFailure;
}
- results.push_back(block);
+
+ // Verify that no other arguments were parsed.
+ if (!entryArguments.empty() &&
+ block->getNumArguments() > entryArguments.size()) {
+ delete block;
+ return emitError("entry block arguments were already defined");
+ }
// Parse the rest of the block list.
+ results.push_back(block);
if (parseBlockListBody(results))
return ParseFailure;
@@ -2385,10 +2398,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) {
if (parseOperation())
return ParseFailure;
break;
- case Token::kw_for:
- if (parseForInst())
- return ParseFailure;
- break;
}
}
@@ -2859,7 +2868,7 @@ OperationInst *FunctionParser::parseGenericOperation() {
std::vector<SmallVector<Block *, 2>> blocks;
while (getToken().is(Token::l_brace)) {
SmallVector<Block *, 2> newBlocks;
- if (parseOperationBlockList(newBlocks)) {
+ if (parseOperationBlockList(newBlocks, /*entryArguments=*/llvm::None)) {
for (auto &blockList : blocks)
cleanupInvalidBlocks(blockList);
return nullptr;
@@ -2884,6 +2893,27 @@ public:
CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
: nameLoc(nameLoc), opName(opName), parser(parser) {}
+ bool parseOperation(const AbstractOperation *opDefinition,
+ OperationState *opState) {
+ if (opDefinition->parseAssembly(this, opState))
+ return true;
+
+ // Check that enough block lists were reserved for those that were parsed.
+ if (parsedBlockLists.size() > opState->numBlockLists) {
+ return emitError(
+ nameLoc,
+ "parsed more block lists than those reserved in the operation state");
+ }
+
+ // Check there were no dangling entry block arguments.
+ if (!parsedBlockListEntryArguments.empty()) {
+ return emitError(
+ nameLoc,
+ "no block list was attached to parsed entry block arguments");
+ }
+ return false;
+ }
+
//===--------------------------------------------------------------------===//
// High level parsing methods.
//===--------------------------------------------------------------------===//
@@ -2895,6 +2925,9 @@ public:
bool parseComma() override {
return parser.parseToken(Token::comma, "expected ','");
}
+ bool parseEqual() override {
+ return parser.parseToken(Token::equal, "expected '='");
+ }
bool parseType(Type &result) override {
return !(result = parser.parseType());
@@ -3083,13 +3116,35 @@ public:
/// Parses a list of blocks.
bool parseBlockList() override {
+ // Parse the block list.
SmallVector<Block *, 2> results;
- if (parser.parseOperationBlockList(results))
+ if (parser.parseOperationBlockList(results, parsedBlockListEntryArguments))
return true;
+
+ parsedBlockListEntryArguments.clear();
parsedBlockLists.emplace_back(results);
return false;
}
+ /// Parses an argument for the entry block of the next block list to be
+ /// parsed.
+ bool parseBlockListEntryBlockArgument(Type argType) override {
+ SmallVector<Value *, 1> argValues;
+ OperandType operand;
+ if (parseOperand(operand))
+ return true;
+
+ // Create a place holder for this argument.
+ FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
+ operand.location};
+ if (auto *value = parser.resolveSSAUse(operandInfo, argType)) {
+ parsedBlockListEntryArguments.emplace_back(operandInfo, argType);
+ return false;
+ }
+
+ return true;
+ }
+
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
@@ -3130,6 +3185,8 @@ public:
private:
std::vector<SmallVector<Block *, 2>> parsedBlockLists;
+ SmallVector<std::pair<FunctionParser::SSAUseInfo, Type>, 2>
+ parsedBlockListEntryArguments;
SMLoc nameLoc;
StringRef opName;
FunctionParser &parser;
@@ -3161,26 +3218,18 @@ OperationInst *FunctionParser::parseCustomOperation() {
// Have the op implementation take a crack and parsing this.
OperationState opState(builder.getContext(), srcLocation, opName);
- if (opDefinition->parseAssembly(&opAsmParser, &opState))
+ if (opAsmParser.parseOperation(opDefinition, &opState))
return nullptr;
// If it emitted an error, we failed.
if (opAsmParser.didEmitError())
return nullptr;
- // Check that enough block lists were reserved for those that were parsed.
- auto parsedBlockLists = opAsmParser.getParsedBlockLists();
- if (parsedBlockLists.size() > opState.numBlockLists) {
- opAsmParser.emitError(
- opLoc,
- "parsed more block lists than those reserved in the operation state");
- return nullptr;
- }
-
// Otherwise, we succeeded. Use the state it parsed as our op information.
auto *opInst = builder.createOperation(opState);
// Resolve any parsed block lists.
+ auto parsedBlockLists = opAsmParser.getParsedBlockLists();
for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) {
auto &opBlockList = opInst->getBlockList(i).getBlocks();
opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(),
@@ -3189,213 +3238,6 @@ OperationInst *FunctionParser::parseCustomOperation() {
return opInst;
}
-/// For instruction.
-///
-/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound
-/// (`step` integer-literal)? trailing-location? `{` inst* `}`
-///
-ParseResult FunctionParser::parseForInst() {
- consumeToken(Token::kw_for);
-
- // Parse induction variable.
- if (getToken().isNot(Token::percent_identifier))
- return emitError("expected SSA identifier for the loop variable");
-
- auto loc = getToken().getLoc();
- StringRef inductionVariableName = getTokenSpelling();
- consumeToken(Token::percent_identifier);
-
- if (parseToken(Token::equal, "expected '='"))
- return ParseFailure;
-
- // Parse lower bound.
- SmallVector<Value *, 4> lbOperands;
- AffineMap lbMap;
- if (parseBound(lbOperands, lbMap, /*isLower*/ true))
- return ParseFailure;
-
- if (parseToken(Token::kw_to, "expected 'to' between bounds"))
- return ParseFailure;
-
- // Parse upper bound.
- SmallVector<Value *, 4> ubOperands;
- AffineMap ubMap;
- if (parseBound(ubOperands, ubMap, /*isLower*/ false))
- return ParseFailure;
-
- // Parse step.
- int64_t step = 1;
- if (consumeIf(Token::kw_step) && parseIntConstant(step))
- return ParseFailure;
-
- // The loop step is a positive integer constant. Since index is stored as an
- // int64_t type, we restrict step to be in the set of positive integers that
- // int64_t can represent.
- if (step < 1) {
- return emitError("step has to be a positive integer");
- }
-
- // Create for instruction.
- ForInst *forInst =
- builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap,
- ubOperands, ubMap, step);
-
- // Create SSA value definition for the induction variable.
- if (addDefinition({inductionVariableName, 0, loc},
- forInst->getInductionVar()))
- return ParseFailure;
-
- // Try to parse the optional trailing location.
- if (parseOptionalTrailingLocation(forInst))
- return ParseFailure;
-
- // If parsing of the for instruction body fails,
- // MLIR contains for instruction with those nested instructions that have been
- // successfully parsed.
- auto *forBody = forInst->getBody();
- if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
- parseBlock(forBody) ||
- parseToken(Token::r_brace, "expected '}' after instruction list"))
- return ParseFailure;
-
- // Reset insertion point to the current block.
- builder.setInsertionPointToEnd(forInst->getBlock());
-
- return ParseSuccess;
-}
-
-/// Parse integer constant as affine constant expression.
-ParseResult FunctionParser::parseIntConstant(int64_t &val) {
- bool negate = consumeIf(Token::minus);
-
- if (getToken().isNot(Token::integer))
- return emitError("expected integer");
-
- auto uval = getToken().getUInt64IntegerValue();
-
- if (!uval.hasValue() || (int64_t)uval.getValue() < 0) {
- return emitError("bound or step is too large for index");
- }
-
- val = (int64_t)uval.getValue();
- if (negate)
- val = -val;
- consumeToken();
-
- return ParseSuccess;
-}
-
-/// Dimensions and symbol use list.
-///
-/// dim-use-list ::= `(` ssa-use-list? `)`
-/// symbol-use-list ::= `[` ssa-use-list? `]`
-/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list?
-///
-ParseResult
-FunctionParser::parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
- unsigned numDims, unsigned numOperands,
- const char *affineStructName) {
- if (parseToken(Token::l_paren, "expected '('"))
- return ParseFailure;
-
- SmallVector<SSAUseInfo, 4> opInfo;
- parseOptionalSSAUseList(opInfo);
-
- if (parseToken(Token::r_paren, "expected ')'"))
- return ParseFailure;
-
- if (numDims != opInfo.size())
- return emitError("dim operand count and " + Twine(affineStructName) +
- " dim count must match");
-
- if (consumeIf(Token::l_square)) {
- parseOptionalSSAUseList(opInfo);
- if (parseToken(Token::r_square, "expected ']'"))
- return ParseFailure;
- }
-
- if (numOperands != opInfo.size())
- return emitError("symbol operand count and " + Twine(affineStructName) +
- " symbol count must match");
-
- // Resolve SSA uses.
- Type indexType = builder.getIndexType();
- for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
- Value *sval = resolveSSAUse(opInfo[i], indexType);
- if (!sval)
- return ParseFailure;
-
- if (i < numDims && !sval->isValidDim())
- return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
- "' cannot be used as a dimension id");
- if (i >= numDims && !sval->isValidSymbol())
- return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
- "' cannot be used as a symbol");
- operands.push_back(sval);
- }
-
- return ParseSuccess;
-}
-
-// Loop bound.
-///
-/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list |
-/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list
-/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal
-///
-ParseResult FunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
- AffineMap &map, bool isLower) {
- // 'min' / 'max' prefixes are syntactic sugar. Ignore them.
- if (isLower)
- consumeIf(Token::kw_max);
- else
- consumeIf(Token::kw_min);
-
- // Parse full form - affine map followed by dim and symbol list.
- if (getToken().isAny(Token::hash_identifier, Token::l_paren)) {
- map = parseAffineMapReference();
- if (!map)
- return ParseFailure;
-
- if (parseDimAndSymbolList(operands, map.getNumDims(), map.getNumInputs(),
- "affine map"))
- return ParseFailure;
- return ParseSuccess;
- }
-
- // Parse custom assembly form.
- if (getToken().isAny(Token::minus, Token::integer)) {
- int64_t val;
- if (!parseIntConstant(val)) {
- map = builder.getConstantAffineMap(val);
- return ParseSuccess;
- }
- return ParseFailure;
- }
-
- // Parse ssa-id as identity map.
- SSAUseInfo opInfo;
- if (parseSSAUse(opInfo))
- return ParseFailure;
-
- // TODO: improve error message when SSA value is not an affine integer.
- // Currently it is 'use of value ... expects different type than prior uses'
- if (auto *value = resolveSSAUse(opInfo, builder.getIndexType()))
- operands.push_back(value);
- else
- return ParseFailure;
-
- // Create an identity map using dim id for an induction variable and
- // symbol otherwise. This representation is optimized for storage.
- // Analysis passes may expand it into a multi-dimensional map if desired.
- if (isForInductionVar(operands[0]))
- map = builder.getDimIdentityMap();
- else
- map = builder.getSymbolIdentityMap();
-
- return ParseSuccess;
-}
-
/// Parse an affine constraint.
/// affine-constraint ::= affine-expr `>=` `0`
/// | affine-expr `==` `0`
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index afd18a49b79..e471b6792c5 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -183,11 +183,6 @@ void CSE::simplifyBlock(Block *bb) {
}
break;
}
- case Instruction::Kind::For: {
- ScopedMapTy::ScopeTy scope(knownValues);
- simplifyBlock(cast<ForInst>(i).getBody());
- break;
- }
}
}
}
diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp
index f9d02f7a47a..9c20e79180a 100644
--- a/mlir/lib/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Transforms/ConstantFold.cpp
@@ -15,6 +15,7 @@
// limitations under the License.
// =============================================================================
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
@@ -37,7 +38,6 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
bool foldOperation(OperationInst *op,
SmallVectorImpl<Value *> &existingConstants);
void visitOperationInst(OperationInst *inst);
- void visitForInst(ForInst *inst);
PassResult runOnFunction(Function *f) override;
static char passID;
@@ -50,6 +50,12 @@ char ConstantFold::passID = 0;
/// constants are found, we keep track of them in the existingConstants list.
///
void ConstantFold::visitOperationInst(OperationInst *op) {
+ // If this operation is an AffineForOp, then fold the bounds.
+ if (auto forOp = op->dyn_cast<AffineForOp>()) {
+ constantFoldBounds(forOp);
+ return;
+ }
+
// If this operation is already a constant, just remember it for cleanup
// later, and don't try to fold it.
if (auto constant = op->dyn_cast<ConstantOp>()) {
@@ -98,11 +104,6 @@ void ConstantFold::visitOperationInst(OperationInst *op) {
opInstsToErase.push_back(op);
}
-// Override the walker's 'for' instruction visit for constant folding.
-void ConstantFold::visitForInst(ForInst *forInst) {
- constantFoldBounds(forInst);
-}
-
// For now, we do a simple top-down pass over a function folding constants. We
// don't handle conditional control flow, block arguments, folding
// conditional branches, or anything else fancy.
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 5c3a66208ec..83ec726ec2a 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -21,6 +21,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
@@ -71,9 +72,9 @@ struct DmaGeneration : public FunctionPass {
}
PassResult runOnFunction(Function *f) override;
- void runOnForInst(ForInst *forInst);
+ void runOnAffineForOp(OpPointer<AffineForOp> forOp);
- bool generateDma(const MemRefRegion &region, ForInst *forInst,
+ bool generateDma(const MemRefRegion &region, OpPointer<AffineForOp> forOp,
uint64_t *sizeInBytes);
// List of memory regions to DMA for. We need a map vector to have a
@@ -174,7 +175,7 @@ static bool getFullMemRefAsRegion(OperationInst *opInst,
// Just get the first numSymbols IVs, which the memref region is parametric
// on.
- SmallVector<ForInst *, 4> ivs;
+ SmallVector<OpPointer<AffineForOp>, 4> ivs;
getLoopIVs(*opInst, &ivs);
ivs.resize(numParamLoopIVs);
SmallVector<Value *, 4> symbols = extractForInductionVars(ivs);
@@ -195,8 +196,10 @@ static bool getFullMemRefAsRegion(OperationInst *opInst,
// generates a DMA from the lower memory space to this one, and replaces all
// loads to load from that buffer. Returns false if DMAs could not be generated
// due to yet unimplemented cases.
-bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
+bool DmaGeneration::generateDma(const MemRefRegion &region,
+ OpPointer<AffineForOp> forOp,
uint64_t *sizeInBytes) {
+ auto *forInst = forOp->getInstruction();
// DMAs for read regions are going to be inserted just before the for loop.
FuncBuilder prologue(forInst);
@@ -386,39 +389,43 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
remapExprs.push_back(dimExpr - offsets[i]);
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
- // *Only* those uses within the body of 'forInst' are replaced.
+ // *Only* those uses within the body of 'forOp' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domInstFilter=*/&*forInst->getBody()->begin());
+ /*domInstFilter=*/&*forOp->getBody()->begin());
return true;
}
// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
-void DmaGeneration::runOnForInst(ForInst *forInst) {
+void DmaGeneration::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// For now (for testing purposes), we'll run this on the outermost among 'for'
// inst's with unit stride, i.e., right at the top of the tile if tiling has
// been done. In the future, the DMA generation has to be done at a level
// where the generated data fits in a higher level of the memory hierarchy; so
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
- if (forInst->getStep() != 1) {
- if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) {
- runOnForInst(innerFor);
+ if (forOp->getStep() != 1) {
+ auto *forBody = forOp->getBody();
+ if (forBody->empty())
+ return;
+ if (auto innerFor =
+ cast<OperationInst>(forBody->front()).dyn_cast<AffineForOp>()) {
+ runOnAffineForOp(innerFor);
}
return;
}
// DMAs will be generated for this depth, i.e., for all data accessed by this
// loop.
- unsigned dmaDepth = getNestingDepth(*forInst);
+ unsigned dmaDepth = getNestingDepth(*forOp->getInstruction());
readRegions.clear();
writeRegions.clear();
fastBufferMap.clear();
// Walk this 'for' instruction to gather all memory regions.
- forInst->walkOps([&](OperationInst *opInst) {
+ forOp->walkOps([&](OperationInst *opInst) {
// Gather regions to promote to buffers in faster memory space.
// TODO(bondhugula): handle store op's; only load's handled for now.
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
@@ -443,7 +450,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(
- forInst->emitError("Non-constant memref sizes not yet supported"));
+ forOp->emitError("Non-constant memref sizes not yet supported"));
return;
}
}
@@ -472,10 +479,10 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
// Perform a union with the existing region.
if (!(*it).second->unionBoundingBox(*region)) {
LLVM_DEBUG(llvm::dbgs()
- << "Memory region bounding box failed; "
+ << "Memory region bounding box failed"
"over-approximating to the entire memref\n");
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
- LLVM_DEBUG(forInst->emitError(
+ LLVM_DEBUG(forOp->emitError(
"Non-constant memref sizes not yet supported"));
}
}
@@ -501,7 +508,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
&regions) {
for (const auto &regionEntry : regions) {
uint64_t sizeInBytes;
- bool iRet = generateDma(*regionEntry.second, forInst, &sizeInBytes);
+ bool iRet = generateDma(*regionEntry.second, forOp, &sizeInBytes);
if (iRet)
totalSizeInBytes += sizeInBytes;
ret = ret & iRet;
@@ -510,7 +517,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
processRegions(readRegions);
processRegions(writeRegions);
if (!ret) {
- forInst->emitError("DMA generation failed for one or more memref's\n");
+ forOp->emitError("DMA generation failed for one or more memref's\n");
return;
}
LLVM_DEBUG(llvm::dbgs() << Twine(llvm::divideCeil(totalSizeInBytes, 1024))
@@ -519,7 +526,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
if (clFastMemoryCapacity && totalSizeInBytes > clFastMemoryCapacity) {
// TODO(bondhugula): selecting the DMA depth so that the result DMA buffers
// fit in fast memory is a TODO - not complex.
- forInst->emitError(
+ forOp->emitError(
"Total size of all DMA buffers' exceeds memory capacity\n");
}
}
@@ -531,8 +538,8 @@ PassResult DmaGeneration::runOnFunction(Function *f) {
for (auto &block : *f) {
for (auto &inst : block) {
- if (auto *forInst = dyn_cast<ForInst>(&inst)) {
- runOnForInst(forInst);
+ if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>()) {
+ runOnAffineForOp(forOp);
}
}
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index fa0e3b51de3..7d4ff03e306 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -97,15 +97,15 @@ namespace {
// operations, and whether or not an IfInst was encountered in the loop nest.
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
- SmallVector<ForInst *, 4> forInsts;
+ SmallVector<OpPointer<AffineForOp>, 4> forOps;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
bool hasNonForRegion = false;
- void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
-
void visitOperationInst(OperationInst *opInst) {
- if (opInst->getNumBlockLists() != 0)
+ if (opInst->isa<AffineForOp>())
+ forOps.push_back(opInst->cast<AffineForOp>());
+ else if (opInst->getNumBlockLists() != 0)
hasNonForRegion = true;
else if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
@@ -491,14 +491,14 @@ bool MemRefDependenceGraph::init(Function *f) {
if (f->getBlocks().size() != 1)
return false;
- DenseMap<ForInst *, unsigned> forToNodeMap;
+ DenseMap<Instruction *, unsigned> forToNodeMap;
for (auto &inst : f->front()) {
- if (auto *forInst = dyn_cast<ForInst>(&inst)) {
- // Create graph node 'id' to represent top-level 'forInst' and record
+ if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) {
+ // Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
- collector.walkForInst(forInst);
- // Return false if IfInsts are found (not currently supported).
+ collector.walk(&inst);
+ // Return false if a non 'for' region was found (not currently supported).
if (collector.hasNonForRegion)
return false;
Node node(nextNodeId++, &inst);
@@ -512,10 +512,9 @@ bool MemRefDependenceGraph::init(Function *f) {
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
- forToNodeMap[forInst] = node.id;
+ forToNodeMap[&inst] = node.id;
nodes.insert({node.id, node});
- }
- if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ } else if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &inst);
@@ -552,12 +551,12 @@ bool MemRefDependenceGraph::init(Function *f) {
for (auto *value : opInst->getResults()) {
for (auto &use : value->getUses()) {
auto *userOpInst = cast<OperationInst>(use.getOwner());
- SmallVector<ForInst *, 4> loops;
+ SmallVector<OpPointer<AffineForOp>, 4> loops;
getLoopIVs(*userOpInst, &loops);
if (loops.empty())
continue;
- assert(forToNodeMap.count(loops[0]) > 0);
- unsigned userLoopNestId = forToNodeMap[loops[0]];
+ assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
+ unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
addEdge(node.id, userLoopNestId, value);
}
}
@@ -587,12 +586,12 @@ namespace {
// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
// and operation count) for a loop nest up until the innermost loop body.
struct LoopNestStats {
- // Map from ForInst to immediate child ForInsts in its loop body.
- DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
- // Map from ForInst to count of operations in its loop body.
- DenseMap<ForInst *, uint64_t> opCountMap;
- // Map from ForInst to its constant trip count.
- DenseMap<ForInst *, uint64_t> tripCountMap;
+ // Map from AffineForOp to immediate child AffineForOps in its loop body.
+ DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
+ // Map from AffineForOp to count of operations in its loop body.
+ DenseMap<Instruction *, uint64_t> opCountMap;
+ // Map from AffineForOp to its constant trip count.
+ DenseMap<Instruction *, uint64_t> tripCountMap;
};
// LoopNestStatsCollector walks a single loop nest and gathers per-loop
@@ -604,23 +603,31 @@ public:
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
- void visitForInst(ForInst *forInst) {
- auto *parentInst = forInst->getParentInst();
+ void visitOperationInst(OperationInst *opInst) {
+ auto forOp = opInst->dyn_cast<AffineForOp>();
+ if (!forOp)
+ return;
+
+ auto *forInst = forOp->getInstruction();
+ auto *parentInst = forOp->getInstruction()->getParentInst();
if (parentInst != nullptr) {
- assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
- // Add mapping to 'forInst' from its parent ForInst.
- stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
+ assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() &&
+ "Expected parent AffineForOp");
+ // Add mapping to 'forOp' from its parent AffineForOp.
+ stats->loopMap[parentInst].push_back(forOp);
}
- // Record the number of op instructions in the body of 'forInst'.
+
+ // Record the number of op instructions in the body of 'forOp'.
unsigned count = 0;
stats->opCountMap[forInst] = 0;
- for (auto &inst : *forInst->getBody()) {
- if (isa<OperationInst>(&inst))
+ for (auto &inst : *forOp->getBody()) {
+ if (!(cast<OperationInst>(inst).isa<AffineForOp>() ||
+ cast<OperationInst>(inst).isa<AffineIfOp>()))
++count;
}
stats->opCountMap[forInst] = count;
- // Record trip count for 'forInst'. Set flag if trip count is not constant.
- Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
+ // Record trip count for 'forOp'. Set flag if trip count is not constant.
+ Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount.hasValue()) {
hasLoopWithNonConstTripCount = true;
return;
@@ -629,7 +636,7 @@ public:
}
};
-// Computes the total cost of the loop nest rooted at 'forInst'.
+// Computes the total cost of the loop nest rooted at 'forOp'.
// Currently, the total cost is computed by counting the total operation
// instance count (i.e. total number of operations in the loop bodyloop
// operation count * loop trip count) for the entire loop nest.
@@ -637,7 +644,7 @@ public:
// specified in the map when computing the total op instance count.
// NOTE: this is used to compute the cost of computation slices, which are
// sliced along the iteration dimension, and thus reduce the trip count.
-// If 'computeCostMap' is non-null, the total op count for forInsts specified
+// If 'computeCostMap' is non-null, the total op count for forOps specified
// in the map is increased (not overridden) by adding the op count from the
// map to the existing op count for the for loop. This is done before
// multiplying by the loop's trip count, and is used to model the cost of
@@ -645,15 +652,15 @@ public:
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
// within another loop.
static int64_t getComputeCost(
- ForInst *forInst, LoopNestStats *stats,
- llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
- DenseMap<ForInst *, int64_t> *computeCostMap) {
- // 'opCount' is the total number operations in one iteration of 'forInst' body
+ Instruction *forInst, LoopNestStats *stats,
+ llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
+ DenseMap<Instruction *, int64_t> *computeCostMap) {
+ // 'opCount' is the total number operations in one iteration of 'forOp' body
int64_t opCount = stats->opCountMap[forInst];
if (stats->loopMap.count(forInst) > 0) {
- for (auto *childForInst : stats->loopMap[forInst]) {
- opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
- computeCostMap);
+ for (auto childForOp : stats->loopMap[forInst]) {
+ opCount += getComputeCost(childForOp->getInstruction(), stats,
+ tripCountOverrideMap, computeCostMap);
}
}
// Add in additional op instances from slice (if specified in map).
@@ -694,18 +701,18 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
return cExpr.getValue();
}
-// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
+// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
// Returns true on success, false otherwise (if a non-constant trip count
// was encountered).
// TODO(andydavis) Make this work with non-unit step loops.
static bool buildSliceTripCountMap(
OperationInst *srcOpInst, ComputationSliceState *sliceState,
- llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
- SmallVector<ForInst *, 4> srcLoopIVs;
+ llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
+ SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
- // Populate map from ForInst -> trip count
+ // Populate map from AffineForOp -> trip count
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
AffineMap lbMap = sliceState->lbs[i];
AffineMap ubMap = sliceState->ubs[i];
@@ -713,7 +720,7 @@ static bool buildSliceTripCountMap(
// The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
if (srcLoopIVs[i]->hasConstantLowerBound() &&
srcLoopIVs[i]->hasConstantUpperBound()) {
- (*tripCountMap)[srcLoopIVs[i]] =
+ (*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
srcLoopIVs[i]->getConstantUpperBound() -
srcLoopIVs[i]->getConstantLowerBound();
continue;
@@ -723,7 +730,7 @@ static bool buildSliceTripCountMap(
Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
if (!tripCount.hasValue())
return false;
- (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
+ (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
}
return true;
}
@@ -750,7 +757,7 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
unsigned numOps = ops.size();
assert(numOps > 0);
- std::vector<SmallVector<ForInst *, 4>> loops(numOps);
+ std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
for (unsigned i = 0; i < numOps; ++i) {
getLoopIVs(*ops[i], &loops[i]);
@@ -762,9 +769,8 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
for (unsigned d = 0; d < loopDepthLimit; ++d) {
unsigned i;
for (i = 1; i < numOps; ++i) {
- if (loops[i - 1][d] != loops[i][d]) {
+ if (loops[i - 1][d] != loops[i][d])
break;
- }
}
if (i != numOps)
break;
@@ -871,14 +877,16 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA,
}
// Creates and returns a private (single-user) memref for fused loop rooted
-// at 'forInst', with (potentially reduced) memref size based on the
+// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO(bondhugula): consider refactoring the common code from generateDma and
// this one.
-static Value *createPrivateMemRef(ForInst *forInst,
+static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
OperationInst *srcStoreOpInst,
unsigned dstLoopDepth) {
- // Create builder to insert alloc op just before 'forInst'.
+ auto *forInst = forOp->getInstruction();
+
+ // Create builder to insert alloc op just before 'forOp'.
FuncBuilder b(forInst);
// Builder to create constants at the top level.
FuncBuilder top(forInst->getFunction());
@@ -934,16 +942,16 @@ static Value *createPrivateMemRef(ForInst *forInst,
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
allocOperands.push_back(
- top.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++));
+ top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
}
- // Create new private memref for fused loop 'forInst'.
+ // Create new private memref for fused loop 'forOp'.
// TODO(andydavis) Create/move alloc ops for private memrefs closer to their
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the function, because loop nests can be reordered
// during the fusion pass.
Value *newMemRef =
- top.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
+ top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
// Build an AffineMap to remap access functions based on lower bound offsets.
SmallVector<AffineExpr, 4> remapExprs;
@@ -967,7 +975,7 @@ static Value *createPrivateMemRef(ForInst *forInst,
bool ret =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domInstFilter=*/&*forInst->getBody()->begin());
+ /*domInstFilter=*/&*forOp->getBody()->begin());
assert(ret && "replaceAllMemrefUsesWith should always succeed here");
(void)ret;
return newMemRef;
@@ -975,7 +983,7 @@ static Value *createPrivateMemRef(ForInst *forInst,
// Does the slice have a single iteration?
static uint64_t getSliceIterationCount(
- const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
+ const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
uint64_t iterCount = 1;
for (const auto &count : sliceTripCountMap) {
iterCount *= count.second;
@@ -1030,25 +1038,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
});
// Compute cost of sliced and unsliced src loop nest.
- SmallVector<ForInst *, 4> srcLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
- srcStatsCollector.walk(srcLoopIVs[0]);
+ srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (srcStatsCollector.hasLoopWithNonConstTripCount)
return false;
// Compute cost of dst loop nest.
- SmallVector<ForInst *, 4> dstLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
LoopNestStats dstLoopNestStats;
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
- dstStatsCollector.walk(dstLoopIVs[0]);
+ dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (dstStatsCollector.hasLoopWithNonConstTripCount)
return false;
@@ -1075,17 +1083,19 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
Optional<unsigned> bestDstLoopDepth = None;
// Compute op instance count for the src loop nest without iteration slicing.
- uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t srcLoopNestCost =
+ getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
// Compute op instance count for the src loop nest.
- uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t dstLoopNestCost =
+ getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
- llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
- DenseMap<ForInst *, int64_t> computeCostMap;
+ llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
+ DenseMap<Instruction *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
MemRefAccess srcAccess(srcOpInst);
// Handle the common case of one dst load without a copy.
@@ -1121,24 +1131,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
// The store and loads to this memref will disappear.
if (storeLoadFwdGuaranteed) {
// A single store disappears: -1 for that.
- computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
+ computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
for (auto *loadOp : dstLoadOpInsts) {
- if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
- computeCostMap[loadLoop] = -1;
+ auto *parentInst = loadOp->getParentInst();
+ if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>())
+ computeCostMap[parentInst] = -1;
}
}
// Compute op instance count for the src loop nest with iteration slicing.
int64_t sliceComputeCost =
- getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
+ getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
/*tripCountOverrideMap=*/&sliceTripCountMap,
/*computeCostMap=*/&computeCostMap);
// Compute cost of fusion for this depth.
- computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
+ computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
int64_t fusedLoopNestComputeCost =
- getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
+ getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
double additionalComputeFraction =
@@ -1211,8 +1222,8 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
<< "\n fused loop nest compute cost: "
<< minFusedLoopNestComputeCost << "\n");
- auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
- auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
+ auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
+ auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Optional<double> storageReduction = None;
@@ -1292,9 +1303,9 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
-// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
-// destination ForInst into which fusion will be attempted.
-// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
+// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
+// candidate destination AffineForOp into which fusion will be attempted.
+// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
// *) For each LoadOp in 'dstLoadOps' do:
// *) Lookup dependent loop nests at earlier positions in the Function
// which have a single store op to the same memref.
@@ -1342,7 +1353,7 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
- if (!isa<ForInst>(dstNode->inst))
+ if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>())
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
@@ -1375,7 +1386,7 @@ public:
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
auto *srcNode = mdg->getNode(srcId);
// Skip if 'srcNode' is not a loop nest.
- if (!isa<ForInst>(srcNode->inst))
+ if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>())
continue;
// Skip if 'srcNode' has more than one store to any memref.
// TODO(andydavis) Support fusing multi-output src loop nests.
@@ -1417,25 +1428,26 @@ public:
continue;
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
- auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
+ auto sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
- // Move 'dstForInst' before 'insertPointInst' if needed.
- auto *dstForInst = cast<ForInst>(dstNode->inst);
- if (insertPointInst != dstForInst) {
- dstForInst->moveBefore(insertPointInst);
+ // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+ auto dstAffineForOp =
+ cast<OperationInst>(dstNode->inst)->cast<AffineForOp>();
+ if (insertPointInst != dstAffineForOp->getInstruction()) {
+ dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
}
// Update edges between 'srcNode' and 'dstNode'.
mdg->updateEdges(srcNode->id, dstNode->id, memref);
// Collect slice loop stats.
LoopNestStateCollector sliceCollector;
- sliceCollector.walkForInst(sliceLoopNest);
+ sliceCollector.walk(sliceLoopNest->getInstruction());
// Promote single iteration slice loops to single IV value.
- for (auto *forInst : sliceCollector.forInsts) {
- promoteIfSingleIteration(forInst);
+ for (auto forOp : sliceCollector.forOps) {
+ promoteIfSingleIteration(forOp);
}
- // Create private memref for 'memref' in 'dstForInst'.
+ // Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<OperationInst *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
@@ -1443,7 +1455,7 @@ public:
}
assert(storesForMemref.size() == 1);
auto *newMemRef = createPrivateMemRef(
- dstForInst, storesForMemref[0], bestDstLoopDepth);
+ dstAffineForOp, storesForMemref[0], bestDstLoopDepth);
visitedMemrefs.insert(newMemRef);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId =
@@ -1453,7 +1465,7 @@ public:
// Collect dst loop stats after memref privatizaton transformation.
LoopNestStateCollector dstLoopCollector;
- dstLoopCollector.walkForInst(dstForInst);
+ dstLoopCollector.walk(dstAffineForOp->getInstruction());
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
@@ -1472,7 +1484,7 @@ public:
// function.
if (mdg->canRemoveNode(srcNode->id)) {
mdg->removeNode(srcNode->id);
- cast<ForInst>(srcNode->inst)->erase();
+ srcNode->inst->erase();
}
}
}
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 396fc8eb658..f1ee7fd1853 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -60,16 +61,17 @@ char LoopTiling::passID = 0;
/// Function.
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
-// Move the loop body of ForInst 'src' from 'src' into the specified location in
-// destination's body.
-static inline void moveLoopBody(ForInst *src, ForInst *dest,
+// Move the loop body of AffineForOp 'src' from 'src' into the specified
+// location in destination's body.
+static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest,
Block::iterator loc) {
dest->getBody()->getInstructions().splice(loc,
src->getBody()->getInstructions());
}
-// Move the loop body of ForInst 'src' from 'src' to the start of dest's body.
-static inline void moveLoopBody(ForInst *src, ForInst *dest) {
+// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's
+// body.
+static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest) {
moveLoopBody(src, dest, dest->getBody()->begin());
}
@@ -78,13 +80,14 @@ static inline void moveLoopBody(ForInst *src, ForInst *dest) {
/// depend on other dimensions. Bounds of each dimension can thus be treated
/// independently, and deriving the new bounds is much simpler and faster
/// than for the case of tiling arbitrary polyhedral shapes.
-static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
- ArrayRef<ForInst *> newLoops,
- ArrayRef<unsigned> tileSizes) {
+static void constructTiledIndexSetHyperRect(
+ MutableArrayRef<OpPointer<AffineForOp>> origLoops,
+ MutableArrayRef<OpPointer<AffineForOp>> newLoops,
+ ArrayRef<unsigned> tileSizes) {
assert(!origLoops.empty());
assert(origLoops.size() == tileSizes.size());
- FuncBuilder b(origLoops[0]);
+ FuncBuilder b(origLoops[0]->getInstruction());
unsigned width = origLoops.size();
// Bounds for tile space loops.
@@ -99,8 +102,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
}
// Bounds for intra-tile loops.
for (unsigned i = 0; i < width; i++) {
- int64_t largestDiv = getLargestDivisorOfTripCount(*origLoops[i]);
- auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
+ int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
+ auto mayBeConstantCount = getConstantTripCount(origLoops[i]);
// The lower bound is just the tile-space loop.
AffineMap lbMap = b.getDimIdentityMap();
newLoops[width + i]->setLowerBound(
@@ -144,38 +147,40 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
// TODO(bondhugula): handle non hyper-rectangular spaces.
-UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
+UtilResult mlir::tileCodeGen(MutableArrayRef<OpPointer<AffineForOp>> band,
ArrayRef<unsigned> tileSizes) {
assert(!band.empty());
assert(band.size() == tileSizes.size());
// Check if the supplied for inst's are all successively nested.
for (unsigned i = 1, e = band.size(); i < e; i++) {
- assert(band[i]->getParentInst() == band[i - 1]);
+ assert(band[i]->getInstruction()->getParentInst() ==
+ band[i - 1]->getInstruction());
}
auto origLoops = band;
- ForInst *rootForInst = origLoops[0];
- auto loc = rootForInst->getLoc();
+ OpPointer<AffineForOp> rootAffineForOp = origLoops[0];
+ auto loc = rootAffineForOp->getLoc();
// Note that width is at least one since band isn't empty.
unsigned width = band.size();
- SmallVector<ForInst *, 12> newLoops(2 * width);
- ForInst *innermostPointLoop;
+ SmallVector<OpPointer<AffineForOp>, 12> newLoops(2 * width);
+ OpPointer<AffineForOp> innermostPointLoop;
// The outermost among the loops as we add more..
- auto *topLoop = rootForInst;
+ auto *topLoop = rootAffineForOp->getInstruction();
// Add intra-tile (or point) loops.
for (unsigned i = 0; i < width; i++) {
FuncBuilder b(topLoop);
// Loop bounds will be set later.
- auto *pointLoop = b.createFor(loc, 0, 0);
+ auto pointLoop = b.create<AffineForOp>(loc, 0, 0);
+ pointLoop->createBody();
pointLoop->getBody()->getInstructions().splice(
pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(),
topLoop);
newLoops[2 * width - 1 - i] = pointLoop;
- topLoop = pointLoop;
+ topLoop = pointLoop->getInstruction();
if (i == 0)
innermostPointLoop = pointLoop;
}
@@ -184,12 +189,13 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
for (unsigned i = width; i < 2 * width; i++) {
FuncBuilder b(topLoop);
// Loop bounds will be set later.
- auto *tileSpaceLoop = b.createFor(loc, 0, 0);
+ auto tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
+ tileSpaceLoop->createBody();
tileSpaceLoop->getBody()->getInstructions().splice(
tileSpaceLoop->getBody()->begin(),
topLoop->getBlock()->getInstructions(), topLoop);
newLoops[2 * width - i - 1] = tileSpaceLoop;
- topLoop = tileSpaceLoop;
+ topLoop = tileSpaceLoop->getInstruction();
}
// Move the loop body of the original nest to the new one.
@@ -201,8 +207,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
getIndexSet(band, &cst);
if (!cst.isHyperRectangular(0, width)) {
- rootForInst->emitError("tiled code generation unimplemented for the"
- "non-hyperrectangular case");
+ rootAffineForOp->emitError("tiled code generation unimplemented for the"
+ "non-hyperrectangular case");
return UtilResult::Failure;
}
@@ -213,7 +219,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
}
// Erase the old loop nest.
- rootForInst->erase();
+ rootAffineForOp->erase();
return UtilResult::Success;
}
@@ -221,38 +227,36 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
// Identify valid and profitable bands of loops to tile. This is currently just
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
-static void getTileableBands(Function *f,
- std::vector<SmallVector<ForInst *, 6>> *bands) {
+static void
+getTileableBands(Function *f,
+ std::vector<SmallVector<OpPointer<AffineForOp>, 6>> *bands) {
// Get maximal perfect nest of 'for' insts starting from root (inclusive).
- auto getMaximalPerfectLoopNest = [&](ForInst *root) {
- SmallVector<ForInst *, 6> band;
- ForInst *currInst = root;
+ auto getMaximalPerfectLoopNest = [&](OpPointer<AffineForOp> root) {
+ SmallVector<OpPointer<AffineForOp>, 6> band;
+ OpPointer<AffineForOp> currInst = root;
do {
band.push_back(currInst);
} while (currInst->getBody()->getInstructions().size() == 1 &&
- (currInst = dyn_cast<ForInst>(&currInst->getBody()->front())));
+ (currInst = cast<OperationInst>(currInst->getBody()->front())
+ .dyn_cast<AffineForOp>()));
bands->push_back(band);
};
- for (auto &block : *f) {
- for (auto &inst : block) {
- auto *forInst = dyn_cast<ForInst>(&inst);
- if (!forInst)
- continue;
- getMaximalPerfectLoopNest(forInst);
- }
- }
+ for (auto &block : *f)
+ for (auto &inst : block)
+ if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>())
+ getMaximalPerfectLoopNest(forOp);
}
PassResult LoopTiling::runOnFunction(Function *f) {
- std::vector<SmallVector<ForInst *, 6>> bands;
+ std::vector<SmallVector<OpPointer<AffineForOp>, 6>> bands;
getTileableBands(f, &bands);
// Temporary tile sizes.
unsigned tileSize =
clTileSize.getNumOccurrences() > 0 ? clTileSize : kDefaultTileSize;
- for (const auto &band : bands) {
+ for (auto &band : bands) {
SmallVector<unsigned, 6> tileSizes(band.size(), tileSize);
if (tileCodeGen(band, tileSizes)) {
return failure();
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 6d63e4afd2d..86e913bd71f 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -21,6 +21,7 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -70,18 +71,19 @@ struct LoopUnroll : public FunctionPass {
const Optional<bool> unrollFull;
// Callback to obtain unroll factors; if this has a callable target, takes
// precedence over command-line argument or passed argument.
- const std::function<unsigned(const ForInst &)> getUnrollFactor;
+ const std::function<unsigned(ConstOpPointer<AffineForOp>)> getUnrollFactor;
- explicit LoopUnroll(
- Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
- const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr)
+ explicit LoopUnroll(Optional<unsigned> unrollFactor = None,
+ Optional<bool> unrollFull = None,
+ const std::function<unsigned(ConstOpPointer<AffineForOp>)>
+ &getUnrollFactor = nullptr)
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
PassResult runOnFunction(Function *f) override;
/// Unroll this for inst. Returns false if nothing was done.
- bool runOnForInst(ForInst *forInst);
+ bool runOnAffineForOp(OpPointer<AffineForOp> forOp);
static const unsigned kDefaultUnrollFactor = 4;
@@ -96,7 +98,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
- std::vector<ForInst *> loops;
+ std::vector<OpPointer<AffineForOp>> loops;
// This method specialized to encode custom return logic.
using InstListType = llvm::iplist<Instruction>;
@@ -111,20 +113,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return hasInnerLoops;
}
- bool walkForInstPostOrder(ForInst *forInst) {
- bool hasInnerLoops =
- walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end());
- if (!hasInnerLoops)
- loops.push_back(forInst);
- return true;
- }
-
bool walkOpInstPostOrder(OperationInst *opInst) {
+ bool hasInnerLoops = false;
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
- if (walkPostOrder(block.begin(), block.end()))
- return true;
- return false;
+ hasInnerLoops |= walkPostOrder(block.begin(), block.end());
+ if (opInst->isa<AffineForOp>()) {
+ if (!hasInnerLoops)
+ loops.push_back(opInst->cast<AffineForOp>());
+ return true;
+ }
+ return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
@@ -137,14 +136,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
public:
// Store short loops as we walk.
- std::vector<ForInst *> loops;
+ std::vector<OpPointer<AffineForOp>> loops;
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- void visitForInst(ForInst *forInst) {
- Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
+ void visitOperationInst(OperationInst *opInst) {
+ auto forOp = opInst->dyn_cast<AffineForOp>();
+ if (!forOp)
+ return;
+ Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
- loops.push_back(forInst);
+ loops.push_back(forOp);
}
};
@@ -156,8 +158,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
- for (auto *forInst : loops)
- loopUnrollFull(forInst);
+ for (auto forOp : loops)
+ loopUnrollFull(forOp);
return success();
}
@@ -172,8 +174,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
if (loops.empty())
break;
bool unrolled = false;
- for (auto *forInst : loops)
- unrolled |= runOnForInst(forInst);
+ for (auto forOp : loops)
+ unrolled |= runOnAffineForOp(forOp);
if (!unrolled)
// Break out if nothing was unrolled.
break;
@@ -183,29 +185,30 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false
/// otherwise. The default unroll factor is 4.
-bool LoopUnroll::runOnForInst(ForInst *forInst) {
+bool LoopUnroll::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// Use the function callback if one was provided.
if (getUnrollFactor) {
- return loopUnrollByFactor(forInst, getUnrollFactor(*forInst));
+ return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
}
// Unroll by the factor passed, if any.
if (unrollFactor.hasValue())
- return loopUnrollByFactor(forInst, unrollFactor.getValue());
+ return loopUnrollByFactor(forOp, unrollFactor.getValue());
// Unroll by the command line factor if one was specified.
if (clUnrollFactor.getNumOccurrences() > 0)
- return loopUnrollByFactor(forInst, clUnrollFactor);
+ return loopUnrollByFactor(forOp, clUnrollFactor);
// Unroll completely if full loop unroll was specified.
if (clUnrollFull.getNumOccurrences() > 0 ||
(unrollFull.hasValue() && unrollFull.getValue()))
- return loopUnrollFull(forInst);
+ return loopUnrollFull(forOp);
// Unroll by four otherwise.
- return loopUnrollByFactor(forInst, kDefaultUnrollFactor);
+ return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
}
FunctionPass *mlir::createLoopUnrollPass(
int unrollFactor, int unrollFull,
- const std::function<unsigned(const ForInst &)> &getUnrollFactor) {
+ const std::function<unsigned(ConstOpPointer<AffineForOp>)>
+ &getUnrollFactor) {
return new LoopUnroll(
unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 7deaf850362..7327a37ee3a 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -43,6 +43,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -80,7 +81,7 @@ struct LoopUnrollAndJam : public FunctionPass {
unrollJamFactor(unrollJamFactor) {}
PassResult runOnFunction(Function *f) override;
- bool runOnForInst(ForInst *forInst);
+ bool runOnAffineForOp(OpPointer<AffineForOp> forOp);
static char passID;
};
@@ -95,47 +96,51 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
// Currently, just the outermost loop from the first loop nest is
- // unroll-and-jammed by this pass. However, runOnForInst can be called on any
- // for Inst.
+ // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
+ // any for Inst.
auto &entryBlock = f->front();
if (!entryBlock.empty())
- if (auto *forInst = dyn_cast<ForInst>(&entryBlock.front()))
- runOnForInst(forInst);
+ if (auto forOp =
+ cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>())
+ runOnAffineForOp(forOp);
return success();
}
/// Unroll and jam a 'for' inst. Default unroll jam factor is
/// kDefaultUnrollJamFactor. Return false if nothing was done.
-bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) {
+bool LoopUnrollAndJam::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// Unroll and jam by the factor that was passed if any.
if (unrollJamFactor.hasValue())
- return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue());
+ return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue());
// Otherwise, unroll jam by the command-line factor if one was specified.
if (clUnrollJamFactor.getNumOccurrences() > 0)
- return loopUnrollJamByFactor(forInst, clUnrollJamFactor);
+ return loopUnrollJamByFactor(forOp, clUnrollJamFactor);
// Unroll and jam by four otherwise.
- return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor);
+ return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor);
}
-bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
+bool mlir::loopUnrollJamUpToFactor(OpPointer<AffineForOp> forOp,
+ uint64_t unrollJamFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
- return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue());
- return loopUnrollJamByFactor(forInst, unrollJamFactor);
+ return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
+ return loopUnrollJamByFactor(forOp, unrollJamFactor);
}
/// Unrolls and jams this loop by the specified factor.
-bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
+bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
+ uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of instructions that do not themselves
// include a for inst (a instruction could have a descendant for inst though
// in its tree).
class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
public:
using InstListType = llvm::iplist<Instruction>;
+ using InstWalker<JamBlockGatherer>::walk;
// Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
@@ -144,30 +149,30 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
- while (it != End && !isa<ForInst>(it))
+ while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>())
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for insts that appear next.
- while (it != End && isa<ForInst>(it))
- walkForInst(cast<ForInst>(it++));
+ while (it != End && cast<OperationInst>(it)->isa<AffineForOp>())
+ walk(&*it++);
}
}
};
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
- if (unrollJamFactor == 1 || forInst->getBody()->empty())
+ if (unrollJamFactor == 1 || forOp->getBody()->empty())
return false;
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (!mayBeConstantTripCount.hasValue() &&
- getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0)
+ getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0)
return false;
- auto lbMap = forInst->getLowerBoundMap();
- auto ubMap = forInst->getUpperBoundMap();
+ auto lbMap = forOp->getLowerBoundMap();
+ auto ubMap = forOp->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@@ -178,7 +183,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different sets of operands.
- if (!forInst->matchingBoundOperandList())
+ if (!forOp->matchingBoundOperandList())
return false;
// If the trip count is lower than the unroll jam factor, no unroll jam.
@@ -187,35 +192,38 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
mayBeConstantTripCount.getValue() < unrollJamFactor)
return false;
+ auto *forInst = forOp->getInstruction();
+
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
- jbg.walkForInst(forInst);
+ jbg.walkOpInst(forInst);
auto &subBlocks = jbg.subBlocks;
// Generate the cleanup loop if trip count isn't a multiple of
// unrollJamFactor.
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
- // Insert the cleanup loop right after 'forInst'.
+ // Insert the cleanup loop right after 'forOp'.
FuncBuilder builder(forInst->getBlock(),
std::next(Block::iterator(forInst)));
- auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst));
- cleanupForInst->setLowerBoundMap(
- getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder));
+ auto cleanupAffineForOp =
+ cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
+ cleanupAffineForOp->setLowerBoundMap(
+ getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder));
// The upper bound needs to be adjusted.
- forInst->setUpperBoundMap(
- getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder));
+ forOp->setUpperBoundMap(
+ getUnrolledLoopUpperBound(forOp, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(cleanupForInst);
+ promoteIfSingleIteration(cleanupAffineForOp);
}
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
- int64_t step = forInst->getStep();
- forInst->setStep(step * unrollJamFactor);
+ int64_t step = forOp->getStep();
+ forOp->setStep(step * unrollJamFactor);
- auto *forInstIV = forInst->getInductionVar();
+ auto *forOpIV = forOp->getInductionVar();
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
// sub-block.
@@ -227,13 +235,13 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forInstIV->use_empty()) {
+ if (!forOpIV->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
- auto ivUnroll = builder.create<AffineApplyOp>(forInst->getLoc(),
- bumpMap, forInstIV);
- operandMapping.map(forInstIV, ivUnroll);
+ auto ivUnroll =
+ builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV);
+ operandMapping.map(forOpIV, ivUnroll);
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
@@ -243,7 +251,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
}
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(forInst);
+ promoteIfSingleIteration(forOp);
return true;
}
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index f770684f519..24ca4e95082 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
@@ -246,7 +247,7 @@ public:
LowerAffinePass() : FunctionPass(&passID) {}
PassResult runOnFunction(Function *function) override;
- bool lowerForInst(ForInst *forInst);
+ bool lowerAffineFor(OpPointer<AffineForOp> forOp);
bool lowerAffineIf(AffineIfOp *ifOp);
bool lowerAffineApply(AffineApplyOp *op);
@@ -295,11 +296,11 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
// a nested loop). Induction variable modification is appended to the body SESE
// region that always loops back to the condition block.
//
-// +--------------------------------+
-// | <code before the ForInst> |
-// | <compute initial %iv value> |
-// | br cond(%iv) |
-// +--------------------------------+
+// +---------------------------------+
+// | <code before the AffineForOp> |
+// | <compute initial %iv value> |
+// | br cond(%iv) |
+// +---------------------------------+
// |
// -------| |
// | v v
@@ -322,11 +323,12 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
// v
// +--------------------------------+
// | end: |
-// | <code after the ForInst> |
+// | <code after the AffineForOp> |
// +--------------------------------+
//
-bool LowerAffinePass::lowerForInst(ForInst *forInst) {
- auto loc = forInst->getLoc();
+bool LowerAffinePass::lowerAffineFor(OpPointer<AffineForOp> forOp) {
+ auto loc = forOp->getLoc();
+ auto *forInst = forOp->getInstruction();
// Start by splitting the block containing the 'for' into two parts. The part
// before will get the init code, the part after will be the end point.
@@ -339,23 +341,23 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
conditionBlock->insertBefore(endBlock);
auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext()));
- // Create the body block, moving the body of the forInst over to it.
+ // Create the body block, moving the body of the forOp over to it.
auto *bodyBlock = new Block();
bodyBlock->insertBefore(endBlock);
- auto *oldBody = forInst->getBody();
+ auto *oldBody = forOp->getBody();
bodyBlock->getInstructions().splice(bodyBlock->begin(),
oldBody->getInstructions(),
oldBody->begin(), oldBody->end());
- // The code in the body of the forInst now uses 'iv' as its indvar.
- forInst->getInductionVar()->replaceAllUsesWith(iv);
+ // The code in the body of the forOp now uses 'iv' as its indvar.
+ forOp->getInductionVar()->replaceAllUsesWith(iv);
// Append the induction variable stepping logic and branch back to the exit
// condition block. Construct an affine expression f : (x -> x+step) and
// apply this expression to the induction variable.
FuncBuilder builder(bodyBlock);
- auto affStep = builder.getAffineConstantExpr(forInst->getStep());
+ auto affStep = builder.getAffineConstantExpr(forOp->getStep());
auto affDim = builder.getAffineDimExpr(0);
auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {});
if (!stepped)
@@ -368,18 +370,18 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
builder.setInsertionPointToEnd(initBlock);
// Compute loop bounds.
- SmallVector<Value *, 8> operands(forInst->getLowerBoundOperands());
+ SmallVector<Value *, 8> operands(forOp->getLowerBoundOperands());
auto lbValues = expandAffineMap(&builder, forInst->getLoc(),
- forInst->getLowerBoundMap(), operands);
+ forOp->getLowerBoundMap(), operands);
if (!lbValues)
return true;
Value *lowerBound =
buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder);
- operands.assign(forInst->getUpperBoundOperands().begin(),
- forInst->getUpperBoundOperands().end());
+ operands.assign(forOp->getUpperBoundOperands().begin(),
+ forOp->getUpperBoundOperands().end());
auto ubValues = expandAffineMap(&builder, forInst->getLoc(),
- forInst->getUpperBoundMap(), operands);
+ forOp->getUpperBoundMap(), operands);
if (!ubValues)
return true;
Value *upperBound =
@@ -394,7 +396,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
endBlock, ArrayRef<Value *>());
// Ok, we're done!
- forInst->erase();
+ forOp->erase();
return false;
}
@@ -614,28 +616,26 @@ PassResult LowerAffinePass::runOnFunction(Function *function) {
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
// We do this as a prepass to avoid invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
- if (isa<ForInst>(inst))
- instsToRewrite.push_back(inst);
- auto op = dyn_cast<OperationInst>(inst);
- if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
+ auto op = cast<OperationInst>(inst);
+ if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
+ op->isa<AffineIfOp>())
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
- for (auto *inst : instsToRewrite)
- if (auto *forInst = dyn_cast<ForInst>(inst)) {
- if (lowerForInst(forInst))
+ for (auto *inst : instsToRewrite) {
+ auto op = cast<OperationInst>(inst);
+ if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
+ if (lowerAffineIf(ifOp))
return failure();
- } else {
- auto op = cast<OperationInst>(inst);
- if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
- if (lowerAffineIf(ifOp))
- return failure();
- } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ } else if (auto forOp = op->dyn_cast<AffineForOp>()) {
+ if (lowerAffineFor(forOp))
return failure();
- }
+ } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ return failure();
}
+ }
return success();
}
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 432ad1f39b8..f2dae11112b 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -75,7 +75,7 @@
/// Implementation details
/// ======================
/// The current decisions made by the super-vectorization pass guarantee that
-/// use-def chains do not escape an enclosing vectorized ForInst. In other
+/// use-def chains do not escape an enclosing vectorized AffineForOp. In other
/// words, this pass operates on a scoped program slice. Furthermore, since we
/// do not vectorize in the presence of conditionals for now, sliced chains are
/// guaranteed not to escape the innermost scope, which has to be either the top
@@ -285,13 +285,12 @@ static Value *substitute(Value *v, VectorType hwVectorType,
///
/// The general problem this function solves is as follows:
/// Assume a vector_transfer operation at the super-vector granularity that has
-/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates
-/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of
-/// rank `h`.
-/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the
-/// super-vector is vector<3x32xf32> and the hardware vector is vector<8xf32>.
-/// Assume the following MLIR snippet after super-vectorization has been
-/// applied:
+/// `l` enclosing loops (AffineForOp). Assume the vector transfer operation
+/// operates on a MemRef of rank `r`, a super-vector of rank `s` and a hardware
+/// vector of rank `h`. For the purpose of illustration assume l==4, r==3, s==2,
+/// h==1 and that the super-vector is vector<3x32xf32> and the hardware vector
+/// is vector<8xf32>. Assume the following MLIR snippet after
+/// super-vectorization has been applied:
///
/// ```mlir
/// for %i0 = 0 to %M {
@@ -351,7 +350,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
SmallVector<AffineExpr, 8> affineExprs;
// TODO(ntv): support a concrete map and composition.
unsigned i = 0;
- // The first numMemRefIndices correspond to ForInst that have not been
+ // The first numMemRefIndices correspond to AffineForOp that have not been
// vectorized, the transformation is the identity on those.
for (i = 0; i < numMemRefIndices; ++i) {
auto d_i = b->getAffineDimExpr(i);
@@ -554,9 +553,6 @@ static bool instantiateMaterialization(Instruction *inst,
MaterializationState *state) {
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst);
- if (isa<ForInst>(inst))
- return inst->emitError("NYI path ForInst");
-
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 811741d08d1..2e083bbfd79 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -21,11 +21,11 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@@ -38,15 +38,12 @@ using namespace mlir;
namespace {
-struct PipelineDataTransfer : public FunctionPass,
- InstWalker<PipelineDataTransfer> {
+struct PipelineDataTransfer : public FunctionPass {
PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {}
PassResult runOnFunction(Function *f) override;
- PassResult runOnForInst(ForInst *forInst);
+ PassResult runOnAffineForOp(OpPointer<AffineForOp> forOp);
- // Collect all 'for' instructions.
- void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
- std::vector<ForInst *> forInsts;
+ std::vector<OpPointer<AffineForOp>> forOps;
static char passID;
};
@@ -79,8 +76,8 @@ static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
/// of the old memref by the new one while indexing the newly added dimension by
/// the loop IV of the specified 'for' instruction modulo 2. Returns false if
/// such a replacement cannot be performed.
-static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
- auto *forBody = forInst->getBody();
+static bool doubleBuffer(Value *oldMemRef, OpPointer<AffineForOp> forOp) {
+ auto *forBody = forOp->getBody();
FuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
@@ -101,6 +98,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
auto newMemRefType = doubleShape(oldMemRefType);
// Put together alloc operands for the dynamic dimensions of the memref.
+ auto *forInst = forOp->getInstruction();
FuncBuilder bOuter(forInst);
SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0;
@@ -118,16 +116,16 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
- int64_t step = forInst->getStep();
+ int64_t step = forOp->getStep();
auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0,
{d0.floorDiv(step) % 2}, {});
- auto ivModTwoOp = bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap,
- forInst->getInductionVar());
+ auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp->getLoc(), modTwoMap,
+ forOp->getInductionVar());
- // replaceAllMemRefUsesWith will always succeed unless the forInst body has
+ // replaceAllMemRefUsesWith will always succeed unless the forOp body has
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(),
- {}, &*forInst->getBody()->begin())) {
+ {}, &*forOp->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getInstruction()->erase();
@@ -143,11 +141,14 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
- forInsts.clear();
- walkPostOrder(f);
+ forOps.clear();
+ f->walkOpsPostOrder([&](OperationInst *opInst) {
+ if (auto forOp = opInst->dyn_cast<AffineForOp>())
+ forOps.push_back(forOp);
+ });
bool ret = false;
- for (auto *forInst : forInsts) {
- ret = ret | runOnForInst(forInst);
+ for (auto forOp : forOps) {
+ ret = ret | runOnAffineForOp(forOp);
}
return ret ? failure() : success();
}
@@ -178,13 +179,13 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
// Identify matching DMA start/finish instructions to overlap computation with.
static void findMatchingStartFinishInsts(
- ForInst *forInst,
+ OpPointer<AffineForOp> forOp,
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
&startWaitPairs) {
// Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
- for (auto &inst : *forInst->getBody()) {
+ for (auto &inst : *forOp->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
@@ -195,7 +196,7 @@ static void findMatchingStartFinishInsts(
}
SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
- for (auto &inst : *forInst->getBody()) {
+ for (auto &inst : *forOp->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
@@ -227,7 +228,7 @@ static void findMatchingStartFinishInsts(
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) {
+ if (!forOp->getBody()->findAncestorInstInBlock(*use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@@ -251,17 +252,18 @@ static void findMatchingStartFinishInsts(
}
/// Overlap DMA transfers with computation in this loop. If successful,
-/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// inserted right before where it was.
-PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
- auto mayBeConstTripCount = getConstantTripCount(*forInst);
+PassResult
+PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
+ auto mayBeConstTripCount = getConstantTripCount(forOp);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return success();
}
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
- findMatchingStartFinishInsts(forInst, startWaitPairs);
+ findMatchingStartFinishInsts(forOp, startWaitPairs);
if (startWaitPairs.empty()) {
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
@@ -280,7 +282,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
auto *dmaStartInst = pair.first;
Value *oldMemRef = dmaStartInst->getOperand(
dmaStartInst->cast<DmaStartOp>()->getFasterMemPos());
- if (!doubleBuffer(oldMemRef, forInst)) {
+ if (!doubleBuffer(oldMemRef, forOp)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
@@ -302,7 +304,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
auto *dmaFinishInst = pair.second;
Value *oldTagMemRef =
dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
- if (!doubleBuffer(oldTagMemRef, forInst)) {
+ if (!doubleBuffer(oldTagMemRef, forOp)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success();
}
@@ -315,7 +317,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
// Double buffering would have invalidated all the old DMA start/wait insts.
startWaitPairs.clear();
- findMatchingStartFinishInsts(forInst, startWaitPairs);
+ findMatchingStartFinishInsts(forOp, startWaitPairs);
// Store shift for instruction for later lookup for AffineApplyOp's.
DenseMap<const Instruction *, unsigned> instShiftMap;
@@ -342,16 +344,16 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
- for (const auto &inst : *forInst->getBody()) {
+ for (const auto &inst : *forOp->getBody()) {
if (instShiftMap.find(&inst) == instShiftMap.end()) {
instShiftMap[&inst] = 1;
}
}
// Get shifts stored in map.
- std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size());
+ std::vector<uint64_t> shifts(forOp->getBody()->getInstructions().size());
unsigned s = 0;
- for (auto &inst : *forInst->getBody()) {
+ for (auto &inst : *forOp->getBody()) {
assert(instShiftMap.find(&inst) != instShiftMap.end());
shifts[s++] = instShiftMap[&inst];
LLVM_DEBUG(
@@ -363,13 +365,13 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
});
}
- if (!isInstwiseShiftValid(*forInst, shifts)) {
+ if (!isInstwiseShiftValid(forOp, shifts)) {
// Violates dependences.
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
return success();
}
- if (instBodySkew(forInst, shifts)) {
+ if (instBodySkew(forOp, shifts)) {
LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";);
return success();
}
diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
index ba59123c700..ae003b3e495 100644
--- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -22,6 +22,7 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Instructions.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 59da2b0a56e..ce16656243d 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -21,6 +21,7 @@
#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -39,22 +40,22 @@ using namespace mlir;
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
/// the specified trip count, stride, and unroll factor. Returns nullptr when
/// the trip count can't be expressed as an affine expression.
-AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst,
+AffineMap mlir::getUnrolledLoopUpperBound(ConstOpPointer<AffineForOp> forOp,
unsigned unrollFactor,
FuncBuilder *builder) {
- auto lbMap = forInst.getLowerBoundMap();
+ auto lbMap = forOp->getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap();
// Sometimes, the trip count cannot be expressed as an affine expression.
- auto tripCount = getTripCountExpr(forInst);
+ auto tripCount = getTripCountExpr(forOp);
if (!tripCount)
return AffineMap();
AffineExpr lb(lbMap.getResult(0));
- unsigned step = forInst.getStep();
+ unsigned step = forOp->getStep();
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
@@ -65,50 +66,51 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst,
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
/// Returns an AffinMap with nullptr storage (that evaluates to false)
/// when the trip count can't be expressed as an affine expression.
-AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst,
+AffineMap mlir::getCleanupLoopLowerBound(ConstOpPointer<AffineForOp> forOp,
unsigned unrollFactor,
FuncBuilder *builder) {
- auto lbMap = forInst.getLowerBoundMap();
+ auto lbMap = forOp->getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap();
// Sometimes the trip count cannot be expressed as an affine expression.
- AffineExpr tripCount(getTripCountExpr(forInst));
+ AffineExpr tripCount(getTripCountExpr(forOp));
if (!tripCount)
return AffineMap();
AffineExpr lb(lbMap.getResult(0));
- unsigned step = forInst.getStep();
+ unsigned step = forOp->getStep();
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
{newLb}, {});
}
-/// Promotes the loop body of a forInst to its containing block if the forInst
+/// Promotes the loop body of a forOp to its containing block if the forOp
/// was known to have a single iteration. Returns false otherwise.
// TODO(bondhugula): extend this for arbitrary affine bounds.
-bool mlir::promoteIfSingleIteration(ForInst *forInst) {
- Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
+bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
+ Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount.hasValue() || tripCount.getValue() != 1)
return false;
// TODO(mlir-team): there is no builder for a max.
- if (forInst->getLowerBoundMap().getNumResults() != 1)
+ if (forOp->getLowerBoundMap().getNumResults() != 1)
return false;
// Replaces all IV uses to its single iteration value.
- auto *iv = forInst->getInductionVar();
+ auto *iv = forOp->getInductionVar();
+ OperationInst *forInst = forOp->getInstruction();
if (!iv->use_empty()) {
- if (forInst->hasConstantLowerBound()) {
+ if (forOp->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
FuncBuilder topBuilder(mlFunc);
auto constOp = topBuilder.create<ConstantIndexOp>(
- forInst->getLoc(), forInst->getConstantLowerBound());
+ forOp->getLoc(), forOp->getConstantLowerBound());
iv->replaceAllUsesWith(constOp);
} else {
- const AffineBound lb = forInst->getLowerBound();
+ const AffineBound lb = forOp->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
if (lb.getMap() == builder.getDimIdentityMap()) {
@@ -124,8 +126,8 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) {
// Move the loop body instructions to the loop's containing block.
auto *block = forInst->getBlock();
block->getInstructions().splice(Block::iterator(forInst),
- forInst->getBody()->getInstructions());
- forInst->erase();
+ forOp->getBody()->getInstructions());
+ forOp->erase();
return true;
}
@@ -133,13 +135,10 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) {
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
- class LoopBodyPromoter : public InstWalker<LoopBodyPromoter> {
- public:
- void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); }
- };
-
- LoopBodyPromoter fsw;
- fsw.walkPostOrder(f);
+ f->walkOpsPostOrder([](OperationInst *inst) {
+ if (auto forOp = inst->dyn_cast<AffineForOp>())
+ promoteIfSingleIteration(forOp);
+ });
}
/// Generates a 'for' inst with the specified lower and upper bounds while
@@ -149,19 +148,22 @@ void mlir::promoteSingleIterationLoops(Function *f) {
/// the pair specifies the shift applied to that group of instructions; note
/// that the shift is multiplied by the loop step before being applied. Returns
/// nullptr if the generated loop simplifies to a single iteration one.
-static ForInst *
+static OpPointer<AffineForOp>
generateLoop(AffineMap lbMap, AffineMap ubMap,
const std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>>
&instGroupQueue,
- unsigned offset, ForInst *srcForInst, FuncBuilder *b) {
+ unsigned offset, OpPointer<AffineForOp> srcForInst,
+ FuncBuilder *b) {
SmallVector<Value *, 4> lbOperands(srcForInst->getLowerBoundOperands());
SmallVector<Value *, 4> ubOperands(srcForInst->getUpperBoundOperands());
assert(lbMap.getNumInputs() == lbOperands.size());
assert(ubMap.getNumInputs() == ubOperands.size());
- auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
- ubOperands, ubMap, srcForInst->getStep());
+ auto loopChunk =
+ b->create<AffineForOp>(srcForInst->getLoc(), lbOperands, lbMap,
+ ubOperands, ubMap, srcForInst->getStep());
+ loopChunk->createBody();
auto *loopChunkIV = loopChunk->getInductionVar();
auto *srcIV = srcForInst->getInductionVar();
@@ -176,7 +178,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
if (!srcIV->use_empty() && shift != 0) {
- auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
+ FuncBuilder b(loopChunk->getBody());
auto ivRemap = b.create<AffineApplyOp>(
srcForInst->getLoc(),
b.getSingleDimShiftAffineMap(
@@ -191,7 +193,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
}
}
if (promoteIfSingleIteration(loopChunk))
- return nullptr;
+ return OpPointer<AffineForOp>();
return loopChunk;
}
@@ -210,28 +212,29 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// asserts preservation of SSA dominance. A check for that as well as that for
// memory-based depedence preservation check rests with the users of this
// method.
-UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
+UtilResult mlir::instBodySkew(OpPointer<AffineForOp> forOp,
+ ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
- if (forInst->getBody()->empty())
+ if (forOp->getBody()->empty())
return UtilResult::Success;
// If the trip counts aren't constant, we would need versioning and
// conditional guards (or context information to prevent such versioning). The
// better way to pipeline for such loops is to first tile them and extract
// constant trip count "full tiles" before applying this.
- auto mayBeConstTripCount = getConstantTripCount(*forInst);
+ auto mayBeConstTripCount = getConstantTripCount(forOp);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
return UtilResult::Success;
}
uint64_t tripCount = mayBeConstTripCount.getValue();
- assert(isInstwiseShiftValid(*forInst, shifts) &&
+ assert(isInstwiseShiftValid(forOp, shifts) &&
"shifts will lead to an invalid transformation\n");
- int64_t step = forInst->getStep();
+ int64_t step = forOp->getStep();
- unsigned numChildInsts = forInst->getBody()->getInstructions().size();
+ unsigned numChildInsts = forOp->getBody()->getInstructions().size();
// Do a linear time (counting) sort for the shifts.
uint64_t maxShift = 0;
@@ -249,7 +252,7 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
// body of the 'for' inst.
std::vector<std::vector<Instruction *>> sortedInstGroups(maxShift + 1);
unsigned pos = 0;
- for (auto &inst : *forInst->getBody()) {
+ for (auto &inst : *forOp->getBody()) {
auto shift = shifts[pos++];
sortedInstGroups[shift].push_back(&inst);
}
@@ -259,17 +262,17 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
// Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
// loop generated as the prologue and the last as epilogue and unroll these
// fully.
- ForInst *prologue = nullptr;
- ForInst *epilogue = nullptr;
+ OpPointer<AffineForOp> prologue;
+ OpPointer<AffineForOp> epilogue;
// Do a sweep over the sorted shifts while storing open groups in a
// vector, and generating loop portions as necessary during the sweep. A block
// of instructions is paired with its shift.
std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> instGroupQueue;
- auto origLbMap = forInst->getLowerBoundMap();
+ auto origLbMap = forOp->getLowerBoundMap();
uint64_t lbShift = 0;
- FuncBuilder b(forInst);
+ FuncBuilder b(forOp->getInstruction());
for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
// If nothing is shifted by d, continue.
if (sortedInstGroups[d].empty())
@@ -280,19 +283,19 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
// The interval for which the loop needs to be generated here is:
// [lbShift, min(lbShift + tripCount, d)) and the body of the
// loop needs to have all instructions in instQueue in that order.
- ForInst *res;
+ OpPointer<AffineForOp> res;
if (lbShift + tripCount * step < d * step) {
res = generateLoop(
b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
- instGroupQueue, 0, forInst, &b);
+ instGroupQueue, 0, forOp, &b);
// Entire loop for the queued inst groups generated, empty it.
instGroupQueue.clear();
lbShift += tripCount * step;
} else {
res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
- 0, forInst, &b);
+ 0, forOp, &b);
lbShift = d * step;
}
if (!prologue && res)
@@ -312,60 +315,63 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, ubShift),
- instGroupQueue, i, forInst, &b);
+ instGroupQueue, i, forOp, &b);
lbShift = ubShift;
if (!prologue)
prologue = epilogue;
}
// Erase the original for inst.
- forInst->erase();
+ forOp->erase();
if (unrollPrologueEpilogue && prologue)
loopUnrollFull(prologue);
- if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
+ if (unrollPrologueEpilogue && !epilogue &&
+ epilogue->getInstruction() != prologue->getInstruction())
loopUnrollFull(epilogue);
return UtilResult::Success;
}
/// Unrolls this loop completely.
-bool mlir::loopUnrollFull(ForInst *forInst) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
+bool mlir::loopUnrollFull(OpPointer<AffineForOp> forOp) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue()) {
uint64_t tripCount = mayBeConstantTripCount.getValue();
if (tripCount == 1) {
- return promoteIfSingleIteration(forInst);
+ return promoteIfSingleIteration(forOp);
}
- return loopUnrollByFactor(forInst, tripCount);
+ return loopUnrollByFactor(forOp, tripCount);
}
return false;
}
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant) whichever is lower.
-bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
+bool mlir::loopUnrollUpToFactor(OpPointer<AffineForOp> forOp,
+ uint64_t unrollFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollFactor)
- return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue());
- return loopUnrollByFactor(forInst, unrollFactor);
+ return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
+ return loopUnrollByFactor(forOp, unrollFactor);
}
/// Unrolls this loop by the specified factor. Returns true if the loop
/// is successfully unrolled.
-bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
+bool mlir::loopUnrollByFactor(OpPointer<AffineForOp> forOp,
+ uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
if (unrollFactor == 1)
- return promoteIfSingleIteration(forInst);
+ return promoteIfSingleIteration(forOp);
- if (forInst->getBody()->empty())
+ if (forOp->getBody()->empty())
return false;
- auto lbMap = forInst->getLowerBoundMap();
- auto ubMap = forInst->getUpperBoundMap();
+ auto lbMap = forOp->getLowerBoundMap();
+ auto ubMap = forOp->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@@ -376,10 +382,10 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different operand lists.
- if (!forInst->matchingBoundOperandList())
+ if (!forOp->matchingBoundOperandList())
return false;
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -388,10 +394,12 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
- if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) {
+ OperationInst *forInst = forOp->getInstruction();
+ if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
- auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst));
- auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder);
+ auto cleanupForInst =
+ cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
+ auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "
"always be determined");
@@ -401,50 +409,50 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
// Adjust upper bound.
auto unrolledUbMap =
- getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder);
+ getUnrolledLoopUpperBound(forOp, unrollFactor, &builder);
assert(unrolledUbMap &&
"upper bound map can alwayys be determined for an unrolled loop "
"with single result bounds");
- forInst->setUpperBoundMap(unrolledUbMap);
+ forOp->setUpperBoundMap(unrolledUbMap);
}
// Scale the step of loop being unrolled by unroll factor.
- int64_t step = forInst->getStep();
- forInst->setStep(step * unrollFactor);
+ int64_t step = forOp->getStep();
+ forOp->setStep(step * unrollFactor);
// Builder to insert unrolled bodies right after the last instruction in the
- // body of 'forInst'.
- FuncBuilder builder(forInst->getBody(), forInst->getBody()->end());
+ // body of 'forOp'.
+ FuncBuilder builder(forOp->getBody(), forOp->getBody()->end());
// Keep a pointer to the last instruction in the original block so that we
// know what to clone (since we are doing this in-place).
- Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
+ Block::iterator srcBlockEnd = std::prev(forOp->getBody()->end());
- // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
- auto *forInstIV = forInst->getInductionVar();
+ // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies).
+ auto *forOpIV = forOp->getInductionVar();
for (unsigned i = 1; i < unrollFactor; i++) {
BlockAndValueMapping operandMap;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forInstIV->use_empty()) {
+ if (!forOpIV->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto ivUnroll =
- builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV);
- operandMap.map(forInstIV, ivUnroll);
+ builder.create<AffineApplyOp>(forOp->getLoc(), bumpMap, forOpIV);
+ operandMap.map(forOpIV, ivUnroll);
}
- // Clone the original body of 'forInst'.
- for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd);
+ // Clone the original body of 'forOp'.
+ for (auto it = forOp->getBody()->begin(); it != std::next(srcBlockEnd);
it++) {
builder.clone(*it, operandMap);
}
}
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(forInst);
+ promoteIfSingleIteration(forOp);
return true;
}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index d3689d056d6..819f1a59b6f 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/Utils.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Dominance.h"
@@ -278,8 +279,8 @@ void mlir::createAffineComputationSlice(
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens for any of
/// the bounds, true otherwise.
-bool mlir::constantFoldBounds(ForInst *forInst) {
- auto foldLowerOrUpperBound = [forInst](bool lower) {
+bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) {
+ auto foldLowerOrUpperBound = [&forInst](bool lower) {
// Check if the bound is already a constant.
if (lower && forInst->hasConstantLowerBound())
return true;
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index ac551d7c20c..7f26161e520 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/VectorAnalysis.h"
@@ -252,9 +253,9 @@ using namespace mlir;
/// ==========
/// The algorithm proceeds in a few steps:
/// 1. defining super-vectorization patterns and matching them on the tree of
-/// ForInst. A super-vectorization pattern is defined as a recursive data
-/// structures that matches and captures nested, imperfectly-nested loops
-/// that have a. comformable loop annotations attached (e.g. parallel,
+/// AffineForOp. A super-vectorization pattern is defined as a recursive
+/// data structures that matches and captures nested, imperfectly-nested
+/// loops that have a. comformable loop annotations attached (e.g. parallel,
/// reduction, vectoriable, ...) as well as b. all contiguous load/store
/// operations along a specified minor dimension (not necessarily the
/// fastest varying) ;
@@ -279,11 +280,11 @@ using namespace mlir;
/// it by its vector form. Otherwise, if the scalar value is a constant,
/// it is vectorized into a splat. In all other cases, vectorization for
/// the pattern currently fails.
-/// e. if everything under the root ForInst in the current pattern vectorizes
-/// properly, we commit that loop to the IR. Otherwise we discard it and
-/// restore a previously cloned version of the loop. Thanks to the
-/// recursive scoping nature of matchers and captured patterns, this is
-/// transparently achieved by a simple RAII implementation.
+/// e. if everything under the root AffineForOp in the current pattern
+/// vectorizes properly, we commit that loop to the IR. Otherwise we
+/// discard it and restore a previously cloned version of the loop. Thanks
+/// to the recursive scoping nature of matchers and captured patterns,
+/// this is transparently achieved by a simple RAII implementation.
/// f. vectorization is applied on the next pattern in the list. Because
/// pattern interference avoidance is not yet implemented and that we do
/// not support further vectorizing an already vector load we need to
@@ -667,12 +668,13 @@ namespace {
struct VectorizationStrategy {
SmallVector<int64_t, 8> vectorSizes;
- DenseMap<ForInst *, unsigned> loopToVectorDim;
+ DenseMap<Instruction *, unsigned> loopToVectorDim;
};
} // end anonymous namespace
-static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
+static void vectorizeLoopIfProfitable(Instruction *loop,
+ unsigned depthInPattern,
unsigned patternDepth,
VectorizationStrategy *strategy) {
assert(patternDepth > depthInPattern &&
@@ -704,13 +706,13 @@ static bool analyzeProfitability(ArrayRef<NestedMatch> matches,
unsigned depthInPattern, unsigned patternDepth,
VectorizationStrategy *strategy) {
for (auto m : matches) {
- auto *loop = cast<ForInst>(m.getMatchedInstruction());
bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
patternDepth, strategy);
if (fail) {
return fail;
}
- vectorizeLoopIfProfitable(loop, depthInPattern, patternDepth, strategy);
+ vectorizeLoopIfProfitable(m.getMatchedInstruction(), depthInPattern,
+ patternDepth, strategy);
}
return false;
}
@@ -855,8 +857,8 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
/// Coarsens the loops bounds and transforms all remaining load and store
/// operations into the appropriate vector_transfer.
-static bool vectorizeForInst(ForInst *loop, int64_t step,
- VectorizationState *state) {
+static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
+ VectorizationState *state) {
using namespace functional;
loop->setStep(step);
@@ -873,7 +875,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
};
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
SmallVector<NestedMatch, 8> loadAndStoresMatches;
- loadAndStores.match(loop, &loadAndStoresMatches);
+ loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches);
for (auto ls : loadAndStoresMatches) {
auto *opInst = cast<OperationInst>(ls.getMatchedInstruction());
auto load = opInst->dyn_cast<LoadOp>();
@@ -898,7 +900,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
return [fastestVaryingMemRefDimension](const Instruction &forInst) {
- const auto &loop = cast<ForInst>(forInst);
+ auto loop = cast<OperationInst>(forInst).cast<AffineForOp>();
return isVectorizableLoopAlongFastestVaryingMemRefDim(
loop, fastestVaryingMemRefDimension);
};
@@ -912,7 +914,8 @@ static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
/// if all vectorizations in `childrenMatches` have already succeeded
/// recursively in DFS post-order.
static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
- ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction());
+ auto *loopInst = oneMatch.getMatchedInstruction();
+ auto loop = cast<OperationInst>(loopInst)->cast<AffineForOp>();
auto childrenMatches = oneMatch.getMatchedChildren();
// 1. DFS postorder recursion, if any of my children fails, I fail too.
@@ -924,7 +927,7 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
// 2. This loop may have been omitted from vectorization for various reasons
// (e.g. due to the performance model or pattern depth > vector size).
- auto it = state->strategy->loopToVectorDim.find(loop);
+ auto it = state->strategy->loopToVectorDim.find(loopInst);
if (it == state->strategy->loopToVectorDim.end()) {
return false;
}
@@ -939,10 +942,10 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
// exploratory tradeoffs (see top of the file). Apply coarsening, i.e.:
// | ub -> ub
// | step -> step * vectorSize
- LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize
+ LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize
<< " : ");
- LLVM_DEBUG(loop->print(dbgs()));
- return vectorizeForInst(loop, loop->getStep() * vectorSize, state);
+ LLVM_DEBUG(loopInst->print(dbgs()));
+ return vectorizeAffineForOp(loop, loop->getStep() * vectorSize, state);
}
/// Non-root pattern iterates over the matches at this level, calls doVectorize
@@ -1186,7 +1189,8 @@ static bool vectorizeOperations(VectorizationState *state) {
/// Each root may succeed independently but will otherwise clean after itself if
/// anything below it fails.
static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
- auto *loop = cast<ForInst>(m.getMatchedInstruction());
+ auto loop =
+ cast<OperationInst>(m.getMatchedInstruction())->cast<AffineForOp>();
VectorizationState state;
state.strategy = strategy;
@@ -1197,17 +1201,20 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
// vectorizable. If a pattern is not vectorizable anymore, we just skip it.
// TODO(ntv): implement a non-greedy profitability analysis that keeps only
// non-intersecting patterns.
- if (!isVectorizableLoop(*loop)) {
+ if (!isVectorizableLoop(loop)) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
return true;
}
- FuncBuilder builder(loop); // builder to insert in place of loop
- ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop));
+ auto *loopInst = loop->getInstruction();
+ FuncBuilder builder(loopInst);
+ auto clonedLoop =
+ cast<OperationInst>(builder.clone(*loopInst))->cast<AffineForOp>();
+
auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match
/// maintains a clone for handling failure and restores the proper state via
/// RAII.
- ScopeGuard sg2([&fail, loop, clonedLoop]() {
+ ScopeGuard sg2([&fail, &loop, &clonedLoop]() {
if (fail) {
loop->getInductionVar()->replaceAllUsesWith(
clonedLoop->getInductionVar());
@@ -1291,8 +1298,8 @@ PassResult Vectorize::runOnFunction(Function *f) {
if (fail) {
continue;
}
- auto *loop = cast<ForInst>(m.getMatchedInstruction());
- vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy);
+ vectorizeLoopIfProfitable(m.getMatchedInstruction(), 0, patternDepth,
+ &strategy);
// TODO(ntv): if pattern does not apply, report it; alter the
// cost/benefit.
fail = vectorizeRootMatch(m, &strategy);
OpenPOWER on IntegriCloud