summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/StmtBlock.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/StmtBlock.cpp')
-rw-r--r--mlir/lib/IR/StmtBlock.cpp27
1 files changed, 12 insertions, 15 deletions
diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp
index 8ecb903d21d..fdee491c150 100644
--- a/mlir/lib/IR/StmtBlock.cpp
+++ b/mlir/lib/IR/StmtBlock.cpp
@@ -20,33 +20,30 @@
#include "mlir/IR/Statements.h"
using namespace mlir;
+StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {}
+
+StmtBlock::StmtBlock(Statement *parent) : parent(parent) {}
+
StmtBlock::~StmtBlock() {
clear();
llvm::DeleteContainerPointers(arguments);
}
+/// Returns the closest surrounding statement that contains this block or
+/// nullptr if this is a top-level statement block.
Statement *StmtBlock::getContainingStmt() {
- switch (kind) {
- case StmtBlockKind::MLFunc:
- return nullptr;
- case StmtBlockKind::ForBody:
- return cast<ForStmtBody>(this)->getFor();
- case StmtBlockKind::IfClause:
- return cast<IfClause>(this)->getIf();
- }
+ return parent.dyn_cast<Statement *>();
}
-MLFunction *StmtBlock::findFunction() const {
- // FIXME: const incorrect.
- StmtBlock *block = const_cast<StmtBlock *>(this);
-
- while (block->getContainingStmt()) {
- block = block->getContainingStmt()->getBlock();
+MLFunction *StmtBlock::findFunction() {
+ StmtBlock *block = this;
+ while (auto *stmt = block->getContainingStmt()) {
+ block = stmt->getBlock();
if (!block)
return nullptr;
}
- return dyn_cast<MLFunction>(block);
+ return block->getParent().get<MLFunction *>();
}
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
OpenPOWER on IntegriCloud