diff options
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 200 |
1 files changed, 100 insertions, 100 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index c8a6ced4ed1..debaac3a33c 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -39,14 +39,14 @@ using namespace mlir; namespace { struct PipelineDataTransfer : public FunctionPass, - StmtWalker<PipelineDataTransfer> { + InstWalker<PipelineDataTransfer> { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnMLFunction(Function *f) override; - PassResult runOnForStmt(ForStmt *forStmt); + PassResult runOnForInst(ForInst *forInst); - // Collect all 'for' statements. - void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } - std::vector<ForStmt *> forStmts; + // Collect all 'for' instructions. + void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } + std::vector<ForInst *> forInsts; static char passID; }; @@ -61,26 +61,26 @@ FunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// Returns the position of the tag memref operand given a DMA statement. +// Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationInst &dmaStmt) { - assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>()); - if (dmaStmt.isa<DmaStartOp>()) { +static unsigned getTagMemRefPos(const OperationInst &dmaInst) { + assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>()); + if (dmaInst.isa<DmaStartOp>()) { // Second to last operand. - return dmaStmt.getNumOperands() - 2; + return dmaInst.getNumOperands() - 2; } - // First operand for a dma finish statement. + // First operand for a dma finish instruction. return 0; } -/// Doubles the buffer of the supplied memref on the specified 'for' statement +/// Doubles the buffer of the supplied memref on the specified 'for' instruction /// by adding a leading dimension of size two to the memref. Replaces all uses /// of the old memref by the new one while indexing the newly added dimension by -/// the loop IV of the specified 'for' statement modulo 2. Returns false if such -/// a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { - auto *forBody = forStmt->getBody(); +/// the loop IV of the specified 'for' instruction modulo 2. Returns false if +/// such a replacement cannot be performed. +static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { + auto *forBody = forInst->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -101,33 +101,33 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { auto newMemRefType = doubleShape(oldMemRefType); // Put together alloc operands for the dynamic dimensions of the memref. - FuncBuilder bOuter(forStmt); + FuncBuilder bOuter(forInst); SmallVector<Value *, 4> allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) - allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef, + allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++)); } - // Create and place the alloc right before the 'for' statement. + // Create and place the alloc right before the 'for' instruction. // TODO(mlir-team): we are assuming scoped allocation here, and aren't // inserting a dealloc -- this isn't the right thing. Value *newMemRef = - bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands); + bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {}); auto ivModTwoOp = - bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt); + bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst); - // replaceAllMemRefUsesWith will always succeed unless the forStmt body has + // replaceAllMemRefUsesWith will always succeed unless the forInst body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0), AffineMap::Null(), {}, - &*forStmt->getBody()->begin())) { + &*forInst->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -139,15 +139,15 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) { /// Returns success if the IR is in a valid state. PassResult PipelineDataTransfer::runOnMLFunction(Function *f) { // Do a post order walk so that inner loop DMAs are processed first. This is - // necessary since 'for' statements nested within would otherwise become + // necessary since 'for' instructions nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). - forStmts.clear(); + forInsts.clear(); walkPostOrder(f); bool ret = false; - for (auto *forStmt : forStmts) { - ret = ret | runOnForStmt(forStmt); + for (auto *forInst : forInsts) { + ret = ret | runOnForInst(forInst); } return ret ? failure() : success(); } @@ -176,36 +176,36 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp, return true; } -// Identify matching DMA start/finish statements to overlap computation with. -static void findMatchingStartFinishStmts( - ForStmt *forStmt, +// Identify matching DMA start/finish instructions to overlap computation with. +static void findMatchingStartFinishInsts( + ForInst *forInst, SmallVectorImpl<std::pair<OperationInst *, OperationInst *>> &startWaitPairs) { - // Collect outgoing DMA statements - needed to check for dependences below. + // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps; - for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast<OperationInst>(&stmt); - if (!opStmt) + for (auto &inst : *forInst->getBody()) { + auto *opInst = dyn_cast<OperationInst>(&inst); + if (!opInst) continue; OpPointer<DmaStartOp> dmaStartOp; - if ((dmaStartOp = opStmt->dyn_cast<DmaStartOp>()) && + if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } - SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts; - for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast<OperationInst>(&stmt); - if (!opStmt) + SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts; + for (auto &inst : *forInst->getBody()) { + auto *opInst = dyn_cast<OperationInst>(&inst); + if (!opInst) continue; - // Collect DMA finish statements. - if (opStmt->isa<DmaWaitOp>()) { - dmaFinishStmts.push_back(opStmt); + // Collect DMA finish instructions. + if (opInst->isa<DmaWaitOp>()) { + dmaFinishInsts.push_back(opInst); continue; } OpPointer<DmaStartOp> dmaStartOp; - if (!(dmaStartOp = opStmt->dyn_cast<DmaStartOp>())) + if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>())) continue; // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. @@ -227,7 +227,7 @@ static void findMatchingStartFinishStmts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { + if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -235,15 +235,15 @@ static void findMatchingStartFinishStmts( } } if (!escapingUses) - dmaStartStmts.push_back(opStmt); + dmaStartInsts.push_back(opInst); } - // For each start statement, we look for a matching finish statement. - for (auto *dmaStartStmt : dmaStartStmts) { - for (auto *dmaFinishStmt : dmaFinishStmts) { - if (checkTagMatch(dmaStartStmt->cast<DmaStartOp>(), - dmaFinishStmt->cast<DmaWaitOp>())) { - startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt}); + // For each start instruction, we look for a matching finish instruction. + for (auto *dmaStartInst : dmaStartInsts) { + for (auto *dmaFinishInst : dmaFinishInsts) { + if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(), + dmaFinishInst->cast<DmaWaitOp>())) { + startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); break; } } @@ -251,17 +251,17 @@ static void findMatchingStartFinishStmts( } /// Overlap DMA transfers with computation in this loop. If successful, -/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { - auto mayBeConstTripCount = getConstantTripCount(*forStmt); +PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { + auto mayBeConstTripCount = getConstantTripCount(*forInst); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return success(); } SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs; - findMatchingStartFinishStmts(forStmt, startWaitPairs); + findMatchingStartFinishInsts(forInst, startWaitPairs); if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); @@ -269,22 +269,22 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } // Double the buffers for the higher memory space memref's. - // Identify memref's to replace by scanning through all DMA start statements. - // A DMA start statement has two memref's - the one from the higher level of - // memory hierarchy is the one to double buffer. + // Identify memref's to replace by scanning through all DMA start + // instructions. A DMA start instruction has two memref's - the one from the + // higher level of memory hierarchy is the one to double buffer. // TODO(bondhugula): check whether double-buffering is even necessary. // TODO(bondhugula): make this work with different layouts: assuming here that // the dimension we are adding here for the double buffering is the outermost // dimension. for (auto &pair : startWaitPairs) { - auto *dmaStartStmt = pair.first; - Value *oldMemRef = dmaStartStmt->getOperand( - dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()); - if (!doubleBuffer(oldMemRef, forStmt)) { + auto *dmaStartInst = pair.first; + Value *oldMemRef = dmaStartInst->getOperand( + dmaStartInst->cast<DmaStartOp>()->getFasterMemPos()); + if (!doubleBuffer(oldMemRef, forInst)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); - LLVM_DEBUG(dmaStartStmt->dump()); + LLVM_DEBUG(dmaStartInst->dump()); // IR still in a valid state. return success(); } @@ -293,80 +293,80 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // operation could have been used on it if it was dynamically shaped in // order to create the double buffer above) if (oldMemRef->use_empty()) - if (auto *allocStmt = oldMemRef->getDefiningInst()) - allocStmt->erase(); + if (auto *allocInst = oldMemRef->getDefiningInst()) + allocInst->erase(); } // Double the buffers for tag memrefs. for (auto &pair : startWaitPairs) { - auto *dmaFinishStmt = pair.second; + auto *dmaFinishInst = pair.second; Value *oldTagMemRef = - dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)); - if (!doubleBuffer(oldTagMemRef, forStmt)) { + dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); + if (!doubleBuffer(oldTagMemRef, forInst)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); } // If the old tag has no more uses, remove its 'dead' alloc if it was // alloc'ed. if (oldTagMemRef->use_empty()) - if (auto *allocStmt = oldTagMemRef->getDefiningInst()) - allocStmt->erase(); + if (auto *allocInst = oldTagMemRef->getDefiningInst()) + allocInst->erase(); } - // Double buffering would have invalidated all the old DMA start/wait stmts. + // Double buffering would have invalidated all the old DMA start/wait insts. startWaitPairs.clear(); - findMatchingStartFinishStmts(forStmt, startWaitPairs); + findMatchingStartFinishInsts(forInst, startWaitPairs); - // Store shift for statement for later lookup for AffineApplyOp's. - DenseMap<const Statement *, unsigned> stmtShiftMap; + // Store shift for instruction for later lookup for AffineApplyOp's. + DenseMap<const Instruction *, unsigned> instShiftMap; for (auto &pair : startWaitPairs) { - auto *dmaStartStmt = pair.first; - assert(dmaStartStmt->isa<DmaStartOp>()); - stmtShiftMap[dmaStartStmt] = 0; - // Set shifts for DMA start stmt's affine operand computation slices to 0. - if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) { - stmtShiftMap[slice] = 0; + auto *dmaStartInst = pair.first; + assert(dmaStartInst->isa<DmaStartOp>()); + instShiftMap[dmaStartInst] = 0; + // Set shifts for DMA start inst's affine operand computation slices to 0. + if (auto *slice = mlir::createAffineComputationSlice(dmaStartInst)) { + instShiftMap[slice] = 0; } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector<OperationInst *, 4> affineApplyStmts; - SmallVector<Value *, 4> operands(dmaStartStmt->getOperands()); - getReachableAffineApplyOps(operands, affineApplyStmts); - for (const auto *stmt : affineApplyStmts) { - stmtShiftMap[stmt] = 0; + SmallVector<OperationInst *, 4> affineApplyInsts; + SmallVector<Value *, 4> operands(dmaStartInst->getOperands()); + getReachableAffineApplyOps(operands, affineApplyInsts); + for (const auto *inst : affineApplyInsts) { + instShiftMap[inst] = 0; } } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &stmt : *forStmt->getBody()) { - if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) { - stmtShiftMap[&stmt] = 1; + for (const auto &inst : *forInst->getBody()) { + if (instShiftMap.find(&inst) == instShiftMap.end()) { + instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector<uint64_t> shifts(forStmt->getBody()->getInstructions().size()); + std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size()); unsigned s = 0; - for (auto &stmt : *forStmt->getBody()) { - assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end()); - shifts[s++] = stmtShiftMap[&stmt]; + for (auto &inst : *forInst->getBody()) { + assert(instShiftMap.find(&inst) != instShiftMap.end()); + shifts[s++] = instShiftMap[&inst]; LLVM_DEBUG( - // Tagging statements with shifts for debugging purposes. - if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { - FuncBuilder b(opStmt); - opStmt->setAttr(b.getIdentifier("shift"), + // Tagging instructions with shifts for debugging purposes. + if (auto *opInst = dyn_cast<OperationInst>(&inst)) { + FuncBuilder b(opInst); + opInst->setAttr(b.getIdentifier("shift"), b.getI64IntegerAttr(shifts[s - 1])); }); } - if (!isStmtwiseShiftValid(*forStmt, shifts)) { + if (!isInstwiseShiftValid(*forInst, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); return success(); } - if (stmtBodySkew(forStmt, shifts)) { - LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";); + if (instBodySkew(forInst, shifts)) { + LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); return success(); } |

