summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorChris Lattner <clattner@google.com>2018-12-26 11:21:53 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 14:36:35 -0700
commitd613f5ab65bbb80c3f5a0a38fef22cb4878c4358 (patch)
treedfa4207dd899f87741ad784fe8fc343afba6885d /mlir/lib
parent9a4060d3f50f046973d4e0b61f324bbc6ebbbdb9 (diff)
downloadbcm5719-llvm-d613f5ab65bbb80c3f5a0a38fef22cb4878c4358.tar.gz
bcm5719-llvm-d613f5ab65bbb80c3f5a0a38fef22cb4878c4358.zip
Refactor MLFunction to contain a StmtBlock for its body instead of inheriting
from it. This is necessary progress to squaring away the parent relationship that a StmtBlock has with its enclosing if/for/fn, and makes room for functions to have more than one block in the future. This also removes IfClause and ForStmtBody. This is step 5/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 226936541
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/Verifier.cpp10
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp6
-rw-r--r--mlir/lib/IR/BuiltinOps.cpp2
-rw-r--r--mlir/lib/IR/Function.cpp9
-rw-r--r--mlir/lib/IR/Operation.cpp2
-rw-r--r--mlir/lib/IR/StmtBlock.cpp27
-rw-r--r--mlir/lib/Parser/Parser.cpp10
-rw-r--r--mlir/lib/Transforms/ConvertToCFG.cpp2
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp2
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp2
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp4
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp4
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp3
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp2
15 files changed, 42 insertions, 45 deletions
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 6e1522a656f..07324ba7d52 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() {
HashTable::ScopeTy blockScope(liveValues);
// The induction variable of a for statement is live within its body.
- if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
- liveValues.insert(forStmtBody->getFor(), true);
+ if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingStmt()))
+ liveValues.insert(forStmt, true);
for (auto &stmt : block) {
// Verify that each of the operands are live.
@@ -330,16 +330,16 @@ bool MLFuncVerifier::verifyDominance() {
};
// Check the whole function out.
- return walkBlock(fn);
+ return walkBlock(*fn.getBody());
}
bool MLFuncVerifier::verifyReturn() {
// TODO: fold return verification in the pass that verifies all statements.
const char missingReturnMsg[] = "ML function must end with return statement";
- if (fn.getStatements().empty())
+ if (fn.getBody()->getStatements().empty())
return failure(missingReturnMsg, fn);
- const auto &stmt = fn.getStatements().back();
+ const auto &stmt = fn.getBody()->getStatements().back();
if (const auto *op = dyn_cast<OperationStmt>(&stmt)) {
if (!op->isReturn())
return failure(missingReturnMsg, fn);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 2de17563d93..3b193117355 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -230,7 +230,7 @@ void ModuleState::visitStatement(const Statement *stmt) {
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
- for (auto &stmt : *fn) {
+ for (auto &stmt : *fn->getBody()) {
ModuleState::visitStatement(&stmt);
}
}
@@ -1390,7 +1390,7 @@ void MLFunctionPrinter::print() {
printFunctionSignature();
printFunctionAttributes(getFunction());
os << " {\n";
- print(function);
+ print(function->getBody());
os << "}\n\n";
}
@@ -1649,7 +1649,7 @@ void Statement::print(raw_ostream &os) const {
void Statement::dump() const { print(llvm::errs()); }
void StmtBlock::printBlock(raw_ostream &os) const {
- MLFunction *function = findFunction();
+ const MLFunction *function = findFunction();
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);
diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp
index 5d7ba237b44..dfd59c4d380 100644
--- a/mlir/lib/IR/BuiltinOps.cpp
+++ b/mlir/lib/IR/BuiltinOps.cpp
@@ -474,7 +474,7 @@ void ReturnOp::print(OpAsmPrinter *p) const {
bool ReturnOp::verify() const {
const Function *function;
if (auto *stmt = dyn_cast<OperationStmt>(getOperation()))
- function = cast<MLFunction>(stmt->getBlock());
+ function = stmt->getBlock()->findFunction();
else
function = cast<Instruction>(getOperation())->getFunction();
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index b79e1596a65..533be8e2a29 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -202,14 +202,13 @@ MLFunction *MLFunction::create(Location location, StringRef name,
MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
- : Function(Kind::MLFunc, location, name, type, attrs),
- StmtBlock(StmtBlockKind::MLFunc) {}
+ : Function(Kind::MLFunc, location, name, type, attrs), body(this) {}
MLFunction::~MLFunction() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
// since child statements need to be destroyed before function arguments
// are destroyed.
- clear();
+ getBody()->clear();
// Explicitly run the destructors for the function arguments.
for (auto &arg : getArgumentsInternal())
@@ -222,11 +221,11 @@ void MLFunction::destroy() {
}
const OperationStmt *MLFunction::getReturnStmt() const {
- return cast<OperationStmt>(&back());
+ return cast<OperationStmt>(&getBody()->back());
}
OperationStmt *MLFunction::getReturnStmt() {
- return cast<OperationStmt>(&back());
+ return cast<OperationStmt>(&getBody()->back());
}
void MLFunction::walk(std::function<void(OperationStmt *)> callback) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index d3a618d7da5..c946a76a98b 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -581,7 +581,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
// Verify that the operation is at the end of the respective parent block.
if (auto *stmt = dyn_cast<OperationStmt>(op)) {
StmtBlock *block = stmt->getBlock();
- if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
+ if (!block || block->getContainingStmt() || &block->back() != stmt)
return op->emitOpError("must be the last statement in the ML function");
} else {
const Instruction *inst = cast<Instruction>(op);
diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp
index 8ecb903d21d..fdee491c150 100644
--- a/mlir/lib/IR/StmtBlock.cpp
+++ b/mlir/lib/IR/StmtBlock.cpp
@@ -20,33 +20,30 @@
#include "mlir/IR/Statements.h"
using namespace mlir;
+StmtBlock::StmtBlock(MLFunction *parent) : parent(parent) {}
+
+StmtBlock::StmtBlock(Statement *parent) : parent(parent) {}
+
StmtBlock::~StmtBlock() {
clear();
llvm::DeleteContainerPointers(arguments);
}
+/// Returns the closest surrounding statement that contains this block or
+/// nullptr if this is a top-level statement block.
Statement *StmtBlock::getContainingStmt() {
- switch (kind) {
- case StmtBlockKind::MLFunc:
- return nullptr;
- case StmtBlockKind::ForBody:
- return cast<ForStmtBody>(this)->getFor();
- case StmtBlockKind::IfClause:
- return cast<IfClause>(this)->getIf();
- }
+ return parent.dyn_cast<Statement *>();
}
-MLFunction *StmtBlock::findFunction() const {
- // FIXME: const incorrect.
- StmtBlock *block = const_cast<StmtBlock *>(this);
-
- while (block->getContainingStmt()) {
- block = block->getContainingStmt()->getBlock();
+MLFunction *StmtBlock::findFunction() {
+ StmtBlock *block = this;
+ while (auto *stmt = block->getContainingStmt()) {
+ block = stmt->getBlock();
if (!block)
return nullptr;
}
- return dyn_cast<MLFunction>(block);
+ return block->getParent().get<MLFunction *>();
}
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 781ec461b62..1a28648eba9 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -2777,7 +2777,7 @@ class MLFunctionParser : public FunctionParser {
public:
MLFunctionParser(ParserState &state, MLFunction *function)
: FunctionParser(state, Kind::MLFunc), function(function),
- builder(function, function->end()) {}
+ builder(function->getBody()) {}
ParseResult parseFunctionBody();
@@ -2796,7 +2796,7 @@ private:
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map,
bool isLower);
ParseResult parseIfStmt();
- ParseResult parseElseClause(IfClause *elseClause);
+ ParseResult parseElseClause(StmtBlock *elseClause);
ParseResult parseStatements(StmtBlock *block);
ParseResult parseStmtBlock(StmtBlock *block);
@@ -2812,7 +2812,7 @@ ParseResult MLFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
// Parse statements in this function.
- if (parseStmtBlock(function))
+ if (parseStmtBlock(function->getBody()))
return ParseFailure;
return finalizeFunction(function, braceLoc);
@@ -3121,7 +3121,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
IfStmt *ifStmt =
builder.createIf(getEncodedSourceLocation(loc), operands, set);
- IfClause *thenClause = ifStmt->getThen();
+ StmtBlock *thenClause = ifStmt->getThen();
// When parsing of an if statement body fails, the IR contains
// the if statement with the portion of the body that has been
@@ -3141,7 +3141,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
return ParseSuccess;
}
-ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
+ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) {
if (getToken().is(Token::kw_if)) {
builder.setInsertionPointToEnd(elseClause);
return parseIfStmt();
diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp
index 8620230b2f1..4fafff51322 100644
--- a/mlir/lib/Transforms/ConvertToCFG.cpp
+++ b/mlir/lib/Transforms/ConvertToCFG.cpp
@@ -490,7 +490,7 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
}
// Convert statements in order.
- for (auto &stmt : *mlFunc) {
+ for (auto &stmt : *mlFunc->getBody()) {
visit(&stmt);
}
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 2b79064e53f..a927516345a 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -426,7 +426,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
}
PassResult DmaGeneration::runOnMLFunction(MLFunction *f) {
- for (auto &stmt : *f) {
+ for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
runOnForStmt(forStmt);
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index df68765aeb7..e3609496cc5 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -348,7 +348,7 @@ public:
bool MemRefDependenceGraph::init(MLFunction *f) {
unsigned id = 0;
DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses;
- for (auto &stmt : *f) {
+ for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
// Create graph node 'id' to represent top-level 'forStmt' and record
// all loads and store accesses it contains.
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 847db83aebc..b5c12865790 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -230,8 +230,8 @@ static void getTileableBands(MLFunction *f,
bands->push_back(band);
};
- for (auto &stmt : *f) {
- ForStmt *forStmt = dyn_cast<ForStmt>(&stmt);
+ for (auto &stmt : *f->getBody()) {
+ auto *forStmt = dyn_cast<ForStmt>(&stmt);
if (!forStmt)
continue;
getMaximalPerfectLoopNest(forStmt);
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index dd491f8119b..ffff1c5b615 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -92,10 +92,10 @@ PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) {
// Currently, just the outermost loop from the first loop nest is
// unroll-and-jammed by this pass. However, runOnForStmt can be called on any
// for Stmt.
- if (!isa<ForStmt>(f->begin()))
+ auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin());
+ if (!forStmt)
return success();
- auto *forStmt = cast<ForStmt>(f->begin());
runOnForStmt(forStmt);
return success();
}
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index d4069eaa638..fd07619a165 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -238,7 +238,7 @@ struct LowerVectorTransfersPass
makeFuncWiseState(MLFunction *f) const override {
auto state = llvm::make_unique<LowerVectorTransfersState>();
auto builder = MLFuncBuilder(f);
- builder.setInsertionPointToStart(f);
+ builder.setInsertionPointToStart(f->getBody());
state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
return state;
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 554e3cb47a9..fbde1fd1692 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -177,7 +177,8 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin());
} else {
auto *mlFunc = cast<MLFunction>(currentFunction);
- cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
+ cast<OperationStmt>(op)->moveBefore(mlFunc->getBody(),
+ mlFunc->getBody()->begin());
}
continue;
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 4d75f7c0835..023d3ebc643 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -102,7 +102,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
if (!forStmt->use_empty()) {
if (forStmt->hasConstantLowerBound()) {
auto *mlFunc = forStmt->findFunction();
- MLFuncBuilder topBuilder(&mlFunc->front());
+ MLFuncBuilder topBuilder(&mlFunc->getBody()->front());
auto constOp = topBuilder.create<ConstantIndexOp>(
forStmt->getLoc(), forStmt->getConstantLowerBound());
forStmt->replaceAllUsesWith(constOp);
OpenPOWER on IntegriCloud