diff options
| author | Uday Bondhugula <bondhugula@google.com> | 2018-08-01 22:36:12 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 12:53:02 -0700 |
| commit | 2a003256ae3f88ebcbb021e0308a4f786605ae4f (patch) | |
| tree | 045a597ee21da87ea533272ddf69af88262536a5 /mlir/lib | |
| parent | b92378e8fa94ecac5d0f11f5e19b1e958f43f4f1 (diff) | |
| download | bcm5719-llvm-2a003256ae3f88ebcbb021e0308a4f786605ae4f.tar.gz bcm5719-llvm-2a003256ae3f88ebcbb021e0308a4f786605ae4f.zip | |
MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions.
- MLStmt cloning and IV replacement
- While at this, fix the innermostLoopGatherer to actually gather all the
innermost loops (it was stopping its walk at the first innermost loop it
found)
- Improve comments for MLFunction statement classes, fix inheritance order.
- Fixed StmtBlock destructor.
PiperOrigin-RevId: 207049173
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/IR/Instructions.cpp | 15 | ||||
| -rw-r--r-- | mlir/lib/IR/Statement.cpp | 24 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 66 |
3 files changed, 82 insertions, 23 deletions
diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp index a10cb3ae5b4..847907b753c 100644 --- a/mlir/lib/IR/Instructions.cpp +++ b/mlir/lib/IR/Instructions.cpp @@ -145,6 +145,21 @@ OperationInst *OperationInst::create(Identifier name, return inst; } +OperationInst *OperationInst::clone() const { + SmallVector<CFGValue *, 8> operands; + SmallVector<Type *, 8> resultTypes; + + // TODO(clattner): switch to iterator logic. + // Put together the operands and results. + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) + operands.push_back(getInstOperand(i).get()); + + for (unsigned i = 0, e = getNumResults(); i != e; ++i) + resultTypes.push_back(getInstResult(i).getType()); + + return create(getName(), operands, resultTypes, getAttrs(), getContext()); +} + OperationInst::OperationInst(Identifier name, unsigned numOperands, unsigned numResults, ArrayRef<NamedAttribute> attributes, diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 5bc99e07a52..4e2a7b03f8f 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -120,7 +120,6 @@ void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList( /// Remove this statement (and its descendants) from its StmtBlock and delete /// all of them. -/// TODO: erase all descendents for ForStmt/IfStmt. void Statement::eraseFromBlock() { assert(getBlock() && "Statement has no block"); getBlock()->getStatements().erase(this); @@ -155,6 +154,22 @@ OperationStmt *OperationStmt::create(Identifier name, return stmt; } +/// Clone an existing OperationStmt. +OperationStmt *OperationStmt::clone() const { + SmallVector<MLValue *, 8> operands; + SmallVector<Type *, 8> resultTypes; + + // TODO(clattner): switch this to iterator logic. + // Put together operands and results. + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) + operands.push_back(getStmtOperand(i).get()); + + for (unsigned i = 0, e = getNumResults(); i != e; ++i) + resultTypes.push_back(getStmtResult(i).getType()); + + return create(getName(), operands, resultTypes, getAttrs(), getContext()); +} + OperationStmt::OperationStmt(Identifier name, unsigned numOperands, unsigned numResults, ArrayRef<NamedAttribute> attributes, @@ -205,9 +220,10 @@ void OperationStmt::dropAllReferences() { ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound, AffineConstantExpr *step, MLIRContext *context) - : Statement(Kind::For), StmtBlock(StmtBlockKind::For), + : Statement(Kind::For), MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)), - lowerBound(lowerBound), upperBound(upperBound), step(step) {} + StmtBlock(StmtBlockKind::For), lowerBound(lowerBound), + upperBound(upperBound), step(step) {} //===----------------------------------------------------------------------===// // IfStmt @@ -215,6 +231,6 @@ ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound, IfStmt::~IfStmt() { delete thenClause; - if (elseClause != nullptr) + if (elseClause) delete elseClause; } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 160a463eb78..fe110d21d66 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/MLFunction.h" @@ -54,10 +55,13 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) { typedef llvm::iplist<Statement> StmtListType; bool walkPostOrder(StmtListType::iterator Start, StmtListType::iterator End) { + bool hasInnerLoops = false; + // We need to walk all elements since all innermost loops need to be + // gathered as opposed to determining whether this list has any inner + // loops or not. while (Start != End) - if (walkPostOrder(&(*Start++))) - return true; - return false; + hasInnerLoops |= walkPostOrder(&(*Start++)); + return hasInnerLoops; } // FIXME: can't use base class method for this because that in turn would @@ -73,12 +77,11 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) { } bool walkIfStmtPostOrder(IfStmt *ifStmt) { - if (walkPostOrder(ifStmt->getThenClause()->begin(), - ifStmt->getThenClause()->end()) || - walkPostOrder(ifStmt->getElseClause()->begin(), - ifStmt->getElseClause()->end())) - return true; - return false; + bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(), + ifStmt->getThenClause()->end()); + hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(), + ifStmt->getElseClause()->end()); + return hasInnerLoops; } bool walkOpStmt(OperationStmt *opStmt) { return false; } @@ -93,17 +96,45 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) { runOnForStmt(forStmt); } -/// Unrolls this loop completely. Returns true if the unrolling happens. +/// Replace an IV with a constant value. +static void replaceIterator(Statement *stmt, const ForStmt &iv, + MLValue *constVal) { + struct ReplaceIterator : public StmtWalker<ReplaceIterator> { + // IV to be replaced. + const ForStmt *iv; + // Constant to be replaced with. + MLValue *constVal; + + ReplaceIterator(const ForStmt &iv, MLValue *constVal) + : iv(&iv), constVal(constVal){}; + + void visitOperationStmt(OperationStmt *os) { + for (auto &operand : os->getStmtOperands()) { + if (operand.get() == static_cast<const MLValue *>(iv)) { + operand.set(constVal); + } + } + } + }; + + ReplaceIterator ri(iv, constVal); + ri.walk(stmt); +} + +/// Unrolls this loop completely. void LoopUnroll::runOnForStmt(ForStmt *forStmt) { auto lb = forStmt->getLowerBound()->getValue(); auto ub = forStmt->getUpperBound()->getValue(); auto step = forStmt->getStep()->getValue(); auto trip_count = (ub - lb + 1) / step; - auto *block = forStmt->getBlock(); - MLFuncBuilder builder(block); + auto *mlFunc = forStmt->Statement::findFunction(); + MLFuncBuilder funcTopBuilder(mlFunc); + funcTopBuilder.setInsertionPointAtStart(mlFunc); + MLFuncBuilder builder(forStmt->getBlock()); for (int i = 0; i < trip_count; i++) { + auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0); for (auto &stmt : forStmt->getStatements()) { switch (stmt.getKind()) { case Statement::Kind::For: @@ -113,16 +144,13 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) { llvm_unreachable("unrolling loops that have only operations"); break; case Statement::Kind::Operation: - auto *op = cast<OperationStmt>(&stmt); - // TODO: clone operands and result types. - builder.createOperation(op->getName(), /*operands*/ {}, - /*resultTypes*/ {}, op->getAttrs()); - // TODO: loop iterator parsing not yet implemented; replace loop - // iterator uses in unrolled body appropriately. + auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt)); + // TODO(bondhugula): only generate constants when the IV actually + // appears in the body. + replaceIterator(cloneOp, *forStmt, ivUnrolledVal); break; } } } - forStmt->eraseFromBlock(); } |

