diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-26 11:21:53 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:36:35 -0700 |
| commit | d613f5ab65bbb80c3f5a0a38fef22cb4878c4358 (patch) | |
| tree | dfa4207dd899f87741ad784fe8fc343afba6885d /mlir/lib | |
| parent | 9a4060d3f50f046973d4e0b61f324bbc6ebbbdb9 (diff) | |
| download | bcm5719-llvm-d613f5ab65bbb80c3f5a0a38fef22cb4878c4358.tar.gz bcm5719-llvm-d613f5ab65bbb80c3f5a0a38fef22cb4878c4358.zip | |
Refactor MLFunction to contain a StmtBlock for its body instead of inheriting
from it. This is necessary progress to squaring away the parent relationship
that a StmtBlock has with its enclosing if/for/fn, and makes room for functions
to have more than one block in the future. This also removes IfClause and ForStmtBody.
This is step 5/n towards merging instructions and statements, NFC.
PiperOrigin-RevId: 226936541
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/Verifier.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/IR/BuiltinOps.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/IR/Function.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/IR/Operation.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/IR/StmtBlock.cpp | 27 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ConvertToCFG.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DmaGeneration.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopTiling.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 |
15 files changed, 42 insertions, 45 deletions
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 6e1522a656f..07324ba7d52 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() { HashTable::ScopeTy blockScope(liveValues); // The induction variable of a for statement is live within its body. - if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block)) - liveValues.insert(forStmtBody->getFor(), true); + if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingStmt())) + liveValues.insert(forStmt, true); for (auto &stmt : block) { // Verify that each of the operands are live. @@ -330,16 +330,16 @@ bool MLFuncVerifier::verifyDominance() { }; // Check the whole function out. - return walkBlock(fn); + return walkBlock(*fn.getBody()); } bool MLFuncVerifier::verifyReturn() { // TODO: fold return verification in the pass that verifies all statements. const char missingReturnMsg[] = "ML function must end with return statement"; - if (fn.getStatements().empty()) + if (fn.getBody()->getStatements().empty()) return failure(missingReturnMsg, fn); - const auto &stmt = fn.getStatements().back(); + const auto &stmt = fn.getBody()->getStatements().back(); if (const auto *op = dyn_cast<OperationStmt>(&stmt)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2de17563d93..3b193117355 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -230,7 +230,7 @@ void ModuleState::visitStatement(const Statement *stmt) { void ModuleState::visitMLFunction(const MLFunction *fn) { visitType(fn->getType()); - for (auto &stmt : *fn) { + for (auto &stmt : *fn->getBody()) { ModuleState::visitStatement(&stmt); } } @@ -1390,7 +1390,7 @@ void MLFunctionPrinter::print() { printFunctionSignature(); printFunctionAttributes(getFunction()); os << " {\n"; - print(function); + print(function->getBody()); os << "}\n\n"; } @@ -1649,7 +1649,7 @@ void Statement::print(raw_ostream &os) const { void Statement::dump() const { print(llvm::errs()); } void StmtBlock::printBlock(raw_ostream &os) const { - MLFunction *function = findFunction(); + const MLFunction *function = findFunction(); ModuleState state(function->getContext()); ModulePrinter modulePrinter(os, state); MLFunctionPrinter(function, modulePrinter).print(this); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 5d7ba237b44..dfd59c4d380 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const { bool ReturnOp::verify() const { const Function *function; if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) - function = cast<MLFunction>(stmt->getBlock()); + function = stmt->getBlock()->findFunction(); else function = cast<Instruction>(getOperation())->getFunction(); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index b79e1596a65..533be8e2a29 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -202,14 +202,13 @@ MLFunction *MLFunction::create(Location location, StringRef name, MLFunction::MLFunction(Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) - : Function(Kind::MLFunc, location, name, type, attrs), - StmtBlock(StmtBlockKind::MLFunc) {} + : Function(Kind::MLFunc, location, name, type, attrs), body(this) {} MLFunction::~MLFunction() { // Explicitly erase statements instead of relying of 'StmtBlock' destructor // since child statements need to be destroyed before function arguments // are destroyed. - clear(); + getBody()->clear(); // Explicitly run the destructors for the function arguments. for (auto &arg : getArgumentsInternal()) @@ -222,11 +221,11 @@ void MLFunction::destroy() { } const OperationStmt *MLFunction::getReturnStmt() const { - return cast<OperationStmt>(&back()); + return cast<OperationStmt>(&getBody()->back()); } OperationStmt *MLFunction::getReturnStmt() { - return cast<OperationStmt>(&back()); + return cast<OperationStmt>(&getBody()->back()); } void MLFunction::walk(std::function<void(OperationStmt *)> callback) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index d3a618d7da5..c946a76a98b 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -581,7 +581,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) { // Verify that the operation is at the end of the respective parent block. if (auto *stmt = dyn_cast<OperationStmt>(op)) { StmtBlock *block = stmt->getBlock(); - if (!block || !isa<MLFunction>(block) || &block->back() != stmt) + if (!block || block->getContainingStmt() || &block->back() != stmt) return op->emitOpError("must be the last statement in the ML function"); } else { const Instruction *inst = cast<Instruction>(op); diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 8ecb903d21d..fdee491c150 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -20,33 +20,30 @@ #include "mlir/IR/Statements.h" using namespace mlir; +StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {} + +StmtBlock::StmtBlock(Statement *parent) : parent(parent) {} + StmtBlock::~StmtBlock() { clear(); llvm::DeleteContainerPointers(arguments); } +/// Returns the closest surrounding statement that contains this block or +/// nullptr if this is a top-level statement block. Statement *StmtBlock::getContainingStmt() { - switch (kind) { - case StmtBlockKind::MLFunc: - return nullptr; - case StmtBlockKind::ForBody: - return cast<ForStmtBody>(this)->getFor(); - case StmtBlockKind::IfClause: - return cast<IfClause>(this)->getIf(); - } + return parent.dyn_cast<Statement *>(); } -MLFunction *StmtBlock::findFunction() const { - // FIXME: const incorrect. - StmtBlock *block = const_cast<StmtBlock *>(this); - - while (block->getContainingStmt()) { - block = block->getContainingStmt()->getBlock(); +MLFunction *StmtBlock::findFunction() { + StmtBlock *block = this; + while (auto *stmt = block->getContainingStmt()) { + block = stmt->getBlock(); if (!block) return nullptr; } - return dyn_cast<MLFunction>(block); + return block->getParent().get<MLFunction *>(); } /// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 781ec461b62..1a28648eba9 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2777,7 +2777,7 @@ class MLFunctionParser : public FunctionParser { public: MLFunctionParser(ParserState &state, MLFunction *function) : FunctionParser(state, Kind::MLFunc), function(function), - builder(function, function->end()) {} + builder(function->getBody()) {} ParseResult parseFunctionBody(); @@ -2796,7 +2796,7 @@ private: ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map, bool isLower); ParseResult parseIfStmt(); - ParseResult parseElseClause(IfClause *elseClause); + ParseResult parseElseClause(StmtBlock *elseClause); ParseResult parseStatements(StmtBlock *block); ParseResult parseStmtBlock(StmtBlock *block); @@ -2812,7 +2812,7 @@ ParseResult MLFunctionParser::parseFunctionBody() { auto braceLoc = getToken().getLoc(); // Parse statements in this function. - if (parseStmtBlock(function)) + if (parseStmtBlock(function->getBody())) return ParseFailure; return finalizeFunction(function, braceLoc); @@ -3121,7 +3121,7 @@ ParseResult MLFunctionParser::parseIfStmt() { IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), operands, set); - IfClause *thenClause = ifStmt->getThen(); + StmtBlock *thenClause = ifStmt->getThen(); // When parsing of an if statement body fails, the IR contains // the if statement with the portion of the body that has been @@ -3141,7 +3141,7 @@ ParseResult MLFunctionParser::parseIfStmt() { return ParseSuccess; } -ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) { +ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) { if (getToken().is(Token::kw_if)) { builder.setInsertionPointToEnd(elseClause); return parseIfStmt(); diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 8620230b2f1..4fafff51322 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -490,7 +490,7 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { } // Convert statements in order. - for (auto &stmt : *mlFunc) { + for (auto &stmt : *mlFunc->getBody()) { visit(&stmt); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 2b79064e53f..a927516345a 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -426,7 +426,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { } PassResult DmaGeneration::runOnMLFunction(MLFunction *f) { - for (auto &stmt : *f) { + for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { runOnForStmt(forStmt); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index df68765aeb7..e3609496cc5 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -348,7 +348,7 @@ public: bool MemRefDependenceGraph::init(MLFunction *f) { unsigned id = 0; DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses; - for (auto &stmt : *f) { + for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { // Create graph node 'id' to represent top-level 'forStmt' and record // all loads and store accesses it contains. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 847db83aebc..b5c12865790 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -230,8 +230,8 @@ static void getTileableBands(MLFunction *f, bands->push_back(band); }; - for (auto &stmt : *f) { - ForStmt *forStmt = dyn_cast<ForStmt>(&stmt); + for (auto &stmt : *f->getBody()) { + auto *forStmt = dyn_cast<ForStmt>(&stmt); if (!forStmt) continue; getMaximalPerfectLoopNest(forStmt); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index dd491f8119b..ffff1c5b615 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -92,10 +92,10 @@ PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnForStmt can be called on any // for Stmt. - if (!isa<ForStmt>(f->begin())) + auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin()); + if (!forStmt) return success(); - auto *forStmt = cast<ForStmt>(f->begin()); runOnForStmt(forStmt); return success(); } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index d4069eaa638..fd07619a165 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -238,7 +238,7 @@ struct LowerVectorTransfersPass makeFuncWiseState(MLFunction *f) const override { auto state = llvm::make_unique<LowerVectorTransfersState>(); auto builder = MLFuncBuilder(f); - builder.setInsertionPointToStart(f); + builder.setInsertionPointToStart(f->getBody()); state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0); return state; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 554e3cb47a9..fbde1fd1692 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -177,7 +177,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin()); } else { auto *mlFunc = cast<MLFunction>(currentFunction); - cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin()); + cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(), + mlFunc->getBody()->begin()); } continue; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 4d75f7c0835..023d3ebc643 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -102,7 +102,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { if (!forStmt->use_empty()) { if (forStmt->hasConstantLowerBound()) { auto *mlFunc = forStmt->findFunction(); - MLFuncBuilder topBuilder(&mlFunc->front()); + MLFuncBuilder topBuilder(&mlFunc->getBody()->front()); auto constOp = topBuilder.create<ConstantIndexOp>( forStmt->getLoc(), forStmt->getConstantLowerBound()); forStmt->replaceAllUsesWith(constOp); |

