diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-23 08:17:48 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:35:19 -0700 |
| commit | 1301f907a10e25e4a05483977a50c8b4f34b2ed4 (patch) | |
| tree | 5d289f379659c4d458411958cf631e4d22ada678 /mlir/lib | |
| parent | 4eef795a1dbd7eafa9a45303f01c51921729f1f4 (diff) | |
| download | bcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.tar.gz bcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.zip | |
Refactor ForStmt: having it contain a StmtBlock instead of subclassing
StmtBlock. This is more consistent with IfStmt and also conceptually makes
more sense - a forstmt "isn't" its body, it contains its body.
This is step 1/N towards merging BasicBlock and StmtBlock. This is required
because in the new regime StmtBlock will have a use list (just like BasicBlock
does) of operands, and ForStmt already has a use list for its induction
variable.
This is a mechanical patch, NFC.
PiperOrigin-RevId: 226684158
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/AffineAnalysis.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Verifier.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/IR/Instructions.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/IR/Statement.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/IR/StmtBlock.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ConvertToCFG.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DmaGeneration.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopTiling.cpp | 18 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 19 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 19 |
17 files changed, 60 insertions, 53 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 91f4ccf4804..bdc2c7ec286 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -905,7 +905,7 @@ static StmtBlock *getCommonStmtBlock(const MemRefAccess &srcAccess, } auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); assert(isa<ForStmt>(commonForValue)); - return dyn_cast<ForStmt>(commonForValue); + return cast<ForStmt>(commonForValue)->getBody(); } // Returns true if the ancestor operation statement of 'srcAccess' properly diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 5e6bd7fa59b..3ee62bb2c42 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -305,9 +305,10 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) { // violation when we have the support. bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, ArrayRef<uint64_t> shifts) { - assert(shifts.size() == forStmt.getStatements().size()); + auto *forBody = forStmt.getBody(); + assert(shifts.size() == forBody->getStatements().size()); unsigned s = 0; - for (const auto &stmt : forStmt) { + for (const auto &stmt : *forBody) { // A for or if stmt does not produce any def/results (that are used // outside). if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { @@ -319,8 +320,8 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, // This is a naive way. If performance becomes an issue, a map can // be used to store 'shifts' - to look up the shift for a statement in // constant time. - if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner())) - if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)]) + if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner())) + if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)]) return false; } } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index cc30cfffb06..2428265acdb 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -362,7 +362,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions, if (level == positions.size() - 1) return &stmt; if (auto *childForStmt = dyn_cast<ForStmt>(&stmt)) - return getStmtAtPosition(positions, level + 1, childForStmt); + return getStmtAtPosition(positions, level + 1, childForStmt->getBody()); if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) { auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen()); @@ -453,13 +453,13 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess, // Clone src loop nest and insert it a the beginning of the statement block // of the loop at 'dstLoopDepth' in 'dstLoopNest'. auto *dstForStmt = dstLoopNest[dstLoopDepth - 1]; - MLFuncBuilder b(dstForStmt, dstForStmt->begin()); + MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin()); DenseMap<const MLValue *, MLValue *> operandMap; auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap)); // Lookup stmt in cloned 'sliceLoopNest' at 'positions'. Statement *sliceStmt = - getStmtAtPosition(positions, /*level=*/0, sliceLoopNest); + getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); // Get loop nest surrounding 'sliceStmt'. SmallVector<ForStmt *, 4> sliceSurroundingLoops; getLoopIVs(*sliceStmt, &sliceSurroundingLoops); diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index d955bcd5edb..6e1522a656f 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -288,8 +288,8 @@ bool MLFuncVerifier::verifyDominance() { HashTable::ScopeTy blockScope(liveValues); // The induction variable of a for statement is live within its body. - if (auto *forStmt = dyn_cast<ForStmt>(&block)) - liveValues.insert(forStmt, true); + if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block)) + liveValues.insert(forStmtBody->getFor(), true); for (auto &stmt : block) { // Verify that each of the operands are live. @@ -322,7 +322,7 @@ bool MLFuncVerifier::verifyDominance() { return true; } if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) - if (walkBlock(*forStmt)) + if (walkBlock(*forStmt->getBody())) return true; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index b798e3890a0..58f34af60f5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -206,7 +206,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) { if (!hasShorthandForm(ubMap)) recordAffineMapReference(ubMap); - for (auto &childStmt : *forStmt) + for (auto &childStmt : *forStmt->getBody()) visitStatement(&childStmt); } @@ -1447,7 +1447,7 @@ void MLFunctionPrinter::print(const ForStmt *stmt) { os << " step " << stmt->getStep(); os << " {\n"; - print(static_cast<const StmtBlock *>(stmt)); + print(stmt->getBody()); os.indent(numSpaces) << "}"; } diff --git a/mlir/lib/IR/Instructions.cpp b/mlir/lib/IR/Instructions.cpp index 9d65f4376b3..de73f3a96d3 100644 --- a/mlir/lib/IR/Instructions.cpp +++ b/mlir/lib/IR/Instructions.cpp @@ -147,7 +147,7 @@ Instruction *Instruction::clone() const { int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1; for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE; --succIt) { - successors[succIt] = getSuccessor(succIt); + successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt)); // Add the successor operands in-place in reverse order. for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e; diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 69afc5c1e98..f63c76605de 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -338,7 +338,7 @@ ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, : Statement(Kind::For, location), MLValue(MLValueKind::ForStmt, Type::getIndex(lbMap.getResult(0).getContext())), - StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) { + body(this), lbMap(lbMap), ubMap(ubMap), step(step) { operands.reserve(numOperands); } @@ -544,8 +544,8 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, operandMap[forStmt] = newFor; // Recursively clone the body of the for loop. - for (auto &subStmt : *forStmt) - newFor->push_back(subStmt.clone(operandMap, context)); + for (auto &subStmt : *forStmt->getBody()) + newFor->getBody()->push_back(subStmt.clone(operandMap, context)); return newFor; } diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 40a31f6c3b9..898dd7bc337 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -24,18 +24,19 @@ using namespace mlir; // Statement block //===----------------------------------------------------------------------===// -Statement *StmtBlock::getContainingStmt() const { +Statement *StmtBlock::getContainingStmt() { switch (kind) { case StmtBlockKind::MLFunc: return nullptr; - case StmtBlockKind::For: - return cast<ForStmt>(const_cast<StmtBlock *>(this)); + case StmtBlockKind::ForBody: + return cast<ForStmtBody>(this)->getFor(); case StmtBlockKind::IfClause: return cast<IfClause>(this)->getIf(); } } MLFunction *StmtBlock::findFunction() const { + // FIXME: const incorrect. StmtBlock *block = const_cast<StmtBlock *>(this); while (block->getContainingStmt()) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 46dd35682fd..781ec461b62 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2876,7 +2876,7 @@ ParseResult MLFunctionParser::parseForStmt() { // If parsing of the for statement body fails, // MLIR contains for statement with those nested statements that have been // successfully parsed. - if (parseStmtBlock(forStmt)) + if (parseStmtBlock(forStmt->getBody())) return ParseFailure; // Reset insertion point to the current block. diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 247a264cd5c..0ed803db64d 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -242,7 +242,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) { // Walking manually because we need custom logic before and after traversing // the list of children. builder.setInsertionPoint(loopBodyFirstBlock); - visitStmtBlock(forStmt); + visitStmtBlock(forStmt->getBody()); // Builder point is currently at the last block of the loop body. Append the // induction variable stepping to this block and branch back to the exit diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index bd7cad7fd3d..2b79064e53f 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -365,7 +365,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef), /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domStmtFilter=*/&*forStmt->begin()); + /*domStmtFilter=*/&*forStmt->getBody()->begin()); return true; } @@ -391,7 +391,7 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { // the pass has to be instantiated with additional information that we aren't // provided with at the moment. if (forStmt->getStep() != 1) { - if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->begin())) { + if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) { runOnForStmt(innerFor); } return; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 85c88f785d1..847db83aebc 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -59,12 +59,12 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } // destination's body. static inline void moveLoopBody(ForStmt *src, ForStmt *dest, StmtBlock::iterator loc) { - dest->getStatements().splice(loc, src->getStatements()); + dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements()); } // Move the loop body of ForStmt 'src' from 'src' to the start of dest's body. static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { - moveLoopBody(src, dest, dest->begin()); + moveLoopBody(src, dest, dest->getBody()->begin()); } /// Constructs and sets new loop bounds after tiling for the case of @@ -167,8 +167,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, MLFuncBuilder b(topLoop); // Loop bounds will be set later. auto *pointLoop = b.createFor(loc, 0, 0); - pointLoop->getStatements().splice( - pointLoop->begin(), topLoop->getBlock()->getStatements(), topLoop); + pointLoop->getBody()->getStatements().splice( + pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), + topLoop); newLoops[2 * width - 1 - i] = pointLoop; topLoop = pointLoop; if (i == 0) @@ -180,8 +181,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, MLFuncBuilder b(topLoop); // Loop bounds will be set later. auto *tileSpaceLoop = b.createFor(loc, 0, 0); - tileSpaceLoop->getStatements().splice( - tileSpaceLoop->begin(), topLoop->getBlock()->getStatements(), topLoop); + tileSpaceLoop->getBody()->getStatements().splice( + tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), + topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; topLoop = tileSpaceLoop; } @@ -223,8 +225,8 @@ static void getTileableBands(MLFunction *f, ForStmt *currStmt = root; do { band.push_back(currStmt); - } while (currStmt->getStatements().size() == 1 && - (currStmt = dyn_cast<ForStmt>(&*currStmt->begin()))); + } while (currStmt->getBody()->getStatements().size() == 1 && + (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin()))); bands->push_back(band); }; diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index a43087bd2e1..183613a2f69 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -104,7 +104,8 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) { } bool walkForStmtPostOrder(ForStmt *forStmt) { - bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end()); + bool hasInnerLoops = + walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end()); if (!hasInnerLoops) loops.push_back(forStmt); return true; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 45ca9dd98df..dd491f8119b 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -152,7 +152,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forStmt->getStatements().empty()) + if (unrollJamFactor == 1 || forStmt->getBody()->empty()) return false; Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index df30a779461..d4069eaa638 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -147,7 +147,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value()); loops.insert(forStmt); // Setting the insertion point to the innermost loop achieves nesting. - b.setInsertionPointToStart(loops.back()); + b.setInsertionPointToStart(loops.back()->getBody()); if (composed == getAffineConstantExpr(0, b.getContext())) { transfer->emitWarning( "Redundant copy can be implemented as a vector broadcast"); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index b656af0d69d..8d75bfbd7ae 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -81,8 +81,9 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { /// the loop IV of the specified 'for' statement modulo 2. Returns false if such /// a replacement cannot be performed. static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { - MLFuncBuilder bInner(forStmt, forStmt->begin()); - bInner.setInsertionPoint(forStmt, forStmt->begin()); + auto *forBody = forStmt->getBody(); + MLFuncBuilder bInner(forBody, forBody->begin()); + bInner.setInsertionPoint(forBody, forBody->begin()); // Doubles the shape with a leading dimension extent of 2. auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { @@ -127,7 +128,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef), ivModTwoOp->getResult(0), AffineMap::Null(), {}, - &*forStmt->begin())) { + &*forStmt->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getOperation()->erase(); @@ -184,7 +185,7 @@ static void findMatchingStartFinishStmts( // Collect outgoing DMA statements - needed to check for dependences below. SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; @@ -195,7 +196,7 @@ static void findMatchingStartFinishStmts( } SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; @@ -228,7 +229,7 @@ static void findMatchingStartFinishStmts( cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos())); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!dominates(*forStmt->begin(), *use.getOwner())) { + if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -339,16 +340,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &stmt : *forStmt) { + for (const auto &stmt : *forStmt->getBody()) { if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) { stmtShiftMap[&stmt] = 1; } } // Get shifts stored in map. - std::vector<uint64_t> shifts(forStmt->getStatements().size()); + std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size()); unsigned s = 0; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end()); shifts[s++] = stmtShiftMap[&stmt]; LLVM_DEBUG( diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 791997e7ff1..4d75f7c0835 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -119,7 +119,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { // Move the loop body statements to the loop's containing block. auto *block = forStmt->getBlock(); block->getStatements().splice(StmtBlock::iterator(forStmt), - forStmt->getStatements()); + forStmt->getBody()->getStatements()); forStmt->erase(); return true; } @@ -181,7 +181,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, operandMap[srcForStmt] = loopChunk; } for (auto *stmt : stmts) { - loopChunk->push_back(stmt->clone(operandMap, b->getContext())); + loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext())); } } if (promoteIfSingleIteration(loopChunk)) @@ -206,7 +206,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // method. UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, bool unrollPrologueEpilogue) { - if (forStmt->getStatements().empty()) + if (forStmt->getBody()->empty()) return UtilResult::Success; // If the trip counts aren't constant, we would need versioning and @@ -225,7 +225,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, int64_t step = forStmt->getStep(); - unsigned numChildStmts = forStmt->getStatements().size(); + unsigned numChildStmts = forStmt->getBody()->getStatements().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -243,7 +243,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, // body of the 'for' stmt. std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1); unsigned pos = 0; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { auto shift = shifts[pos++]; sortedStmtGroups[shift].push_back(&stmt); } @@ -352,7 +352,7 @@ bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) { bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); - if (unrollFactor == 1 || forStmt->getStatements().empty()) + if (unrollFactor == 1 || forStmt->getBody()->empty()) return false; auto lbMap = forStmt->getLowerBoundMap(); @@ -406,11 +406,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Builder to insert unrolled bodies right after the last statement in the // body of 'forStmt'. - MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end())); + MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end()); // Keep a pointer to the last statement in the original block so that we know // what to clone (since we are doing this in-place). - StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end()); + StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { @@ -429,7 +429,8 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { } // Clone the original body of 'forStmt'. - for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) { + for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd); + it++) { builder.clone(*it, operandMap); } } |

