diff options
Diffstat (limited to 'mlir/lib/IR/Instruction.cpp')
| -rw-r--r-- | mlir/lib/IR/Instruction.cpp | 50 |
1 files changed, 41 insertions, 9 deletions
diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index b8a3e581329..6d74ed14257 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -126,9 +126,9 @@ bool Value::isValidSymbol() const { return op->isValidSymbol(); return false; } - // This value is either a function argument or an induction variable. - // Function argument is ok, induction variable is not. - return isa<BlockArgument>(this); + // Otherwise, the only valid symbol is a function argument. + auto *arg = dyn_cast<BlockArgument>(this); + return arg && arg->isFunctionArgument(); } void Instruction::setOperand(unsigned idx, Value *value) { @@ -635,13 +635,16 @@ ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands, ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap, AffineMap ubMap, int64_t step) - : Instruction(Instruction::Kind::For, location), - Value(Value::Kind::ForInst, - Type::getIndex(lbMap.getResult(0).getContext())), - body(this), lbMap(lbMap), ubMap(ubMap), step(step) { + : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap), + ubMap(ubMap), step(step) { // The body of a for inst always has one block. - body.push_back(new Block()); + auto *bodyEntry = new Block(); + body.push_back(bodyEntry); + + // Add an argument to the block for the induction variable. + bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext())); + operands.reserve(numOperands); } @@ -777,6 +780,35 @@ void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) { v.walkPostOrder(this); } +/// Returns the induction variable for this loop. +Value *ForInst::getInductionVar() { return getBody()->getArgument(0); } + +/// Returns if the provided value is the induction variable of a ForInst. +bool mlir::isForInductionVar(const Value *val) { + return getForInductionVarOwner(val) != nullptr; +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +ForInst *mlir::getForInductionVarOwner(Value *val) { + const BlockArgument *ivArg = dyn_cast<BlockArgument>(val); + if (!ivArg || !ivArg->getOwner()) + return nullptr; + return dyn_cast_or_null<ForInst>( + ivArg->getOwner()->getParent()->getContainingInst()); +} +const ForInst *mlir::getForInductionVarOwner(const Value *val) { + return getForInductionVarOwner(const_cast<Value *>(val)); +} + +/// Extracts the induction variables from a list of ForInsts and returns them. +SmallVector<Value *, 8> +mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) { + SmallVector<Value *, 8> results; + for (auto *forInst : forInsts) + results.push_back(forInst->getInductionVar()); + return results; +} //===----------------------------------------------------------------------===// // IfInst //===----------------------------------------------------------------------===// @@ -909,7 +941,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, ubMap, forInst->getStep()); // Remember the induction variable mapping. - mapper.map(forInst, newFor); + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); // Recursively clone the body of the for loop. for (auto &subInst : *forInst->getBody()) |

