diff options
Diffstat (limited to 'mlir/lib/IR/Statement.cpp')
| -rw-r--r-- | mlir/lib/IR/Statement.cpp | 93 |
1 files changed, 44 insertions, 49 deletions
diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 7d9ef98a7de..681eb4dac23 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -262,17 +262,16 @@ bool OperationStmt::isReturn() const { return is<ReturnOp>(); } //===----------------------------------------------------------------------===// ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands, - AffineMap *lbMap, ArrayRef<MLValue *> ubOperands, - AffineMap *ubMap, int64_t step, MLIRContext *context) { - assert(lbOperands.size() == lbMap->getNumInputs() && + AffineMap lbMap, ArrayRef<MLValue *> ubOperands, + AffineMap ubMap, int64_t step) { + assert(lbOperands.size() == lbMap.getNumInputs() && "lower bound operand count does not match the affine map"); - assert(ubOperands.size() == ubMap->getNumInputs() && + assert(ubOperands.size() == ubMap.getNumInputs() && "upper bound operand count does not match the affine map"); assert(step > 0 && "step has to be a positive integer constant"); unsigned numOperands = lbOperands.size() + ubOperands.size(); - ForStmt *stmt = - new ForStmt(location, numOperands, lbMap, ubMap, step, context); + ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step); unsigned i = 0; for (unsigned e = lbOperands.size(); i != e; ++i) @@ -284,30 +283,31 @@ ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands, return stmt; } -ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap, - AffineMap *ubMap, int64_t step, MLIRContext *context) +ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap lbMap, + AffineMap ubMap, int64_t step) : Statement(Kind::For, location), - MLValue(MLValueKind::ForStmt, Type::getIndex(context)), + MLValue(MLValueKind::ForStmt, + Type::getIndex(lbMap.getResult(0).getContext())), StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) { operands.reserve(numOperands); } const AffineBound ForStmt::getLowerBound() const { - return AffineBound(*this, 0, lbMap->getNumInputs(), lbMap); + return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap); } const AffineBound ForStmt::getUpperBound() const { - return AffineBound(*this, lbMap->getNumInputs(), getNumOperands(), ubMap); + return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); } -void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) { - assert(lbOperands.size() == map->getNumInputs()); - assert(map->getNumResults() >= 1 && "bound map has at least one result"); +void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands()); operands.clear(); - operands.reserve(lbOperands.size() + ubMap->getNumInputs()); + operands.reserve(lbOperands.size() + ubMap.getNumInputs()); for (auto *operand : lbOperands) { operands.emplace_back(StmtOperand(this, operand)); } @@ -317,9 +317,9 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) { this->lbMap = map; } -void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) { - assert(ubOperands.size() == map->getNumInputs()); - assert(map->getNumResults() >= 1 && "bound map has at least one result"); +void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands()); @@ -334,34 +334,30 @@ void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) { this->ubMap = map; } -void ForStmt::setLowerBoundMap(AffineMap *map) { - assert(lbMap->getNumDims() == map->getNumDims() && - lbMap->getNumSymbols() == map->getNumSymbols()); - assert(map->getNumResults() >= 1 && "bound map has at least one result"); +void ForStmt::setLowerBoundMap(AffineMap map) { + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); this->lbMap = map; } -void ForStmt::setUpperBoundMap(AffineMap *map) { - assert(ubMap->getNumDims() == map->getNumDims() && - ubMap->getNumSymbols() == map->getNumSymbols()); - assert(map->getNumResults() >= 1 && "bound map has at least one result"); +void ForStmt::setUpperBoundMap(AffineMap map) { + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); this->ubMap = map; } -bool ForStmt::hasConstantLowerBound() const { - return lbMap->isSingleConstant(); -} +bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); } -bool ForStmt::hasConstantUpperBound() const { - return ubMap->isSingleConstant(); -} +bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); } int64_t ForStmt::getConstantLowerBound() const { - return lbMap->getSingleConstantResult(); + return lbMap.getSingleConstantResult(); } int64_t ForStmt::getConstantUpperBound() const { - return ubMap->getSingleConstantResult(); + return ubMap.getSingleConstantResult(); } void ForStmt::setConstantLowerBound(int64_t value) { @@ -373,21 +369,20 @@ void ForStmt::setConstantUpperBound(int64_t value) { } ForStmt::operand_range ForStmt::getLowerBoundOperands() { - return {operand_begin(), - operand_begin() + getLowerBoundMap()->getNumInputs()}; + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; } ForStmt::operand_range ForStmt::getUpperBoundOperands() { - return {operand_begin() + getLowerBoundMap()->getNumInputs(), operand_end()}; + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; } bool ForStmt::matchingBoundOperandList() const { - if (lbMap->getNumDims() != ubMap->getNumDims() || - lbMap->getNumSymbols() != ubMap->getNumSymbols()) + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) return false; - unsigned numOperands = lbMap->getNumInputs(); - for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) { + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { // Compare MLValue *'s. if (getOperand(i) != getOperand(numOperands + i)) return false; @@ -419,11 +414,11 @@ bool ForStmt::constantFoldBound(bool lower) { operandConstants.push_back(operandCst); } - AffineMap *boundMap = lower ? getLowerBoundMap() : getUpperBoundMap(); - assert(boundMap->getNumResults() >= 1 && + AffineMap boundMap = lower ? getLowerBoundMap() : getUpperBoundMap(); + assert(boundMap.getNumResults() >= 1 && "bound maps should have at least one result"); SmallVector<Attribute *, 4> foldedResults; - if (boundMap->constantFold(operandConstants, foldedResults)) + if (boundMap.constantFold(operandConstants, foldedResults)) return true; // Compute the max or min as applicable over the results. @@ -523,14 +518,14 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, } if (auto *forStmt = dyn_cast<ForStmt>(this)) { - auto *lbMap = forStmt->getLowerBoundMap(); - auto *ubMap = forStmt->getUpperBoundMap(); + auto lbMap = forStmt->getLowerBoundMap(); + auto ubMap = forStmt->getUpperBoundMap(); auto *newFor = ForStmt::create( getLoc(), - ArrayRef<MLValue *>(operands).take_front(lbMap->getNumInputs()), lbMap, - ArrayRef<MLValue *>(operands).take_back(ubMap->getNumInputs()), ubMap, - forStmt->getStep(), context); + ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap, + ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap, + forStmt->getStep()); // Remember the induction variable mapping. operandMap[forStmt] = newFor; |

