diff options
| author | Uday Bondhugula <bondhugula@google.com> | 2018-08-01 22:36:12 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 12:53:02 -0700 |
| commit | 2a003256ae3f88ebcbb021e0308a4f786605ae4f (patch) | |
| tree | 045a597ee21da87ea533272ddf69af88262536a5 /mlir/lib/Transforms/LoopUnroll.cpp | |
| parent | b92378e8fa94ecac5d0f11f5e19b1e958f43f4f1 (diff) | |
| download | bcm5719-llvm-2a003256ae3f88ebcbb021e0308a4f786605ae4f.tar.gz bcm5719-llvm-2a003256ae3f88ebcbb021e0308a4f786605ae4f.zip | |
MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions.
- MLStmt cloning and IV replacement
- While at this, fix the innermostLoopGatherer to actually gather all the
innermost loops (it was stopping its walk at the first innermost loop it
found)
- Improve comments for MLFunction statement classes, fix inheritance order.
- Fixed StmtBlock destructor.
PiperOrigin-RevId: 207049173
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 66 |
1 files changed, 47 insertions, 19 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 160a463eb78..fe110d21d66 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/MLFunction.h" @@ -54,10 +55,13 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) { typedef llvm::iplist<Statement> StmtListType; bool walkPostOrder(StmtListType::iterator Start, StmtListType::iterator End) { + bool hasInnerLoops = false; + // We need to walk all elements since all innermost loops need to be + // gathered as opposed to determining whether this list has any inner + // loops or not. while (Start != End) - if (walkPostOrder(&(*Start++))) - return true; - return false; + hasInnerLoops |= walkPostOrder(&(*Start++)); + return hasInnerLoops; } // FIXME: can't use base class method for this because that in turn would @@ -73,12 +77,11 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) { } bool walkIfStmtPostOrder(IfStmt *ifStmt) { - if (walkPostOrder(ifStmt->getThenClause()->begin(), - ifStmt->getThenClause()->end()) || - walkPostOrder(ifStmt->getElseClause()->begin(), - ifStmt->getElseClause()->end())) - return true; - return false; + bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(), + ifStmt->getThenClause()->end()); + hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(), + ifStmt->getElseClause()->end()); + return hasInnerLoops; } bool walkOpStmt(OperationStmt *opStmt) { return false; } @@ -93,17 +96,45 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) { runOnForStmt(forStmt); } -/// Unrolls this loop completely. Returns true if the unrolling happens. +/// Replace an IV with a constant value. +static void replaceIterator(Statement *stmt, const ForStmt &iv, + MLValue *constVal) { + struct ReplaceIterator : public StmtWalker<ReplaceIterator> { + // IV to be replaced. + const ForStmt *iv; + // Constant to be replaced with. + MLValue *constVal; + + ReplaceIterator(const ForStmt &iv, MLValue *constVal) + : iv(&iv), constVal(constVal){}; + + void visitOperationStmt(OperationStmt *os) { + for (auto &operand : os->getStmtOperands()) { + if (operand.get() == static_cast<const MLValue *>(iv)) { + operand.set(constVal); + } + } + } + }; + + ReplaceIterator ri(iv, constVal); + ri.walk(stmt); +} + +/// Unrolls this loop completely. void LoopUnroll::runOnForStmt(ForStmt *forStmt) { auto lb = forStmt->getLowerBound()->getValue(); auto ub = forStmt->getUpperBound()->getValue(); auto step = forStmt->getStep()->getValue(); auto trip_count = (ub - lb + 1) / step; - auto *block = forStmt->getBlock(); - MLFuncBuilder builder(block); + auto *mlFunc = forStmt->Statement::findFunction(); + MLFuncBuilder funcTopBuilder(mlFunc); + funcTopBuilder.setInsertionPointAtStart(mlFunc); + MLFuncBuilder builder(forStmt->getBlock()); for (int i = 0; i < trip_count; i++) { + auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0); for (auto &stmt : forStmt->getStatements()) { switch (stmt.getKind()) { case Statement::Kind::For: @@ -113,16 +144,13 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) { llvm_unreachable("unrolling loops that have only operations"); break; case Statement::Kind::Operation: - auto *op = cast<OperationStmt>(&stmt); - // TODO: clone operands and result types. - builder.createOperation(op->getName(), /*operands*/ {}, - /*resultTypes*/ {}, op->getAttrs()); - // TODO: loop iterator parsing not yet implemented; replace loop - // iterator uses in unrolled body appropriately. + auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt)); + // TODO(bondhugula): only generate constants when the IV actually + // appears in the body. + replaceIterator(cloneOp, *forStmt, ivUnrolledVal); break; } } } - forStmt->eraseFromBlock(); } |

