summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopUnroll.cpp
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-08-01 22:36:12 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 12:53:02 -0700
commit2a003256ae3f88ebcbb021e0308a4f786605ae4f (patch)
tree045a597ee21da87ea533272ddf69af88262536a5 /mlir/lib/Transforms/LoopUnroll.cpp
parentb92378e8fa94ecac5d0f11f5e19b1e958f43f4f1 (diff)
downloadbcm5719-llvm-2a003256ae3f88ebcbb021e0308a4f786605ae4f.tar.gz
bcm5719-llvm-2a003256ae3f88ebcbb021e0308a4f786605ae4f.zip
MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions. - MLStmt cloning and IV replacement - While at this, fix the innermostLoopGatherer to actually gather all the innermost loops (it was stopping its walk at the first innermost loop it found) - Improve comments for MLFunction statement classes, fix inheritance order. - Fixed StmtBlock destructor. PiperOrigin-RevId: 207049173
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp66
1 files changed, 47 insertions, 19 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 160a463eb78..fe110d21d66 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
@@ -54,10 +55,13 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
typedef llvm::iplist<Statement> StmtListType;
bool walkPostOrder(StmtListType::iterator Start,
StmtListType::iterator End) {
+ bool hasInnerLoops = false;
+ // We need to walk all elements since all innermost loops need to be
+ // gathered as opposed to determining whether this list has any inner
+ // loops or not.
while (Start != End)
- if (walkPostOrder(&(*Start++)))
- return true;
- return false;
+ hasInnerLoops |= walkPostOrder(&(*Start++));
+ return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
@@ -73,12 +77,11 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
}
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
- if (walkPostOrder(ifStmt->getThenClause()->begin(),
- ifStmt->getThenClause()->end()) ||
- walkPostOrder(ifStmt->getElseClause()->begin(),
- ifStmt->getElseClause()->end()))
- return true;
- return false;
+ bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
+ ifStmt->getThenClause()->end());
+ hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
+ ifStmt->getElseClause()->end());
+ return hasInnerLoops;
}
bool walkOpStmt(OperationStmt *opStmt) { return false; }
@@ -93,17 +96,45 @@ void LoopUnroll::runOnMLFunction(MLFunction *f) {
runOnForStmt(forStmt);
}
-/// Unrolls this loop completely. Returns true if the unrolling happens.
+/// Replace an IV with a constant value.
+static void replaceIterator(Statement *stmt, const ForStmt &iv,
+ MLValue *constVal) {
+ struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
+ // IV to be replaced.
+ const ForStmt *iv;
+ // Constant to be replaced with.
+ MLValue *constVal;
+
+ ReplaceIterator(const ForStmt &iv, MLValue *constVal)
+ : iv(&iv), constVal(constVal){};
+
+ void visitOperationStmt(OperationStmt *os) {
+ for (auto &operand : os->getStmtOperands()) {
+ if (operand.get() == static_cast<const MLValue *>(iv)) {
+ operand.set(constVal);
+ }
+ }
+ }
+ };
+
+ ReplaceIterator ri(iv, constVal);
+ ri.walk(stmt);
+}
+
+/// Unrolls this loop completely.
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
auto trip_count = (ub - lb + 1) / step;
- auto *block = forStmt->getBlock();
- MLFuncBuilder builder(block);
+ auto *mlFunc = forStmt->Statement::findFunction();
+ MLFuncBuilder funcTopBuilder(mlFunc);
+ funcTopBuilder.setInsertionPointAtStart(mlFunc);
+ MLFuncBuilder builder(forStmt->getBlock());
for (int i = 0; i < trip_count; i++) {
+ auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
for (auto &stmt : forStmt->getStatements()) {
switch (stmt.getKind()) {
case Statement::Kind::For:
@@ -113,16 +144,13 @@ void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
llvm_unreachable("unrolling loops that have only operations");
break;
case Statement::Kind::Operation:
- auto *op = cast<OperationStmt>(&stmt);
- // TODO: clone operands and result types.
- builder.createOperation(op->getName(), /*operands*/ {},
- /*resultTypes*/ {}, op->getAttrs());
- // TODO: loop iterator parsing not yet implemented; replace loop
- // iterator uses in unrolled body appropriately.
+ auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
+ // TODO(bondhugula): only generate constants when the IV actually
+ // appears in the body.
+ replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
break;
}
}
}
-
forStmt->eraseFromBlock();
}
OpenPOWER on IntegriCloud