summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/Instruction.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/Instruction.cpp')
-rw-r--r--mlir/lib/IR/Instruction.cpp50
1 files changed, 41 insertions, 9 deletions
diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp
index b8a3e581329..6d74ed14257 100644
--- a/mlir/lib/IR/Instruction.cpp
+++ b/mlir/lib/IR/Instruction.cpp
@@ -126,9 +126,9 @@ bool Value::isValidSymbol() const {
return op->isValidSymbol();
return false;
}
- // This value is either a function argument or an induction variable.
- // Function argument is ok, induction variable is not.
- return isa<BlockArgument>(this);
+ // Otherwise, the only valid symbol is a function argument.
+ auto *arg = dyn_cast<BlockArgument>(this);
+ return arg && arg->isFunctionArgument();
}
void Instruction::setOperand(unsigned idx, Value *value) {
@@ -635,13 +635,16 @@ ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step)
- : Instruction(Instruction::Kind::For, location),
- Value(Value::Kind::ForInst,
- Type::getIndex(lbMap.getResult(0).getContext())),
- body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
+ : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap),
+ ubMap(ubMap), step(step) {
// The body of a for inst always has one block.
- body.push_back(new Block());
+ auto *bodyEntry = new Block();
+ body.push_back(bodyEntry);
+
+ // Add an argument to the block for the induction variable.
+ bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext()));
+
operands.reserve(numOperands);
}
@@ -777,6 +780,35 @@ void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
v.walkPostOrder(this);
}
+/// Returns the induction variable for this loop.
+Value *ForInst::getInductionVar() { return getBody()->getArgument(0); }
+
+/// Returns if the provided value is the induction variable of a ForInst.
+bool mlir::isForInductionVar(const Value *val) {
+ return getForInductionVarOwner(val) != nullptr;
+}
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+ForInst *mlir::getForInductionVarOwner(Value *val) {
+ const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
+ if (!ivArg || !ivArg->getOwner())
+ return nullptr;
+ return dyn_cast_or_null<ForInst>(
+ ivArg->getOwner()->getParent()->getContainingInst());
+}
+const ForInst *mlir::getForInductionVarOwner(const Value *val) {
+ return getForInductionVarOwner(const_cast<Value *>(val));
+}
+
+/// Extracts the induction variables from a list of ForInsts and returns them.
+SmallVector<Value *, 8>
+mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
+ SmallVector<Value *, 8> results;
+ for (auto *forInst : forInsts)
+ results.push_back(forInst->getInductionVar());
+ return results;
+}
//===----------------------------------------------------------------------===//
// IfInst
//===----------------------------------------------------------------------===//
@@ -909,7 +941,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
ubMap, forInst->getStep());
// Remember the induction variable mapping.
- mapper.map(forInst, newFor);
+ mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
// Recursively clone the body of the for loop.
for (auto &subInst : *forInst->getBody())
OpenPOWER on IntegriCloud