summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorChris Lattner <clattner@google.com>2018-12-23 08:17:48 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 14:35:19 -0700
commit1301f907a10e25e4a05483977a50c8b4f34b2ed4 (patch)
tree5d289f379659c4d458411958cf631e4d22ada678 /mlir/lib
parent4eef795a1dbd7eafa9a45303f01c51921729f1f4 (diff)
downloadbcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.tar.gz
bcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.zip
Refactor ForStmt: having it contain a StmtBlock instead of subclassing
StmtBlock. This is more consistent with IfStmt and also conceptually makes more sense - a forstmt "isn't" its body, it contains its body. This is step 1/N towards merging BasicBlock and StmtBlock. This is required because in the new regime StmtBlock will have a use list (just like BasicBlock does) of operands, and ForStmt already has a use list for its induction variable. This is a mechanical patch, NFC. PiperOrigin-RevId: 226684158
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp2
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp9
-rw-r--r--mlir/lib/Analysis/Utils.cpp6
-rw-r--r--mlir/lib/Analysis/Verifier.cpp6
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp4
-rw-r--r--mlir/lib/IR/Instructions.cpp2
-rw-r--r--mlir/lib/IR/Statement.cpp6
-rw-r--r--mlir/lib/IR/StmtBlock.cpp7
-rw-r--r--mlir/lib/Parser/Parser.cpp2
-rw-r--r--mlir/lib/Transforms/ConvertToCFG.cpp2
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp4
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp18
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp3
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp2
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp2
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp19
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp19
17 files changed, 60 insertions, 53 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 91f4ccf4804..bdc2c7ec286 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -905,7 +905,7 @@ static StmtBlock *getCommonStmtBlock(const MemRefAccess &srcAccess,
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
assert(isa<ForStmt>(commonForValue));
- return dyn_cast<ForStmt>(commonForValue);
+ return cast<ForStmt>(commonForValue)->getBody();
}
// Returns true if the ancestor operation statement of 'srcAccess' properly
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 5e6bd7fa59b..3ee62bb2c42 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -305,9 +305,10 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) {
// violation when we have the support.
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
ArrayRef<uint64_t> shifts) {
- assert(shifts.size() == forStmt.getStatements().size());
+ auto *forBody = forStmt.getBody();
+ assert(shifts.size() == forBody->getStatements().size());
unsigned s = 0;
- for (const auto &stmt : forStmt) {
+ for (const auto &stmt : *forBody) {
// A for or if stmt does not produce any def/results (that are used
// outside).
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
@@ -319,8 +320,8 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
// This is a naive way. If performance becomes an issue, a map can
// be used to store 'shifts' - to look up the shift for a statement in
// constant time.
- if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
- if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
+ if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner()))
+ if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)])
return false;
}
}
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index cc30cfffb06..2428265acdb 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -362,7 +362,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
if (level == positions.size() - 1)
return &stmt;
if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
- return getStmtAtPosition(positions, level + 1, childForStmt);
+ return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
@@ -453,13 +453,13 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// Clone src loop nest and insert it a the beginning of the statement block
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
- MLFuncBuilder b(dstForStmt, dstForStmt->begin());
+ MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
DenseMap<const MLValue *, MLValue *> operandMap;
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
Statement *sliceStmt =
- getStmtAtPosition(positions, /*level=*/0, sliceLoopNest);
+ getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceStmt'.
SmallVector<ForStmt *, 4> sliceSurroundingLoops;
getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index d955bcd5edb..6e1522a656f 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() {
HashTable::ScopeTy blockScope(liveValues);
// The induction variable of a for statement is live within its body.
- if (auto *forStmt = dyn_cast<ForStmt>(&block))
- liveValues.insert(forStmt, true);
+ if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
+ liveValues.insert(forStmtBody->getFor(), true);
for (auto &stmt : block) {
// Verify that each of the operands are live.
@@ -322,7 +322,7 @@ bool MLFuncVerifier::verifyDominance() {
return true;
}
if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
- if (walkBlock(*forStmt))
+ if (walkBlock(*forStmt->getBody()))
return true;
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index b798e3890a0..58f34af60f5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -206,7 +206,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) {
if (!hasShorthandForm(ubMap))
recordAffineMapReference(ubMap);
- for (auto &childStmt : *forStmt)
+ for (auto &childStmt : *forStmt->getBody())
visitStatement(&childStmt);
}
@@ -1447,7 +1447,7 @@ void MLFunctionPrinter::print(const ForStmt *stmt) {
os << " step " << stmt->getStep();
os << " {\n";
- print(static_cast<const StmtBlock *>(stmt));
+ print(stmt->getBody());
os.indent(numSpaces) << "}";
}
diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp
index 9d65f4376b3..de73f3a96d3 100644
--- a/mlir/lib/IR/Instructions.cpp
+++ b/mlir/lib/IR/Instructions.cpp
@@ -147,7 +147,7 @@ Instruction *Instruction::clone() const {
int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1;
for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE;
--succIt) {
- successors[succIt] = getSuccessor(succIt);
+ successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt));
// Add the successor operands in-place in reverse order.
for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e;
diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp
index 69afc5c1e98..f63c76605de 100644
--- a/mlir/lib/IR/Statement.cpp
+++ b/mlir/lib/IR/Statement.cpp
@@ -338,7 +338,7 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())),
- StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
+ body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
operands.reserve(numOperands);
}
@@ -544,8 +544,8 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
operandMap[forStmt] = newFor;
// Recursively clone the body of the for loop.
- for (auto &subStmt : *forStmt)
- newFor->push_back(subStmt.clone(operandMap, context));
+ for (auto &subStmt : *forStmt->getBody())
+ newFor->getBody()->push_back(subStmt.clone(operandMap, context));
return newFor;
}
diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp
index 40a31f6c3b9..898dd7bc337 100644
--- a/mlir/lib/IR/StmtBlock.cpp
+++ b/mlir/lib/IR/StmtBlock.cpp
@@ -24,18 +24,19 @@ using namespace mlir;
// Statement block
//===----------------------------------------------------------------------===//
-Statement *StmtBlock::getContainingStmt() const {
+Statement *StmtBlock::getContainingStmt() {
switch (kind) {
case StmtBlockKind::MLFunc:
return nullptr;
- case StmtBlockKind::For:
- return cast<ForStmt>(const_cast<StmtBlock *>(this));
+ case StmtBlockKind::ForBody:
+ return cast<ForStmtBody>(this)->getFor();
case StmtBlockKind::IfClause:
return cast<IfClause>(this)->getIf();
}
}
MLFunction *StmtBlock::findFunction() const {
+ // FIXME: const incorrect.
StmtBlock *block = const_cast<StmtBlock *>(this);
while (block->getContainingStmt()) {
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 46dd35682fd..781ec461b62 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -2876,7 +2876,7 @@ ParseResult MLFunctionParser::parseForStmt() {
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been
// successfully parsed.
- if (parseStmtBlock(forStmt))
+ if (parseStmtBlock(forStmt->getBody()))
return ParseFailure;
// Reset insertion point to the current block.
diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp
index 247a264cd5c..0ed803db64d 100644
--- a/mlir/lib/Transforms/ConvertToCFG.cpp
+++ b/mlir/lib/Transforms/ConvertToCFG.cpp
@@ -242,7 +242,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// Walking manually because we need custom logic before and after traversing
// the list of children.
builder.setInsertionPoint(loopBodyFirstBlock);
- visitStmtBlock(forStmt);
+ visitStmtBlock(forStmt->getBody());
// Builder point is currently at the last block of the loop body. Append the
// induction variable stepping to this block and branch back to the exit
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index bd7cad7fd3d..2b79064e53f 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -365,7 +365,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef),
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domStmtFilter=*/&*forStmt->begin());
+ /*domStmtFilter=*/&*forStmt->getBody()->begin());
return true;
}
@@ -391,7 +391,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
if (forStmt->getStep() != 1) {
- if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->begin())) {
+ if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
runOnForStmt(innerFor);
}
return;
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 85c88f785d1..847db83aebc 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -59,12 +59,12 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
// destination's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
StmtBlock::iterator loc) {
- dest->getStatements().splice(loc, src->getStatements());
+ dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements());
}
// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
- moveLoopBody(src, dest, dest->begin());
+ moveLoopBody(src, dest, dest->getBody()->begin());
}
/// Constructs and sets new loop bounds after tiling for the case of
@@ -167,8 +167,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
MLFuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *pointLoop = b.createFor(loc, 0, 0);
- pointLoop->getStatements().splice(
- pointLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
+ pointLoop->getBody()->getStatements().splice(
+ pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
+ topLoop);
newLoops[2 * width - 1 - i] = pointLoop;
topLoop = pointLoop;
if (i == 0)
@@ -180,8 +181,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
MLFuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *tileSpaceLoop = b.createFor(loc, 0, 0);
- tileSpaceLoop->getStatements().splice(
- tileSpaceLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
+ tileSpaceLoop->getBody()->getStatements().splice(
+ tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
+ topLoop);
newLoops[2 * width - i - 1] = tileSpaceLoop;
topLoop = tileSpaceLoop;
}
@@ -223,8 +225,8 @@ static void getTileableBands(MLFunction *f,
ForStmt *currStmt = root;
do {
band.push_back(currStmt);
- } while (currStmt->getStatements().size() == 1 &&
- (currStmt = dyn_cast<ForStmt>(&*currStmt->begin())));
+ } while (currStmt->getBody()->getStatements().size() == 1 &&
+ (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
bands->push_back(band);
};
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index a43087bd2e1..183613a2f69 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -104,7 +104,8 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
}
bool walkForStmtPostOrder(ForStmt *forStmt) {
- bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
+ bool hasInnerLoops =
+ walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
return true;
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 45ca9dd98df..dd491f8119b 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -152,7 +152,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
- if (unrollJamFactor == 1 || forStmt->getStatements().empty())
+ if (unrollJamFactor == 1 || forStmt->getBody()->empty())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index df30a779461..d4069eaa638 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -147,7 +147,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
loops.insert(forStmt);
// Setting the insertion point to the innermost loop achieves nesting.
- b.setInsertionPointToStart(loops.back());
+ b.setInsertionPointToStart(loops.back()->getBody());
if (composed == getAffineConstantExpr(0, b.getContext())) {
transfer->emitWarning(
"Redundant copy can be implemented as a vector broadcast");
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index b656af0d69d..8d75bfbd7ae 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -81,8 +81,9 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
/// a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
- MLFuncBuilder bInner(forStmt, forStmt->begin());
- bInner.setInsertionPoint(forStmt, forStmt->begin());
+ auto *forBody = forStmt->getBody();
+ MLFuncBuilder bInner(forBody, forBody->begin());
+ bInner.setInsertionPoint(forBody, forBody->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
@@ -127,7 +128,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
ivModTwoOp->getResult(0), AffineMap::Null(), {},
- &*forStmt->begin())) {
+ &*forStmt->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getOperation()->erase();
@@ -184,7 +185,7 @@ static void findMatchingStartFinishStmts(
// Collect outgoing DMA statements - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
@@ -195,7 +196,7 @@ static void findMatchingStartFinishStmts(
}
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
@@ -228,7 +229,7 @@ static void findMatchingStartFinishStmts(
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!dominates(*forStmt->begin(), *use.getOwner())) {
+ if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@@ -339,16 +340,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
- for (const auto &stmt : *forStmt) {
+ for (const auto &stmt : *forStmt->getBody()) {
if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
stmtShiftMap[&stmt] = 1;
}
}
// Get shifts stored in map.
- std::vector<uint64_t> shifts(forStmt->getStatements().size());
+ std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size());
unsigned s = 0;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
shifts[s++] = stmtShiftMap[&stmt];
LLVM_DEBUG(
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 791997e7ff1..4d75f7c0835 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -119,7 +119,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
// Move the loop body statements to the loop's containing block.
auto *block = forStmt->getBlock();
block->getStatements().splice(StmtBlock::iterator(forStmt),
- forStmt->getStatements());
+ forStmt->getBody()->getStatements());
forStmt->erase();
return true;
}
@@ -181,7 +181,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
operandMap[srcForStmt] = loopChunk;
}
for (auto *stmt : stmts) {
- loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
+ loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
}
}
if (promoteIfSingleIteration(loopChunk))
@@ -206,7 +206,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// method.
UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
- if (forStmt->getStatements().empty())
+ if (forStmt->getBody()->empty())
return UtilResult::Success;
// If the trip counts aren't constant, we would need versioning and
@@ -225,7 +225,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
int64_t step = forStmt->getStep();
- unsigned numChildStmts = forStmt->getStatements().size();
+ unsigned numChildStmts = forStmt->getBody()->getStatements().size();
// Do a linear time (counting) sort for the shifts.
uint64_t maxShift = 0;
@@ -243,7 +243,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
// body of the 'for' stmt.
std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
unsigned pos = 0;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto shift = shifts[pos++];
sortedStmtGroups[shift].push_back(&stmt);
}
@@ -352,7 +352,7 @@ bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
- if (unrollFactor == 1 || forStmt->getStatements().empty())
+ if (unrollFactor == 1 || forStmt->getBody()->empty())
return false;
auto lbMap = forStmt->getLowerBoundMap();
@@ -406,11 +406,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Builder to insert unrolled bodies right after the last statement in the
// body of 'forStmt'.
- MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+ MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
// Keep a pointer to the last statement in the original block so that we know
// what to clone (since we are doing this in-place).
- StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end());
+ StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
for (unsigned i = 1; i < unrollFactor; i++) {
@@ -429,7 +429,8 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
}
// Clone the original body of 'forStmt'.
- for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) {
+ for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
+ it++) {
builder.clone(*it, operandMap);
}
}
OpenPOWER on IntegriCloud