diff options
Diffstat (limited to 'mlir/lib/IR/StmtBlock.cpp')
| -rw-r--r-- | mlir/lib/IR/StmtBlock.cpp | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 8ecb903d21d..fdee491c150 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -20,33 +20,30 @@ #include "mlir/IR/Statements.h" using namespace mlir; +StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {} + +StmtBlock::StmtBlock(Statement *parent) : parent(parent) {} + StmtBlock::~StmtBlock() { clear(); llvm::DeleteContainerPointers(arguments); } +/// Returns the closest surrounding statement that contains this block or +/// nullptr if this is a top-level statement block. Statement *StmtBlock::getContainingStmt() { - switch (kind) { - case StmtBlockKind::MLFunc: - return nullptr; - case StmtBlockKind::ForBody: - return cast<ForStmtBody>(this)->getFor(); - case StmtBlockKind::IfClause: - return cast<IfClause>(this)->getIf(); - } + return parent.dyn_cast<Statement *>(); } -MLFunction *StmtBlock::findFunction() const { - // FIXME: const incorrect. - StmtBlock *block = const_cast<StmtBlock *>(this); - - while (block->getContainingStmt()) { - block = block->getContainingStmt()->getBlock(); +MLFunction *StmtBlock::findFunction() { + StmtBlock *block = this; + while (auto *stmt = block->getContainingStmt()) { + block = stmt->getBlock(); if (!block) return nullptr; } - return dyn_cast<MLFunction>(block); + return block->getParent().get<MLFunction *>(); } /// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor |

