diff options
Diffstat (limited to 'mlir/include/mlir/IR/Statements.h')
| -rw-r--r-- | mlir/include/mlir/IR/Statements.h | 44 |
1 files changed, 37 insertions, 7 deletions
diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index 9b391477336..28f5a14540d 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -216,8 +216,31 @@ private: } }; +/// A ForStmtBody represents statements contained within a ForStmt. +class ForStmtBody : public StmtBlock { +public: + explicit ForStmtBody(ForStmt *stmt) + : StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) { + assert(stmt != nullptr && "ForStmtBody must have non-null parent"); + } + + ~ForStmtBody() {} + + /// Methods for support type inquiry through isa, cast, and dyn_cast + static bool classof(const StmtBlock *block) { + return block->getStmtBlockKind() == StmtBlockKind::ForBody; + } + + /// Returns the 'for' statement that contains this body. + ForStmt *getFor() { return forStmt; } + const ForStmt *getFor() const { return forStmt; } + +private: + ForStmt *forStmt; +}; + /// For statement represents an affine loop nest. -class ForStmt : public Statement, public MLValue, public StmtBlock { +class ForStmt : public Statement, public MLValue { public: static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands, AffineMap lbMap, ArrayRef<MLValue *> ubOperands, @@ -228,7 +251,7 @@ public: // since child statements need to be destroyed before the MLValue that this // for stmt represents is destroyed. Affine maps are immortal objects and // don't need to be deleted. - clear(); + getBody()->clear(); } /// Resolve base class ambiguity. @@ -242,6 +265,12 @@ public: using operand_range = llvm::iterator_range<operand_iterator>; using const_operand_range = llvm::iterator_range<const_operand_iterator>; + /// Get the body of the ForStmt. + ForStmtBody *getBody() { return &body; } + + /// Get the body of the ForStmt. + const ForStmtBody *getBody() const { return &body; } + //===--------------------------------------------------------------------===// // Bounds and step //===--------------------------------------------------------------------===// @@ -359,10 +388,6 @@ public: return ptr->getKind() == IROperandOwner::Kind::ForStmt; } - static bool classof(const StmtBlock *block) { - return block->getStmtBlockKind() == StmtBlockKind::For; - } - // For statement represents implicitly represents induction variable by // inheriting from MLValue class. Whenever you need to refer to the loop // induction variable, just use the for statement itself. @@ -371,6 +396,9 @@ public: } private: + // The StmtBlock for the body. + ForStmtBody body; + // Affine map for the lower bound. AffineMap lbMap; // Affine map for the upper bound. The upper bound is exclusive. @@ -456,7 +484,9 @@ public: ~IfClause() {} /// Returns the if statement that contains this clause. - IfStmt *getIf() const { return ifStmt; } + const IfStmt *getIf() const { return ifStmt; } + + IfStmt *getIf() { return ifStmt; } private: IfStmt *ifStmt; |

