summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/Instructions.h33
-rw-r--r--mlir/include/mlir/IR/Value.h4
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp17
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp2
-rw-r--r--mlir/lib/Analysis/Dominance.cpp11
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp8
-rw-r--r--mlir/lib/Analysis/SliceAnalysis.cpp2
-rw-r--r--mlir/lib/Analysis/Utils.cpp7
-rw-r--r--mlir/lib/Analysis/VectorAnalysis.cpp3
-rw-r--r--mlir/lib/EDSC/MLIREmitter.cpp9
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp84
-rw-r--r--mlir/lib/IR/Instruction.cpp50
-rw-r--r--mlir/lib/IR/Value.cpp8
-rw-r--r--mlir/lib/Parser/Parser.cpp5
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp3
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp13
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp8
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp2
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp4
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp26
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp8
21 files changed, 172 insertions, 135 deletions
diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h
index 8085e720260..71d832b8b90 100644
--- a/mlir/include/mlir/IR/Instructions.h
+++ b/mlir/include/mlir/IR/Instructions.h
@@ -555,19 +555,12 @@ inline auto OperationInst::getResultTypes() const
}
/// For instruction represents an affine loop nest.
-class ForInst : public Instruction, public Value {
+class ForInst : public Instruction {
public:
static ForInst *create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step);
- ~ForInst() {
- // There may be references to the induction variable of this loop within its
- // body or, in case of ill-formed code during parsing, outside its body.
- // Explicitly drop all uses of the induction variable before destroying it.
- dropAllUses();
- }
-
/// Resolve base class ambiguity.
using Instruction::getFunction;
@@ -700,7 +693,9 @@ public:
//===--------------------------------------------------------------------===//
/// Return the context this operation is associated with.
- MLIRContext *getContext() const { return getType().getContext(); }
+ MLIRContext *getContext() const {
+ return getInductionVar()->getType().getContext();
+ }
using Instruction::dump;
using Instruction::print;
@@ -710,11 +705,10 @@ public:
return ptr->getKind() == IROperandOwner::Kind::ForInst;
}
- // For instruction represents implicitly represents induction variable by
- // inheriting from Value class. Whenever you need to refer to the loop
- // induction variable, just use the for instruction itself.
- static bool classof(const Value *value) {
- return value->getKind() == Value::Kind::ForInst;
+ /// Returns the induction variable for this loop.
+ Value *getInductionVar();
+ const Value *getInductionVar() const {
+ return const_cast<ForInst *>(this)->getInductionVar();
}
private:
@@ -738,6 +732,17 @@ private:
AffineMap ubMap, int64_t step);
};
+/// Returns if the provided value is the induction variable of a ForInst.
+bool isForInductionVar(const Value *val);
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+ForInst *getForInductionVarOwner(Value *val);
+const ForInst *getForInductionVarOwner(const Value *val);
+
+/// Extracts the induction variables from a list of ForInsts and returns them.
+SmallVector<Value *, 8> extractForInductionVars(ArrayRef<ForInst *> forInsts);
+
/// AffineBound represents a lower or upper bound in the for instruction.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the ForInst. Its life span should not exceed
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 48af9c71be6..90f1f484b1f 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -45,7 +45,6 @@ public:
enum class Kind {
BlockArgument, // block argument
InstResult, // operation instruction result
- ForInst, // 'for' instruction induction variable
};
~Value() {}
@@ -141,6 +140,9 @@ public:
/// Returns the number of this argument.
unsigned getArgNumber() const;
+ /// Returns if the current argument is a function argument.
+ bool isFunctionArgument() const;
+
private:
friend class Block; // For access to private constructor.
BlockArgument(Type type, Block *owner)
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 1ecad8d4e90..a4d969bc203 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -555,7 +555,7 @@ void mlir::getReachableAffineApplyOps(
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
FlatAffineConstraints *domain) {
- SmallVector<Value *, 4> indices(forInsts.begin(), forInsts.end());
+ auto indices = extractForInductionVars(forInsts);
// Reset while associated Values in 'indices' to the domain.
domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forInst : forInsts) {
@@ -677,7 +677,7 @@ static void buildDimAndSymbolPositionMaps(
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
- if (!isa<ForInst>(values[i])) {
+ if (!isForInductionVar(values[i])) {
assert(values[i]->isValidSymbol() &&
"access operand has to be either a loop IV or a symbol");
valuePosMap->addSymbolValue(value);
@@ -739,7 +739,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
// Set values for the symbolic identifier dimensions.
auto setSymbolIds = [&](ArrayRef<Value *> values) {
for (auto *value : values) {
- if (!isa<ForInst>(value)) {
+ if (!isForInductionVar(value)) {
assert(value->isValidSymbol() && "expected symbol");
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
}
@@ -907,7 +907,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Add equality constraints for any operands that are defined by constant ops.
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (isa<ForInst>(operands[i]))
+ if (isForInductionVar(operands[i]))
continue;
auto *symbol = operands[i];
assert(symbol->isValidSymbol());
@@ -976,8 +976,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
- if (!isa<ForInst>(srcDomain.getIdValue(i)) ||
- !isa<ForInst>(dstDomain.getIdValue(i)) ||
+ if (!isForInductionVar(srcDomain.getIdValue(i)) ||
+ !isForInductionVar(dstDomain.getIdValue(i)) ||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
break;
++numCommonLoops;
@@ -998,8 +998,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess,
return block;
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
- assert(isa<ForInst>(commonForValue));
- return cast<ForInst>(commonForValue)->getBody();
+ auto *forInst = getForInductionVarOwner(commonForValue);
+ assert(forInst && "commonForValue was not an induction variable");
+ return forInst->getBody();
}
// Returns true if the ancestor operation instruction of 'srcAccess' appears
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 268fbe0c9c6..7aa23bbe480 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1251,7 +1251,7 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
unsigned pos;
// Pre-condition for this method.
- if (!findId(forInst, &pos)) {
+ if (!findId(*forInst.getInductionVar(), &pos)) {
assert(0 && "Value not found");
return false;
}
diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp
index b98efa73e54..5cdeebbdf4a 100644
--- a/mlir/lib/Analysis/Dominance.cpp
+++ b/mlir/lib/Analysis/Dominance.cpp
@@ -53,9 +53,9 @@ bool DominanceInfo::properlyDominates(const Block *a, const Block *b) {
if (blockListA == blockListB)
return DominatorTreeBase::properlyDominates(a, b);
- // Otherwise, 'a' properly dominates 'b' if 'b' is defined in an
- // IfInst/ForInst that (recursively) ends up being dominated by 'a'. Walk up
- // the list of containers enclosing B.
+ // Otherwise, 'a' properly dominates 'b' if 'b' is defined in an instruction
+ // region that (recursively) ends up being dominated by 'a'. Walk up the list
+ // of containers enclosing B.
Instruction *bAncestor;
do {
bAncestor = blockListB->getContainingInst();
@@ -106,11 +106,6 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) {
if (auto *aInst = a->getDefiningInst())
return properlyDominates(aInst, b);
- // The induction variable of a ForInst properly dominantes its body, so we
- // can just do a simple block dominance check.
- if (auto *forInst = dyn_cast<ForInst>(a))
- return dominates(forInst->getBody(), b->getBlock());
-
// block arguments properly dominate all instructions in their own block, so
// we use a dominates check here, not a properlyDominates check.
return dominates(cast<BlockArgument>(a)->getOwner(), b->getBlock());
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index b154ebab105..640984bf866 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
}
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
- assert(isa<ForInst>(iv) && "iv must be a ForInst");
+ assert(isForInductionVar(&iv) && "iv must be a ForInst");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
@@ -288,8 +288,10 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
[fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
- return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
- : isContiguousAccess(loop, *store, fastestVaryingDim);
+ return load ? isContiguousAccess(*loop.getInductionVar(), *load,
+ fastestVaryingDim)
+ : isContiguousAccess(*loop.getInductionVar(), *store,
+ fastestVaryingDim);
});
return isVectorizableLoopWithCond(loop, fun);
}
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index a8cec771f0d..d16a7fcb1b3 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -64,7 +64,7 @@ void mlir::getForwardSlice(Instruction *inst,
}
}
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
- for (auto &u : forInst->getUses()) {
+ for (auto &u : forInst->getInductionVar()->getUses()) {
auto *ownerInst = u.getOwner();
if (forwardSlice->count(ownerInst) == 0) {
getForwardSlice(ownerInst, forwardSlice, filter,
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 39e58e8983c..939a2ede618 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -149,7 +149,8 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
// A rank 0 memref has a 0-d region.
SmallVector<ForInst *, 4> ivs;
getLoopIVs(*opInst, &ivs);
- SmallVector<Value *, 4> regionSymbols(ivs.begin(), ivs.end());
+
+ SmallVector<Value *, 8> regionSymbols = extractForInductionVars(ivs);
regionCst->reset(0, loopDepth, 0, regionSymbols);
return true;
}
@@ -172,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
unsigned numSymbols = accessMap.getNumSymbols();
// Add inequalties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
- if (auto *loop = dyn_cast<ForInst>(accessValueMap.getOperand(i))) {
+ if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
// Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
@@ -207,7 +208,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
outerIVs.resize(loopDepth);
for (auto *operand : accessValueMap.getOperands()) {
ForInst *iv;
- if ((iv = dyn_cast<ForInst>(operand)) &&
+ if ((iv = getForInductionVarOwner(operand)) &&
std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
regionCst->projectOut(operand);
}
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp
index 37eed71508f..125020e92a3 100644
--- a/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -113,7 +113,8 @@ static AffineMap makePermutationMap(
getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) {
assert(kvp.second < perm.size());
- auto invariants = getInvariantAccesses(*kvp.first, unwrappedIndices);
+ auto invariants =
+ getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices);
unsigned numIndices = unwrappedIndices.size();
unsigned countInvariantIndices = 0;
for (unsigned dim = 0; dim < numIndices; ++dim) {
diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp
index 56f211ec578..c2a6dc1f90a 100644
--- a/mlir/lib/EDSC/MLIREmitter.cpp
+++ b/mlir/lib/EDSC/MLIREmitter.cpp
@@ -133,9 +133,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
inst->print(os);
return;
}
- // &v is required here otherwise we get:
- // non-pointer operand type 'const mlir::ForInst' incompatible with nullptr
- if (auto *forInst = dyn_cast<ForInst>(&v)) {
+ if (auto *forInst = getForInductionVarOwner(&v)) {
forInst->print(os);
} else {
os << "unknown_ssa_value";
@@ -296,7 +294,7 @@ Value *MLIREmitter::emit(Expr e) {
exprs[1]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
auto step =
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
- res = builder->createFor(location, lb, ub, step);
+ res = builder->createFor(location, lb, ub, step)->getInductionVar();
}
}
@@ -347,7 +345,8 @@ void MLIREmitter::emitStmt(const Stmt &stmt) {
bind(stmt.getLHS(), val);
if (stmt.getRHS().getKind() == ExprKind::For) {
// Step into the loop.
- builder->setInsertionPointToStart(cast<ForInst>(val)->getBody());
+ builder->setInsertionPointToStart(
+ getForInductionVarOwner(val)->getBody());
}
}
emitStmts(stmt.getEnclosedStmts());
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index af996213418..21bc3b824b1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1078,7 +1078,7 @@ public:
void print(const OperationInst *inst);
void print(const ForInst *inst);
void print(const IfInst *inst);
- void print(const Block *block);
+ void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
void printGenericOp(const OperationInst *op);
@@ -1125,10 +1125,15 @@ public:
unsigned index) override;
/// Print a block list.
- void printBlockList(const BlockList &blocks) {
+ void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
os << " {\n";
- for (auto &b : blocks)
- print(&b);
+ if (!blocks.empty()) {
+ auto *entryBlock = &blocks.front();
+ print(entryBlock,
+ printEntryBlockArgs && entryBlock->getNumArguments() != 0);
+ for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
+ print(&b);
+ }
os.indent(currentIndent) << "}";
}
@@ -1164,8 +1169,8 @@ private:
/// This is the next value ID to assign in numbering.
unsigned nextValueID = 0;
- /// This is the ID to assign to the next induction variable.
- unsigned nextLoopID = 0;
+ /// This is the ID to assign to the next region entry block argument.
+ unsigned nextRegionArgumentID = 0;
/// This is the next ID to assign to a Function argument.
unsigned nextArgumentID = 0;
/// This is the next ID to assign when a name conflict is detected.
@@ -1205,14 +1210,10 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
numberValuesInBlock(block);
break;
}
- case Instruction::Kind::For: {
- auto *forInst = cast<ForInst>(&inst);
- // Number the induction variable.
- numberValueID(forInst);
+ case Instruction::Kind::For:
// Recursively number the stuff in the body.
- numberValuesInBlock(*forInst->getBody());
+ numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
- }
case Instruction::Kind::If: {
auto *ifInst = cast<IfInst>(&inst);
numberValuesInBlock(*ifInst->getThen());
@@ -1251,13 +1252,19 @@ void FunctionPrinter::numberValueID(const Value *value) {
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case Value::Kind::BlockArgument:
- // If this is an argument to the function, give it an 'arg' name.
- if (auto *block = cast<BlockArgument>(value)->getOwner())
- if (auto *fn = block->getFunction())
- if (&fn->getBlockList().front() == block) {
+ // If this is an argument to the function, give it an 'arg' name. If the
+ // argument is to an entry block of an operation region, give it an 'i'
+ // name.
+ if (auto *block = cast<BlockArgument>(value)->getOwner()) {
+ auto *parentBlockList = block->getParent();
+ if (parentBlockList && block == &parentBlockList->front()) {
+ if (parentBlockList->getContainingFunction())
specialName << "arg" << nextArgumentID++;
- break;
- }
+ else
+ specialName << "i" << nextRegionArgumentID++;
+ break;
+ }
+ }
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
@@ -1266,9 +1273,6 @@ void FunctionPrinter::numberValueID(const Value *value) {
// done with it.
valueIDs[value] = nextValueID++;
return;
- case Value::Kind::ForInst:
- specialName << 'i' << nextLoopID++;
- break;
}
}
@@ -1312,10 +1316,8 @@ void FunctionPrinter::print() {
printTrailingLocation(function->getLoc());
if (!function->empty()) {
- os << " {\n";
- for (const auto &block : *function)
- print(&block);
- os << "}\n";
+ printBlockList(function->getBlockList(), /*printEntryBlockArgs=*/false);
+ os << "\n";
}
os << '\n';
}
@@ -1357,26 +1359,10 @@ void FunctionPrinter::printFunctionSignature() {
}
}
-/// Return true if the introducer for the specified block should be printed.
-static bool shouldPrintBlockArguments(const Block *block) {
- // Never print the entry block of the function - it is included in the
- // argument list.
- if (block == &block->getFunction()->front())
- return false;
-
- // If this is the first block in a nested region, and if there are no
- // arguments, then we can omit it.
- if (block == &block->getParent()->front() && block->getNumArguments() == 0)
- return false;
-
- // Otherwise print it.
- return true;
-}
-
-void FunctionPrinter::print(const Block *block) {
+void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
// Print the block label and argument list, unless this is the first block of
// the function, or the first block of an IfInst/ForInst with no arguments.
- if (shouldPrintBlockArguments(block)) {
+ if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
@@ -1445,7 +1431,7 @@ void FunctionPrinter::print(const OperationInst *inst) {
void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "for ";
- printOperand(inst);
+ printOperand(inst->getInductionVar());
os << " = ";
printBound(inst->getLowerBound(), "max");
os << " to ";
@@ -1457,7 +1443,7 @@ void FunctionPrinter::print(const ForInst *inst) {
printTrailingLocation(inst->getLoc());
os << " {\n";
- print(inst->getBody());
+ print(inst->getBody(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
@@ -1468,11 +1454,11 @@ void FunctionPrinter::print(const IfInst *inst) {
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
printTrailingLocation(inst->getLoc());
os << " {\n";
- print(inst->getThen());
+ print(inst->getThen(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
if (inst->hasElse()) {
os << " else {\n";
- print(inst->getElse());
+ print(inst->getElse(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
}
@@ -1583,7 +1569,7 @@ void FunctionPrinter::printGenericOp(const OperationInst *op) {
// Print any trailing block lists.
for (auto &blockList : op->getBlockLists())
- printBlockList(blockList);
+ printBlockList(blockList, /*printEntryBlockArgs=*/true);
}
void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
@@ -1729,8 +1715,6 @@ void Value::print(raw_ostream &os) const {
return;
case Value::Kind::InstResult:
return getDefiningInst()->print(os);
- case Value::Kind::ForInst:
- return cast<ForInst>(this)->print(os);
}
}
diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp
index b8a3e581329..6d74ed14257 100644
--- a/mlir/lib/IR/Instruction.cpp
+++ b/mlir/lib/IR/Instruction.cpp
@@ -126,9 +126,9 @@ bool Value::isValidSymbol() const {
return op->isValidSymbol();
return false;
}
- // This value is either a function argument or an induction variable.
- // Function argument is ok, induction variable is not.
- return isa<BlockArgument>(this);
+ // Otherwise, the only valid symbol is a function argument.
+ auto *arg = dyn_cast<BlockArgument>(this);
+ return arg && arg->isFunctionArgument();
}
void Instruction::setOperand(unsigned idx, Value *value) {
@@ -635,13 +635,16 @@ ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step)
- : Instruction(Instruction::Kind::For, location),
- Value(Value::Kind::ForInst,
- Type::getIndex(lbMap.getResult(0).getContext())),
- body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
+ : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap),
+ ubMap(ubMap), step(step) {
// The body of a for inst always has one block.
- body.push_back(new Block());
+ auto *bodyEntry = new Block();
+ body.push_back(bodyEntry);
+
+ // Add an argument to the block for the induction variable.
+ bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext()));
+
operands.reserve(numOperands);
}
@@ -777,6 +780,35 @@ void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
v.walkPostOrder(this);
}
+/// Returns the induction variable for this loop.
+Value *ForInst::getInductionVar() { return getBody()->getArgument(0); }
+
+/// Returns if the provided value is the induction variable of a ForInst.
+bool mlir::isForInductionVar(const Value *val) {
+ return getForInductionVarOwner(val) != nullptr;
+}
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+ForInst *mlir::getForInductionVarOwner(Value *val) {
+ const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
+ if (!ivArg || !ivArg->getOwner())
+ return nullptr;
+ return dyn_cast_or_null<ForInst>(
+ ivArg->getOwner()->getParent()->getContainingInst());
+}
+const ForInst *mlir::getForInductionVarOwner(const Value *val) {
+ return getForInductionVarOwner(const_cast<Value *>(val));
+}
+
+/// Extracts the induction variables from a list of ForInsts and returns them.
+SmallVector<Value *, 8>
+mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
+ SmallVector<Value *, 8> results;
+ for (auto *forInst : forInsts)
+ results.push_back(forInst->getInductionVar());
+ return results;
+}
//===----------------------------------------------------------------------===//
// IfInst
//===----------------------------------------------------------------------===//
@@ -909,7 +941,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
ubMap, forInst->getStep());
// Remember the induction variable mapping.
- mapper.map(forInst, newFor);
+ mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
// Recursively clone the body of the for loop.
for (auto &subInst : *forInst->getBody())
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index a2cb9910ab8..6418b062dc1 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -35,8 +35,6 @@ Function *Value::getFunction() {
return cast<BlockArgument>(this)->getFunction();
case Value::Kind::InstResult:
return getDefiningInst()->getFunction();
- case Value::Kind::ForInst:
- return cast<ForInst>(this)->getFunction();
}
}
@@ -83,3 +81,9 @@ Function *BlockArgument::getFunction() {
return owner->getFunction();
return nullptr;
}
+
+/// Returns if the current argument is a function argument.
+bool BlockArgument::isFunctionArgument() const {
+ auto *containingFn = getFunction();
+ return containingFn && &containingFn->front() == getOwner();
+}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index ecb7fbc779e..c477ad1bbc5 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -3201,7 +3201,8 @@ ParseResult FunctionParser::parseForInst() {
ubOperands, ubMap, step);
// Create SSA value definition for the induction variable.
- if (addDefinition({inductionVariableName, 0, loc}, forInst))
+ if (addDefinition({inductionVariableName, 0, loc},
+ forInst->getInductionVar()))
return ParseFailure;
// Try to parse the optional trailing location.
@@ -3347,7 +3348,7 @@ ParseResult FunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
// Create an identity map using dim id for an induction variable and
// symbol otherwise. This representation is optimized for storage.
// Analysis passes may expand it into a multi-dimensional map if desired.
- if (isa<ForInst>(operands[0]))
+ if (isForInductionVar(operands[0]))
map = builder.getDimIdentityMap();
else
map = builder.getSymbolIdentityMap();
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 0437fb143e0..04eb38e9fc9 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -171,7 +171,8 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols,
getLoopIVs(*opInst, &ivs);
auto *regionCst = region->getConstraints();
- SmallVector<Value *, 4> symbols(ivs.begin(), ivs.end());
+
+ SmallVector<Value *, 8> symbols = extractForInductionVars(ivs);
regionCst->reset(rank, numSymbols, 0, symbols);
// Memref dim sizes provide the bounds.
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 2a4b7bcd262..396fc8eb658 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -103,7 +103,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
// The lower bound is just the tile-space loop.
AffineMap lbMap = b.getDimIdentityMap();
- newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap);
+ newLoops[width + i]->setLowerBound(
+ /*operands=*/newLoops[i]->getInductionVar(), lbMap);
// Set the upper bound.
if (mayBeConstantCount.hasValue() &&
@@ -117,7 +118,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
// with 'i' (tile-space loop) appended to it. The new upper bound map is
// the original one with an additional expression i + tileSize appended.
SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands());
- ubOperands.push_back(newLoops[i]);
+ ubOperands.push_back(newLoops[i]->getInductionVar());
auto origUbMap = origLoops[i]->getUpperBoundMap();
SmallVector<AffineExpr, 4> boundExprs;
@@ -135,7 +136,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
// No need of the min expression.
auto dim = b.getAffineDimExpr(0);
auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
- newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
+ newLoops[width + i]->setUpperBound(newLoops[i]->getInductionVar(), ubMap);
}
}
}
@@ -194,8 +195,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
// Move the loop body of the original nest to the new one.
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
- SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end());
- SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end());
+ SmallVector<Value *, 8> origLoopIVs = extractForInductionVars(band);
+ SmallVector<Optional<Value *>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
FlatAffineConstraints cst;
getIndexSet(band, &cst);
@@ -208,7 +209,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
// In this case, the point loop IVs just replace the original ones.
for (unsigned i = 0; i < width; i++) {
- origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
+ origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]->getInductionVar());
}
// Erase the old loop nest.
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 71d77817254..a8ec57c0426 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -215,6 +215,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
int64_t step = forInst->getStep();
forInst->setStep(step * unrollJamFactor);
+ auto *forInstIV = forInst->getInductionVar();
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
// sub-block.
@@ -226,14 +227,15 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forInst->use_empty()) {
+ if (!forInstIV->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
- builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
+ builder
+ .create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV)
->getResult(0);
- operandMapping.map(forInst, ivUnroll);
+ operandMapping.map(forInstIV, ivUnroll);
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index 99ee603bb05..94f300bd16a 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -348,7 +348,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
oldBody->begin(), oldBody->end());
// The code in the body of the forInst now uses 'iv' as its indvar.
- forInst->replaceAllUsesWith(iv);
+ forInst->getInductionVar()->replaceAllUsesWith(iv);
// Append the induction variable stepping logic and branch back to the exit
// condition block. Construct an affine expression f : (x -> x+step) and
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index e72b9ef80df..0019714b6a3 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -121,8 +121,8 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
int64_t step = forInst->getStep();
auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0,
{d0.floorDiv(step) % 2}, {});
- auto ivModTwoOp =
- bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst);
+ auto ivModTwoOp = bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap,
+ forInst->getInductionVar());
// replaceAllMemRefUsesWith will always succeed unless the forInst body has
// non-deferencing uses of the memref.
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index d41614545d2..03673eaa535 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -99,24 +99,25 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) {
return false;
// Replaces all IV uses to its single iteration value.
- if (!forInst->use_empty()) {
+ auto *iv = forInst->getInductionVar();
+ if (!iv->use_empty()) {
if (forInst->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
FuncBuilder topBuilder(mlFunc);
auto constOp = topBuilder.create<ConstantIndexOp>(
forInst->getLoc(), forInst->getConstantLowerBound());
- forInst->replaceAllUsesWith(constOp);
+ iv->replaceAllUsesWith(constOp);
} else {
const AffineBound lb = forInst->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
if (lb.getMap() == builder.getDimIdentityMap()) {
// No need of generating an affine_apply.
- forInst->replaceAllUsesWith(lbOperands[0]);
+ iv->replaceAllUsesWith(lbOperands[0]);
} else {
auto affineApplyOp = builder.create<AffineApplyOp>(
forInst->getLoc(), lb.getMap(), lbOperands);
- forInst->replaceAllUsesWith(affineApplyOp->getResult(0));
+ iv->replaceAllUsesWith(affineApplyOp->getResult(0));
}
}
}
@@ -161,6 +162,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
ubOperands, ubMap, srcForInst->getStep());
+ auto *loopChunkIV = loopChunk->getInductionVar();
+ auto *srcIV = srcForInst->getInductionVar();
BlockAndValueMapping operandMap;
@@ -172,17 +175,17 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// remapped to results of cloned instructions, and their IV used remapped.
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
- if (!srcForInst->use_empty() && shift != 0) {
+ if (!srcIV->use_empty() && shift != 0) {
auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
auto *ivRemap = b.create<AffineApplyOp>(
srcForInst->getLoc(),
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
srcForInst->getStep() * shift)),
- loopChunk)
+ loopChunkIV)
->getResult(0);
- operandMap.map(srcForInst, ivRemap);
+ operandMap.map(srcIV, ivRemap);
} else {
- operandMap.map(srcForInst, loopChunk);
+ operandMap.map(srcIV, loopChunkIV);
}
for (auto *inst : insts) {
loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext()));
@@ -419,19 +422,20 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
// Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
+ auto *forInstIV = forInst->getInductionVar();
for (unsigned i = 1; i < unrollFactor; i++) {
BlockAndValueMapping operandMap;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forInst->use_empty()) {
+ if (!forInstIV->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
- builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
+ builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV)
->getResult(0);
- operandMap.map(forInst, ivUnroll);
+ operandMap.map(forInstIV, ivUnroll);
}
// Clone the original body of 'forInst'.
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index e9b37fcc04c..cfde1ecf0a8 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -881,8 +881,9 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
auto load = opInst->dyn_cast<LoadOp>();
auto store = opInst->dyn_cast<StoreOp>();
LLVM_DEBUG(opInst->print(dbgs()));
- auto fail = load ? vectorizeRootOrTerminal(loop, load, state)
- : vectorizeRootOrTerminal(loop, store, state);
+ auto fail =
+ load ? vectorizeRootOrTerminal(loop->getInductionVar(), load, state)
+ : vectorizeRootOrTerminal(loop->getInductionVar(), store, state);
if (fail) {
return fail;
}
@@ -1210,7 +1211,8 @@ static bool vectorizeRootMatches(NestedMatch matches,
/// RAII.
ScopeGuard sg2([&fail, loop, clonedLoop]() {
if (fail) {
- loop->replaceAllUsesWith(clonedLoop);
+ loop->getInductionVar()->replaceAllUsesWith(
+ clonedLoop->getInductionVar());
loop->erase();
} else {
clonedLoop->erase();
OpenPOWER on IntegriCloud