diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-26 11:21:53 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:36:35 -0700 |
| commit | d613f5ab65bbb80c3f5a0a38fef22cb4878c4358 (patch) | |
| tree | dfa4207dd899f87741ad784fe8fc343afba6885d | |
| parent | 9a4060d3f50f046973d4e0b61f324bbc6ebbbdb9 (diff) | |
| download | bcm5719-llvm-d613f5ab65bbb80c3f5a0a38fef22cb4878c4358.tar.gz bcm5719-llvm-d613f5ab65bbb80c3f5a0a38fef22cb4878c4358.zip | |
Refactor MLFunction to contain a StmtBlock for its body instead of inheriting
from it. This is necessary progress to squaring away the parent relationship
that a StmtBlock has with its enclosing if/for/fn, and makes room for functions
to have more than one block in the future. This also removes IfClause and ForStmtBody.
This is step 5/n towards merging instructions and statements, NFC.
PiperOrigin-RevId: 226936541
| -rw-r--r-- | mlir/include/mlir/IR/Builders.h | 8 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/MLFunction.h | 9 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Statements.h | 70 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/StmtBlock.h | 24 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/StmtVisitor.h | 6 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Verifier.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/IR/BuiltinOps.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/IR/Function.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/IR/Operation.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/IR/StmtBlock.cpp | 27 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ConvertToCFG.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DmaGeneration.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopTiling.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnrollAndJam.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 2 |
20 files changed, 81 insertions, 123 deletions
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index f743930bd58..3525c31e099 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -294,6 +294,12 @@ public: setInsertionPoint(stmt); } + MLFuncBuilder(StmtBlock *block) + // TODO: Eliminate findFunction from this. + : MLFuncBuilder(block->findFunction()) { + setInsertionPoint(block, block->end()); + } + MLFuncBuilder(StmtBlock *block, StmtBlock::iterator insertPoint) // TODO: Eliminate findFunction from this. : MLFuncBuilder(block->findFunction()) { @@ -304,7 +310,7 @@ public: /// the function. MLFuncBuilder(MLFunction *func) : Builder(func->getContext()), function(func) { - setInsertionPoint(func, func->begin()); + setInsertionPoint(func->getBody(), func->getBody()->begin()); } /// Return the function this builder is referring to. diff --git a/mlir/include/mlir/IR/MLFunction.h b/mlir/include/mlir/IR/MLFunction.h index cf7f64f869a..58261e04d8f 100644 --- a/mlir/include/mlir/IR/MLFunction.h +++ b/mlir/include/mlir/IR/MLFunction.h @@ -36,7 +36,6 @@ template <typename ObjectType, typename ElementType> class ArgumentIterator; // include nested affine for loops, conditionals and operations. class MLFunction final : public Function, - public StmtBlock, private llvm::TrailingObjects<MLFunction, MLFuncArgument> { public: /// Creates a new MLFunction with the specific type. @@ -44,6 +43,9 @@ public: FunctionType type, ArrayRef<NamedAttribute> attrs = {}); + StmtBlock *getBody() { return &body; } + const StmtBlock *getBody() const { return &body; } + /// Destroys this statement and its subclass data. void destroy(); @@ -98,9 +100,6 @@ public: static bool classof(const Function *func) { return func->getKind() == Function::Kind::MLFunc; } - static bool classof(const StmtBlock *block) { - return block->getStmtBlockKind() == StmtBlockKind::MLFunc; - } private: MLFunction(Location location, StringRef name, FunctionType type, @@ -119,6 +118,8 @@ private: MutableArrayRef<MLFuncArgument> getArgumentsInternal() { return {getTrailingObjects<MLFuncArgument>(), getNumArguments()}; } + + StmtBlock body; }; //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index fcffa9397c9..7ff955cc824 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -274,29 +274,6 @@ private: size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; } }; -/// A ForStmtBody represents statements contained within a ForStmt. -class ForStmtBody : public StmtBlock { -public: - explicit ForStmtBody(ForStmt *stmt) - : StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) { - assert(stmt != nullptr && "ForStmtBody must have non-null parent"); - } - - ~ForStmtBody() {} - - /// Methods for support type inquiry through isa, cast, and dyn_cast - static bool classof(const StmtBlock *block) { - return block->getStmtBlockKind() == StmtBlockKind::ForBody; - } - - /// Returns the 'for' statement that contains this body. - ForStmt *getFor() { return forStmt; } - const ForStmt *getFor() const { return forStmt; } - -private: - ForStmt *forStmt; -}; - /// For statement represents an affine loop nest. class ForStmt : public Statement, public MLValue { public: @@ -324,10 +301,10 @@ public: using const_operand_range = llvm::iterator_range<const_operand_iterator>; /// Get the body of the ForStmt. - ForStmtBody *getBody() { return &body; } + StmtBlock *getBody() { return &body; } /// Get the body of the ForStmt. - const ForStmtBody *getBody() const { return &body; } + const StmtBlock *getBody() const { return &body; } //===--------------------------------------------------------------------===// // Bounds and step @@ -455,7 +432,7 @@ public: private: // The StmtBlock for the body. - ForStmtBody body; + StmtBlock body; // Affine map for the lower bound. AffineMap lbMap; @@ -525,31 +502,6 @@ private: friend class ForStmt; }; -/// An if clause represents statements contained within a then or an else clause -/// of an if statement. -class IfClause : public StmtBlock { -public: - explicit IfClause(IfStmt *stmt) - : StmtBlock(StmtBlockKind::IfClause), ifStmt(stmt) { - assert(stmt != nullptr && "If clause must have non-null parent"); - } - - /// Methods for support type inquiry through isa, cast, and dyn_cast - static bool classof(const StmtBlock *block) { - return block->getStmtBlockKind() == StmtBlockKind::IfClause; - } - - ~IfClause() {} - - /// Returns the if statement that contains this clause. - const IfStmt *getIf() const { return ifStmt; } - - IfStmt *getIf() { return ifStmt; } - -private: - IfStmt *ifStmt; -}; - /// If statement restricts execution to a subset of the loop iteration space. class IfStmt : public Statement { public: @@ -561,15 +513,15 @@ public: // Then, else, condition. //===--------------------------------------------------------------------===// - IfClause *getThen() { return &thenClause; } - const IfClause *getThen() const { return &thenClause; } - IfClause *getElse() { return elseClause; } - const IfClause *getElse() const { return elseClause; } + StmtBlock *getThen() { return &thenClause; } + const StmtBlock *getThen() const { return &thenClause; } + StmtBlock *getElse() { return elseClause; } + const StmtBlock *getElse() const { return elseClause; } bool hasElse() const { return elseClause != nullptr; } - IfClause *createElse() { + StmtBlock *createElse() { assert(elseClause == nullptr && "already has an else clause!"); - return (elseClause = new IfClause(this)); + return (elseClause = new StmtBlock(this)); } const AffineCondition getCondition() const; @@ -634,9 +586,9 @@ public: private: // it is always present. - IfClause thenClause; + StmtBlock thenClause; // 'else' clause of the if statement. 'nullptr' if there is no else clause. - IfClause *elseClause; + StmtBlock *elseClause; // The integer set capturing the conditional guard. IntegerSet set; diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h index 65e0f19066e..9ee4d651029 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/StmtBlock.h @@ -39,12 +39,8 @@ template <typename BlockType> class StmtSuccessorIterator; /// children of a parent statement in the ML Function. class StmtBlock : public IRObjectWithUseList { public: - enum class StmtBlockKind { - MLFunc, // MLFunction - ForBody, // ForStmtBody - IfClause // IfClause - }; - + explicit StmtBlock(MLFunction *parent); + explicit StmtBlock(Statement *parent); ~StmtBlock(); void clear() { @@ -54,7 +50,9 @@ public: statements.pop_back(); } - StmtBlockKind getStmtBlockKind() const { return kind; } + llvm::PointerUnion<MLFunction *, Statement *> getParent() const { + return parent; + } /// Returns the closest surrounding statement that contains this block or /// nullptr if this is a top-level statement block. @@ -66,7 +64,10 @@ public: /// Returns the function that this statement block is part of. /// The function is determined by traversing the chain of parent statements. - MLFunction *findFunction() const; + MLFunction *findFunction(); + const MLFunction *findFunction() const { + return const_cast<StmtBlock *>(this)->findFunction(); + } //===--------------------------------------------------------------------===// // Block argument management @@ -224,11 +225,10 @@ public: void printBlock(raw_ostream &os) const; void dumpBlock() const; -protected: - StmtBlock(StmtBlockKind kind) : kind(kind) {} - private: - StmtBlockKind kind; + /// This is the parent function/IfStmt/ForStmt that owns this block. + llvm::PointerUnion<MLFunction *, Statement *> parent; + /// This is the list of statements in the block. StmtListType statements; diff --git a/mlir/include/mlir/IR/StmtVisitor.h b/mlir/include/mlir/IR/StmtVisitor.h index 94bc0b0cdc1..8dcd5863096 100644 --- a/mlir/include/mlir/IR/StmtVisitor.h +++ b/mlir/include/mlir/IR/StmtVisitor.h @@ -132,11 +132,13 @@ public: // Define walkers for MLFunction and all MLFunction statement kinds. void walk(MLFunction *f) { static_cast<SubClass *>(this)->visitMLFunction(f); - static_cast<SubClass *>(this)->walk(f->begin(), f->end()); + static_cast<SubClass *>(this)->walk(f->getBody()->begin(), + f->getBody()->end()); } void walkPostOrder(MLFunction *f) { - static_cast<SubClass *>(this)->walkPostOrder(f->begin(), f->end()); + static_cast<SubClass *>(this)->walkPostOrder(f->getBody()->begin(), + f->getBody()->end()); static_cast<SubClass *>(this)->visitMLFunction(f); } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 6e1522a656f..07324ba7d52 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() { HashTable::ScopeTy blockScope(liveValues); // The induction variable of a for statement is live within its body. - if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block)) - liveValues.insert(forStmtBody->getFor(), true); + if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingStmt())) + liveValues.insert(forStmt, true); for (auto &stmt : block) { // Verify that each of the operands are live. @@ -330,16 +330,16 @@ bool MLFuncVerifier::verifyDominance() { }; // Check the whole function out. - return walkBlock(fn); + return walkBlock(*fn.getBody()); } bool MLFuncVerifier::verifyReturn() { // TODO: fold return verification in the pass that verifies all statements. const char missingReturnMsg[] = "ML function must end with return statement"; - if (fn.getStatements().empty()) + if (fn.getBody()->getStatements().empty()) return failure(missingReturnMsg, fn); - const auto &stmt = fn.getStatements().back(); + const auto &stmt = fn.getBody()->getStatements().back(); if (const auto *op = dyn_cast<OperationStmt>(&stmt)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2de17563d93..3b193117355 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -230,7 +230,7 @@ void ModuleState::visitStatement(const Statement *stmt) { void ModuleState::visitMLFunction(const MLFunction *fn) { visitType(fn->getType()); - for (auto &stmt : *fn) { + for (auto &stmt : *fn->getBody()) { ModuleState::visitStatement(&stmt); } } @@ -1390,7 +1390,7 @@ void MLFunctionPrinter::print() { printFunctionSignature(); printFunctionAttributes(getFunction()); os << " {\n"; - print(function); + print(function->getBody()); os << "}\n\n"; } @@ -1649,7 +1649,7 @@ void Statement::print(raw_ostream &os) const { void Statement::dump() const { print(llvm::errs()); } void StmtBlock::printBlock(raw_ostream &os) const { - MLFunction *function = findFunction(); + const MLFunction *function = findFunction(); ModuleState state(function->getContext()); ModulePrinter modulePrinter(os, state); MLFunctionPrinter(function, modulePrinter).print(this); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 5d7ba237b44..dfd59c4d380 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const { bool ReturnOp::verify() const { const Function *function; if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) - function = cast<MLFunction>(stmt->getBlock()); + function = stmt->getBlock()->findFunction(); else function = cast<Instruction>(getOperation())->getFunction(); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index b79e1596a65..533be8e2a29 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -202,14 +202,13 @@ MLFunction *MLFunction::create(Location location, StringRef name, MLFunction::MLFunction(Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs) - : Function(Kind::MLFunc, location, name, type, attrs), - StmtBlock(StmtBlockKind::MLFunc) {} + : Function(Kind::MLFunc, location, name, type, attrs), body(this) {} MLFunction::~MLFunction() { // Explicitly erase statements instead of relying of 'StmtBlock' destructor // since child statements need to be destroyed before function arguments // are destroyed. - clear(); + getBody()->clear(); // Explicitly run the destructors for the function arguments. for (auto &arg : getArgumentsInternal()) @@ -222,11 +221,11 @@ void MLFunction::destroy() { } const OperationStmt *MLFunction::getReturnStmt() const { - return cast<OperationStmt>(&back()); + return cast<OperationStmt>(&getBody()->back()); } OperationStmt *MLFunction::getReturnStmt() { - return cast<OperationStmt>(&back()); + return cast<OperationStmt>(&getBody()->back()); } void MLFunction::walk(std::function<void(OperationStmt *)> callback) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index d3a618d7da5..c946a76a98b 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -581,7 +581,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) { // Verify that the operation is at the end of the respective parent block. if (auto *stmt = dyn_cast<OperationStmt>(op)) { StmtBlock *block = stmt->getBlock(); - if (!block || !isa<MLFunction>(block) || &block->back() != stmt) + if (!block || block->getContainingStmt() || &block->back() != stmt) return op->emitOpError("must be the last statement in the ML function"); } else { const Instruction *inst = cast<Instruction>(op); diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 8ecb903d21d..fdee491c150 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -20,33 +20,30 @@ #include "mlir/IR/Statements.h" using namespace mlir; +StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {} + +StmtBlock::StmtBlock(Statement *parent) : parent(parent) {} + StmtBlock::~StmtBlock() { clear(); llvm::DeleteContainerPointers(arguments); } +/// Returns the closest surrounding statement that contains this block or +/// nullptr if this is a top-level statement block. Statement *StmtBlock::getContainingStmt() { - switch (kind) { - case StmtBlockKind::MLFunc: - return nullptr; - case StmtBlockKind::ForBody: - return cast<ForStmtBody>(this)->getFor(); - case StmtBlockKind::IfClause: - return cast<IfClause>(this)->getIf(); - } + return parent.dyn_cast<Statement *>(); } -MLFunction *StmtBlock::findFunction() const { - // FIXME: const incorrect. - StmtBlock *block = const_cast<StmtBlock *>(this); - - while (block->getContainingStmt()) { - block = block->getContainingStmt()->getBlock(); +MLFunction *StmtBlock::findFunction() { + StmtBlock *block = this; + while (auto *stmt = block->getContainingStmt()) { + block = stmt->getBlock(); if (!block) return nullptr; } - return dyn_cast<MLFunction>(block); + return block->getParent().get<MLFunction *>(); } /// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 781ec461b62..1a28648eba9 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2777,7 +2777,7 @@ class MLFunctionParser : public FunctionParser { public: MLFunctionParser(ParserState &state, MLFunction *function) : FunctionParser(state, Kind::MLFunc), function(function), - builder(function, function->end()) {} + builder(function->getBody()) {} ParseResult parseFunctionBody(); @@ -2796,7 +2796,7 @@ private: ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map, bool isLower); ParseResult parseIfStmt(); - ParseResult parseElseClause(IfClause *elseClause); + ParseResult parseElseClause(StmtBlock *elseClause); ParseResult parseStatements(StmtBlock *block); ParseResult parseStmtBlock(StmtBlock *block); @@ -2812,7 +2812,7 @@ ParseResult MLFunctionParser::parseFunctionBody() { auto braceLoc = getToken().getLoc(); // Parse statements in this function. - if (parseStmtBlock(function)) + if (parseStmtBlock(function->getBody())) return ParseFailure; return finalizeFunction(function, braceLoc); @@ -3121,7 +3121,7 @@ ParseResult MLFunctionParser::parseIfStmt() { IfStmt *ifStmt = builder.createIf(getEncodedSourceLocation(loc), operands, set); - IfClause *thenClause = ifStmt->getThen(); + StmtBlock *thenClause = ifStmt->getThen(); // When parsing of an if statement body fails, the IR contains // the if statement with the portion of the body that has been @@ -3141,7 +3141,7 @@ ParseResult MLFunctionParser::parseIfStmt() { return ParseSuccess; } -ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) { +ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) { if (getToken().is(Token::kw_if)) { builder.setInsertionPointToEnd(elseClause); return parseIfStmt(); diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 8620230b2f1..4fafff51322 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -490,7 +490,7 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { } // Convert statements in order. - for (auto &stmt : *mlFunc) { + for (auto &stmt : *mlFunc->getBody()) { visit(&stmt); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 2b79064e53f..a927516345a 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -426,7 +426,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { } PassResult DmaGeneration::runOnMLFunction(MLFunction *f) { - for (auto &stmt : *f) { + for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { runOnForStmt(forStmt); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index df68765aeb7..e3609496cc5 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -348,7 +348,7 @@ public: bool MemRefDependenceGraph::init(MLFunction *f) { unsigned id = 0; DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses; - for (auto &stmt : *f) { + for (auto &stmt : *f->getBody()) { if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { // Create graph node 'id' to represent top-level 'forStmt' and record // all loads and store accesses it contains. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 847db83aebc..b5c12865790 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -230,8 +230,8 @@ static void getTileableBands(MLFunction *f, bands->push_back(band); }; - for (auto &stmt : *f) { - ForStmt *forStmt = dyn_cast<ForStmt>(&stmt); + for (auto &stmt : *f->getBody()) { + auto *forStmt = dyn_cast<ForStmt>(&stmt); if (!forStmt) continue; getMaximalPerfectLoopNest(forStmt); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index dd491f8119b..ffff1c5b615 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -92,10 +92,10 @@ PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnForStmt can be called on any // for Stmt. - if (!isa<ForStmt>(f->begin())) + auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin()); + if (!forStmt) return success(); - auto *forStmt = cast<ForStmt>(f->begin()); runOnForStmt(forStmt); return success(); } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index d4069eaa638..fd07619a165 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -238,7 +238,7 @@ struct LowerVectorTransfersPass makeFuncWiseState(MLFunction *f) const override { auto state = llvm::make_unique<LowerVectorTransfersState>(); auto builder = MLFuncBuilder(f); - builder.setInsertionPointToStart(f); + builder.setInsertionPointToStart(f->getBody()); state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0); return state; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 554e3cb47a9..fbde1fd1692 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -177,7 +177,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin()); } else { auto *mlFunc = cast<MLFunction>(currentFunction); - cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin()); + cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(), + mlFunc->getBody()->begin()); } continue; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 4d75f7c0835..023d3ebc643 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -102,7 +102,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { if (!forStmt->use_empty()) { if (forStmt->hasConstantLowerBound()) { auto *mlFunc = forStmt->findFunction(); - MLFuncBuilder topBuilder(&mlFunc->front()); + MLFuncBuilder topBuilder(&mlFunc->getBody()->front()); auto constOp = topBuilder.create<ConstantIndexOp>( forStmt->getLoc(), forStmt->getConstantLowerBound()); forStmt->replaceAllUsesWith(constOp); |

