diff options
| author | Uday Bondhugula <bondhugula@google.com> | 2018-10-18 11:14:26 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 13:32:28 -0700 |
| commit | 18e666702cd00e0b9c1dafc9801fcdda6dfdb704 (patch) | |
| tree | 17326676b703fbf74216a38657cf8df3cf5d3d2b /mlir/lib/Transforms/PipelineDataTransfer.cpp | |
| parent | 3013dadb7c3326f016b3e6bf02f3df9a0d3efa6a (diff) | |
| download | bcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.tar.gz bcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.zip | |
Generalize / improve DMA transfer overlap; nested and multiple DMA support; resolve
multiple TODOs.
- replace the fake test pass (that worked on just the first loop in the
MLFunction) to perform DMA pipelining on all suitable loops.
- nested DMAs work now (DMAs in an outer loop, more DMAs in nested inner loops)
- fix bugs / assumptions: correctly copy memory space and elemental type of source
memref for double buffering.
- correctly identify matching start/finish statements, handle multiple DMAs per
loop.
- introduce dominates/properlyDominates utitilies for MLFunction statements.
- move checkDominancePreservationOnShifts to LoopAnalysis.h; rename it
getShiftValidity
- refactor getContainingStmtPos -> findAncestorStmtInBlock - move into
Analysis/Utils.h; has two users.
- other improvements / cleanup for related API/utilities
- add size argument to dma_wait - for nested DMAs or in general, it makes it
easy to obtain the size to use when lowering the dma_wait since we wouldn't
want to identify the matching dma_start, and more importantly, in general/in the
future, there may not always be a dma_start dominating the dma_wait.
- add debug information in the pass
PiperOrigin-RevId: 217734892
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 264 |
1 files changed, 169 insertions, 95 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index bb60d8e9d78..d6a064988fb 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -22,21 +22,31 @@ #include "mlir/Transforms/Passes.h" #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/StmtVisitor.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "pipeline-data-transfer" using namespace mlir; namespace { -struct PipelineDataTransfer : public MLFunctionPass { - explicit PipelineDataTransfer() {} +struct PipelineDataTransfer : public MLFunctionPass, + StmtWalker<PipelineDataTransfer> { PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnForStmt(ForStmt *forStmt); + + // Collect all 'for' statements. + void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } + std::vector<ForStmt *> forStmts; }; } // end anonymous namespace @@ -47,20 +57,6 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -/// Given a DMA start operation, returns the operand position of either the -/// source or destination memref depending on the one that is at the higher -/// level of the memory hierarchy. -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are -// added. TODO(b/117228571) -static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) { - unsigned srcDmaPos = 0; - unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1; - - if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace()) - return srcDmaPos; - return destDmaPos; -} - // Returns the position of the tag memref operand given a DMA statement. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) @@ -76,18 +72,20 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { /// Doubles the buffer of the supplied memref while replacing all uses of the /// old memref. Returns false if such a replacement cannot be performed. -static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { +static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) { MLFuncBuilder bInner(forStmt, forStmt->begin()); bInner.setInsertionPoint(forStmt, forStmt->begin()); // Doubles the shape with a leading dimension extent of 2. - auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * { + auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * { // Add the leading dimension in the shape for the double buffer. - ArrayRef<int> shape = origMemRefType->getShape(); + ArrayRef<int> shape = oldMemRefType->getShape(); SmallVector<int, 4> shapeSizes(shape.begin(), shape.end()); shapeSizes.insert(shapeSizes.begin(), 2); - auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type()); + auto *newMemRefType = + bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {}, + oldMemRefType->getMemorySpace()); return newMemRefType; }; @@ -105,113 +103,187 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { auto ivModTwoOp = bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt); if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, - cast<MLValue>(ivModTwoOp->getResult(0)))) + cast<MLValue>(ivModTwoOp->getResult(0)))) { + LLVM_DEBUG(llvm::dbgs() + << "memref replacement for double buffering failed\n";); + cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock(); return false; + } return true; } -// For testing purposes, this just runs on the first 'for' statement of an -// MLFunction at the top level. -// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when -// the other TODOs listed inside are dealt with. +/// Returns false if this succeeds on at least one 'for' stmt. PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { if (f->empty()) return PassResult::Success; - ForStmt *forStmt = nullptr; - for (auto &stmt : *f) { - if ((forStmt = dyn_cast<ForStmt>(&stmt))) { - break; - } + // Do a post order walk so that inner loop DMAs are processed first. This is + // necessary since 'for' statements nested within would otherwise become + // invalid (erased) when the outer loop is pipelined (the pipelined one gets + // deleted and replaced by a prologue, a new steady-state loop and an + // epilogue). + forStmts.clear(); + walkPostOrder(f); + bool ret = true; + for (auto *forStmt : forStmts) { + ret = ret & runOnForStmt(forStmt); } - if (!forStmt) - return PassResult::Success; - - unsigned numStmts = forStmt->getStatements().size(); + return ret ? failure() : success(); +} - if (numStmts == 0) - return PassResult::Success; +// Check if tags of the dma start op and dma wait op match. +static bool checkTagMatch(OpPointer<DmaStartOp> startOp, + OpPointer<DmaWaitOp> waitOp) { + if (startOp->getTagMemRef() != waitOp->getTagMemRef()) + return false; + auto startIndices = startOp->getTagIndices(); + auto waitIndices = waitOp->getTagIndices(); + // Both of these have the same number of indices since they correspond to the + // same tag memref. + for (auto it = startIndices.begin(), wIt = waitIndices.begin(), + e = startIndices.end(); + it != e; ++it, ++wIt) { + // Keep it simple for now, just checking if indices match. + // TODO(mlir-team): this would in general need to check if there is no + // intervening write writing to the same tag location, i.e., memory last + // write/data flow analysis. This is however sufficient/powerful enough for + // now since the DMA generation pass or the input for it will always have + // start/wait with matching tags (same SSA operand indices). + if (*it != *wIt) + return false; + } + return true; +} - SmallVector<OperationStmt *, 4> dmaStartStmts; - SmallVector<OperationStmt *, 4> dmaFinishStmts; +// Identify matching DMA start/finish statements to overlap computation with. +static void findMatchingStartFinishStmts( + ForStmt *forStmt, + SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>> + &startWaitPairs) { + SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts; for (auto &stmt : *forStmt) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; - if (opStmt->is<DmaStartOp>()) { - dmaStartStmts.push_back(opStmt); - } else if (opStmt->is<DmaWaitOp>()) { + // Collect DMA finish statements. + if (opStmt->is<DmaWaitOp>()) { dmaFinishStmts.push_back(opStmt); + continue; + } + OpPointer<DmaStartOp> dmaStartOp; + if (!(dmaStartOp = opStmt->getAs<DmaStartOp>())) + continue; + // Only DMAs incoming into higher memory spaces. + // TODO(bondhugula): outgoing DMAs. + if (!dmaStartOp->isDestMemorySpaceFaster()) + continue; + + // We only double buffer if the buffer is not live out of loop. + const MLValue *memref = + cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos())); + bool escapingUses = false; + for (const auto &use : memref->getUses()) { + if (!dominates(*forStmt, *use.getOwner())) { + LLVM_DEBUG(llvm::dbgs() + << "can't pipeline: buffer is live out of loop\n";); + escapingUses = true; + break; + } + } + if (!escapingUses) + dmaStartStmts.push_back(opStmt); + } + + // For each start statement, we look for a matching finish statement. + for (auto *dmaStartStmt : dmaStartStmts) { + for (auto *dmaFinishStmt : dmaFinishStmts) { + if (checkTagMatch(dmaStartStmt->getAs<DmaStartOp>(), + dmaFinishStmt->getAs<DmaWaitOp>())) { + startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt}); + break; + } } } +} - // TODO(bondhugula,andydavis): match tag memref's (requires memory-based - // subscript check utilities). Assume for now that start/finish are matched in - // the order they appear. - if (dmaStartStmts.size() != dmaFinishStmts.size()) +/// Overlap DMA transfers with computation in this loop. If successful, +/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// inserted right before where it was. +PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { + auto mayBeConstTripCount = getConstantTripCount(*forStmt); + if (!mayBeConstTripCount.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return PassResult::Failure; + } + + SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs; + findMatchingStartFinishStmts(forStmt, startWaitPairs); + + if (startWaitPairs.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); + return failure(); + } // Double the buffers for the higher memory space memref's. - // TODO(bondhugula): assuming we don't have multiple DMA starts for the same - // memref. + // Identify memref's to replace by scanning through all DMA start statements. + // A DMA start statement has two memref's - the one from the higher level of + // memory hierarchy is the one to double buffer. // TODO(bondhugula): check whether double-buffering is even necessary. // TODO(bondhugula): make this work with different layouts: assuming here that // the dimension we are adding here for the double buffering is the outermost // dimension. - // Identify memref's to replace by scanning through all DMA start statements. - // A DMA start statement has two memref's - the one from the higher level of - // memory hierarchy is the one to double buffer. - for (auto *dmaStartStmt : dmaStartStmts) { - MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( - getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>()))); + for (auto &pair : startWaitPairs) { + auto *dmaStartStmt = pair.first; + const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( + dmaStartStmt->getAs<DmaStartOp>()->getFasterMemPos())); if (!doubleBuffer(oldMemRef, forStmt)) { - return PassResult::Failure; + // Normally, double buffering should not fail because we already checked + // that there are no uses outside. + LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); + LLVM_DEBUG(dmaStartStmt->dump()); + return failure(); } } - // Double the buffers for tag memref's. - for (auto *dmaFinishStmt : dmaFinishStmts) { - MLValue *oldTagMemRef = cast<MLValue>( + // Double the buffers for tag memrefs. + for (auto &pair : startWaitPairs) { + const auto *dmaFinishStmt = pair.second; + const MLValue *oldTagMemRef = cast<MLValue>( dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt))); if (!doubleBuffer(oldTagMemRef, forStmt)) { - return PassResult::Failure; + LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); + return failure(); } } - // Collect all compute ops. - std::vector<const Statement *> computeOps; - computeOps.reserve(forStmt->getStatements().size()); + // Double buffering would have invalidated all the old DMA start/wait stmts. + startWaitPairs.clear(); + findMatchingStartFinishStmts(forStmt, startWaitPairs); + // Store delay for statement for later lookup for AffineApplyOp's. - DenseMap<const Statement *, unsigned> opDelayMap; - for (auto &stmt : *forStmt) { - auto *opStmt = dyn_cast<OperationStmt>(&stmt); - if (!opStmt) { - // All for and if stmt's are treated as pure compute operations. - opDelayMap[&stmt] = 1; - } else if (opStmt->is<DmaStartOp>()) { - // DMA starts are not shifted. - opDelayMap[opStmt] = 0; - // Set shifts for DMA start stmt's affine operand computation slices to 0. - if (auto *slice = mlir::createAffineComputationSlice(opStmt)) { - opDelayMap[slice] = 0; - } else { - // If a slice wasn't created, the reachable affine_apply op's from its - // operands are the ones that go with it. - SmallVector<OperationStmt *, 4> affineApplyStmts; - SmallVector<MLValue *, 4> operands(opStmt->getOperands()); - getReachableAffineApplyOps(operands, affineApplyStmts); - for (auto *op : affineApplyStmts) { - opDelayMap[op] = 0; - } - } - } else if (opStmt->is<DmaWaitOp>()) { - // DMA finish op shifted by one. - opDelayMap[opStmt] = 1; + DenseMap<const Statement *, unsigned> stmtDelayMap; + for (auto &pair : startWaitPairs) { + auto *dmaStartStmt = pair.first; + assert(dmaStartStmt->is<DmaStartOp>()); + stmtDelayMap[dmaStartStmt] = 0; + // Set shifts for DMA start stmt's affine operand computation slices to 0. + if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) { + stmtDelayMap[slice] = 0; } else { - // Everything else is a compute op; so shifted by one (op's supplying - // 'affine' operands to DMA start's have already been set right shifts. - opDelayMap[opStmt] = 1; - computeOps.push_back(&stmt); + // If a slice wasn't created, the reachable affine_apply op's from its + // operands are the ones that go with it. + SmallVector<OperationStmt *, 4> affineApplyStmts; + SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands()); + getReachableAffineApplyOps(operands, affineApplyStmts); + for (const auto *stmt : affineApplyStmts) { + stmtDelayMap[stmt] = 0; + } + } + } + // Everything else (including compute ops and dma finish) are shifted by one. + for (const auto &stmt : *forStmt) { + if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) { + stmtDelayMap[&stmt] = 1; } } @@ -219,18 +291,20 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { std::vector<uint64_t> delays(forStmt->getStatements().size()); unsigned s = 0; for (const auto &stmt : *forStmt) { - assert(opDelayMap.find(&stmt) != opDelayMap.end()); - delays[s++] = opDelayMap[&stmt]; + assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end()); + delays[s++] = stmtDelayMap[&stmt]; } - if (!checkDominancePreservationOnShift(*forStmt, delays)) { + if (!isStmtwiseShiftValid(*forStmt, delays)) { // Violates SSA dominance. + LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";); return PassResult::Failure; } if (stmtBodySkew(forStmt, delays)) { + LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";); return PassResult::Failure; } - return PassResult::Success; + return success(); } |

