summaryrefslogtreecommitdiffstats
path: root/mlir/include/mlir/IR/Statements.h
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/include/mlir/IR/Statements.h')
-rw-r--r--mlir/include/mlir/IR/Statements.h44
1 files changed, 37 insertions, 7 deletions
diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h
index 9b391477336..28f5a14540d 100644
--- a/mlir/include/mlir/IR/Statements.h
+++ b/mlir/include/mlir/IR/Statements.h
@@ -216,8 +216,31 @@ private:
}
};
+/// A ForStmtBody represents statements contained within a ForStmt.
+class ForStmtBody : public StmtBlock {
+public:
+ explicit ForStmtBody(ForStmt *stmt)
+ : StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) {
+ assert(stmt != nullptr && "ForStmtBody must have non-null parent");
+ }
+
+ ~ForStmtBody() {}
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast
+ static bool classof(const StmtBlock *block) {
+ return block->getStmtBlockKind() == StmtBlockKind::ForBody;
+ }
+
+ /// Returns the 'for' statement that contains this body.
+ ForStmt *getFor() { return forStmt; }
+ const ForStmt *getFor() const { return forStmt; }
+
+private:
+ ForStmt *forStmt;
+};
+
/// For statement represents an affine loop nest.
-class ForStmt : public Statement, public MLValue, public StmtBlock {
+class ForStmt : public Statement, public MLValue {
public:
static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
@@ -228,7 +251,7 @@ public:
// since child statements need to be destroyed before the MLValue that this
// for stmt represents is destroyed. Affine maps are immortal objects and
// don't need to be deleted.
- clear();
+ getBody()->clear();
}
/// Resolve base class ambiguity.
@@ -242,6 +265,12 @@ public:
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
+ /// Get the body of the ForStmt.
+ ForStmtBody *getBody() { return &body; }
+
+ /// Get the body of the ForStmt.
+ const ForStmtBody *getBody() const { return &body; }
+
//===--------------------------------------------------------------------===//
// Bounds and step
//===--------------------------------------------------------------------===//
@@ -359,10 +388,6 @@ public:
return ptr->getKind() == IROperandOwner::Kind::ForStmt;
}
- static bool classof(const StmtBlock *block) {
- return block->getStmtBlockKind() == StmtBlockKind::For;
- }
-
// For statement represents implicitly represents induction variable by
// inheriting from MLValue class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself.
@@ -371,6 +396,9 @@ public:
}
private:
+ // The StmtBlock for the body.
+ ForStmtBody body;
+
// Affine map for the lower bound.
AffineMap lbMap;
// Affine map for the upper bound. The upper bound is exclusive.
@@ -456,7 +484,9 @@ public:
~IfClause() {}
/// Returns the if statement that contains this clause.
- IfStmt *getIf() const { return ifStmt; }
+ const IfStmt *getIf() const { return ifStmt; }
+
+ IfStmt *getIf() { return ifStmt; }
private:
IfStmt *ifStmt;
OpenPOWER on IntegriCloud