summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/Statement.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/Statement.cpp')
-rw-r--r--mlir/lib/IR/Statement.cpp93
1 files changed, 44 insertions, 49 deletions
diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp
index 7d9ef98a7de..681eb4dac23 100644
--- a/mlir/lib/IR/Statement.cpp
+++ b/mlir/lib/IR/Statement.cpp
@@ -262,17 +262,16 @@ bool OperationStmt::isReturn() const { return is<ReturnOp>(); }
//===----------------------------------------------------------------------===//
ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands,
- AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
- AffineMap *ubMap, int64_t step, MLIRContext *context) {
- assert(lbOperands.size() == lbMap->getNumInputs() &&
+ AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
+ AffineMap ubMap, int64_t step) {
+ assert(lbOperands.size() == lbMap.getNumInputs() &&
"lower bound operand count does not match the affine map");
- assert(ubOperands.size() == ubMap->getNumInputs() &&
+ assert(ubOperands.size() == ubMap.getNumInputs() &&
"upper bound operand count does not match the affine map");
assert(step > 0 && "step has to be a positive integer constant");
unsigned numOperands = lbOperands.size() + ubOperands.size();
- ForStmt *stmt =
- new ForStmt(location, numOperands, lbMap, ubMap, step, context);
+ ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step);
unsigned i = 0;
for (unsigned e = lbOperands.size(); i != e; ++i)
@@ -284,30 +283,31 @@ ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands,
return stmt;
}
-ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
- AffineMap *ubMap, int64_t step, MLIRContext *context)
+ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap lbMap,
+ AffineMap ubMap, int64_t step)
: Statement(Kind::For, location),
- MLValue(MLValueKind::ForStmt, Type::getIndex(context)),
+ MLValue(MLValueKind::ForStmt,
+ Type::getIndex(lbMap.getResult(0).getContext())),
StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
operands.reserve(numOperands);
}
const AffineBound ForStmt::getLowerBound() const {
- return AffineBound(*this, 0, lbMap->getNumInputs(), lbMap);
+ return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
}
const AffineBound ForStmt::getUpperBound() const {
- return AffineBound(*this, lbMap->getNumInputs(), getNumOperands(), ubMap);
+ return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
}
-void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) {
- assert(lbOperands.size() == map->getNumInputs());
- assert(map->getNumResults() >= 1 && "bound map has at least one result");
+void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
+ assert(lbOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands());
operands.clear();
- operands.reserve(lbOperands.size() + ubMap->getNumInputs());
+ operands.reserve(lbOperands.size() + ubMap.getNumInputs());
for (auto *operand : lbOperands) {
operands.emplace_back(StmtOperand(this, operand));
}
@@ -317,9 +317,9 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) {
this->lbMap = map;
}
-void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) {
- assert(ubOperands.size() == map->getNumInputs());
- assert(map->getNumResults() >= 1 && "bound map has at least one result");
+void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) {
+ assert(ubOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands());
@@ -334,34 +334,30 @@ void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) {
this->ubMap = map;
}
-void ForStmt::setLowerBoundMap(AffineMap *map) {
- assert(lbMap->getNumDims() == map->getNumDims() &&
- lbMap->getNumSymbols() == map->getNumSymbols());
- assert(map->getNumResults() >= 1 && "bound map has at least one result");
+void ForStmt::setLowerBoundMap(AffineMap map) {
+ assert(lbMap.getNumDims() == map.getNumDims() &&
+ lbMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->lbMap = map;
}
-void ForStmt::setUpperBoundMap(AffineMap *map) {
- assert(ubMap->getNumDims() == map->getNumDims() &&
- ubMap->getNumSymbols() == map->getNumSymbols());
- assert(map->getNumResults() >= 1 && "bound map has at least one result");
+void ForStmt::setUpperBoundMap(AffineMap map) {
+ assert(ubMap.getNumDims() == map.getNumDims() &&
+ ubMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->ubMap = map;
}
-bool ForStmt::hasConstantLowerBound() const {
- return lbMap->isSingleConstant();
-}
+bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
-bool ForStmt::hasConstantUpperBound() const {
- return ubMap->isSingleConstant();
-}
+bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
int64_t ForStmt::getConstantLowerBound() const {
- return lbMap->getSingleConstantResult();
+ return lbMap.getSingleConstantResult();
}
int64_t ForStmt::getConstantUpperBound() const {
- return ubMap->getSingleConstantResult();
+ return ubMap.getSingleConstantResult();
}
void ForStmt::setConstantLowerBound(int64_t value) {
@@ -373,21 +369,20 @@ void ForStmt::setConstantUpperBound(int64_t value) {
}
ForStmt::operand_range ForStmt::getLowerBoundOperands() {
- return {operand_begin(),
- operand_begin() + getLowerBoundMap()->getNumInputs()};
+ return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
ForStmt::operand_range ForStmt::getUpperBoundOperands() {
- return {operand_begin() + getLowerBoundMap()->getNumInputs(), operand_end()};
+ return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
bool ForStmt::matchingBoundOperandList() const {
- if (lbMap->getNumDims() != ubMap->getNumDims() ||
- lbMap->getNumSymbols() != ubMap->getNumSymbols())
+ if (lbMap.getNumDims() != ubMap.getNumDims() ||
+ lbMap.getNumSymbols() != ubMap.getNumSymbols())
return false;
- unsigned numOperands = lbMap->getNumInputs();
- for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
+ unsigned numOperands = lbMap.getNumInputs();
+ for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
// Compare MLValue *'s.
if (getOperand(i) != getOperand(numOperands + i))
return false;
@@ -419,11 +414,11 @@ bool ForStmt::constantFoldBound(bool lower) {
operandConstants.push_back(operandCst);
}
- AffineMap *boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
- assert(boundMap->getNumResults() >= 1 &&
+ AffineMap boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
+ assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute *, 4> foldedResults;
- if (boundMap->constantFold(operandConstants, foldedResults))
+ if (boundMap.constantFold(operandConstants, foldedResults))
return true;
// Compute the max or min as applicable over the results.
@@ -523,14 +518,14 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
}
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
- auto *lbMap = forStmt->getLowerBoundMap();
- auto *ubMap = forStmt->getUpperBoundMap();
+ auto lbMap = forStmt->getLowerBoundMap();
+ auto ubMap = forStmt->getUpperBoundMap();
auto *newFor = ForStmt::create(
getLoc(),
- ArrayRef<MLValue *>(operands).take_front(lbMap->getNumInputs()), lbMap,
- ArrayRef<MLValue *>(operands).take_back(ubMap->getNumInputs()), ubMap,
- forStmt->getStep(), context);
+ ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap,
+ ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap,
+ forStmt->getStep());
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;
OpenPOWER on IntegriCloud