summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/PipelineDataTransfer.cpp
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-10-18 11:14:26 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 13:32:28 -0700
commit18e666702cd00e0b9c1dafc9801fcdda6dfdb704 (patch)
tree17326676b703fbf74216a38657cf8df3cf5d3d2b /mlir/lib/Transforms/PipelineDataTransfer.cpp
parent3013dadb7c3326f016b3e6bf02f3df9a0d3efa6a (diff)
downloadbcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.tar.gz
bcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.zip
Generalize / improve DMA transfer overlap; nested and multiple DMA support; resolve
multiple TODOs. - replace the fake test pass (that worked on just the first loop in the MLFunction) to perform DMA pipelining on all suitable loops. - nested DMAs work now (DMAs in an outer loop, more DMAs in nested inner loops) - fix bugs / assumptions: correctly copy memory space and elemental type of source memref for double buffering. - correctly identify matching start/finish statements, handle multiple DMAs per loop. - introduce dominates/properlyDominates utitilies for MLFunction statements. - move checkDominancePreservationOnShifts to LoopAnalysis.h; rename it getShiftValidity - refactor getContainingStmtPos -> findAncestorStmtInBlock - move into Analysis/Utils.h; has two users. - other improvements / cleanup for related API/utilities - add size argument to dma_wait - for nested DMAs or in general, it makes it easy to obtain the size to use when lowering the dma_wait since we wouldn't want to identify the matching dma_start, and more importantly, in general/in the future, there may not always be a dma_start dominating the dma_wait. - add debug information in the pass PiperOrigin-RevId: 217734892
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp264
1 files changed, 169 insertions, 95 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index bb60d8e9d78..d6a064988fb 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -22,21 +22,31 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/StmtVisitor.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "pipeline-data-transfer"
using namespace mlir;
namespace {
-struct PipelineDataTransfer : public MLFunctionPass {
- explicit PipelineDataTransfer() {}
+struct PipelineDataTransfer : public MLFunctionPass,
+ StmtWalker<PipelineDataTransfer> {
PassResult runOnMLFunction(MLFunction *f) override;
+ PassResult runOnForStmt(ForStmt *forStmt);
+
+ // Collect all 'for' statements.
+ void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
+ std::vector<ForStmt *> forStmts;
};
} // end anonymous namespace
@@ -47,20 +57,6 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-/// Given a DMA start operation, returns the operand position of either the
-/// source or destination memref depending on the one that is at the higher
-/// level of the memory hierarchy.
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
-// added. TODO(b/117228571)
-static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) {
- unsigned srcDmaPos = 0;
- unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1;
-
- if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace())
- return srcDmaPos;
- return destDmaPos;
-}
-
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
@@ -76,18 +72,20 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// Doubles the buffer of the supplied memref while replacing all uses of the
/// old memref. Returns false if such a replacement cannot be performed.
-static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
+static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2.
- auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
+ auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
// Add the leading dimension in the shape for the double buffer.
- ArrayRef<int> shape = origMemRefType->getShape();
+ ArrayRef<int> shape = oldMemRefType->getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2);
- auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
+ auto *newMemRefType =
+ bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
+ oldMemRefType->getMemorySpace());
return newMemRefType;
};
@@ -105,113 +103,187 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
- cast<MLValue>(ivModTwoOp->getResult(0))))
+ cast<MLValue>(ivModTwoOp->getResult(0)))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "memref replacement for double buffering failed\n";);
+ cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
return false;
+ }
return true;
}
-// For testing purposes, this just runs on the first 'for' statement of an
-// MLFunction at the top level.
-// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
-// the other TODOs listed inside are dealt with.
+/// Returns false if this succeeds on at least one 'for' stmt.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
if (f->empty())
return PassResult::Success;
- ForStmt *forStmt = nullptr;
- for (auto &stmt : *f) {
- if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
- break;
- }
+ // Do a post order walk so that inner loop DMAs are processed first. This is
+ // necessary since 'for' statements nested within would otherwise become
+ // invalid (erased) when the outer loop is pipelined (the pipelined one gets
+ // deleted and replaced by a prologue, a new steady-state loop and an
+ // epilogue).
+ forStmts.clear();
+ walkPostOrder(f);
+ bool ret = true;
+ for (auto *forStmt : forStmts) {
+ ret = ret & runOnForStmt(forStmt);
}
- if (!forStmt)
- return PassResult::Success;
-
- unsigned numStmts = forStmt->getStatements().size();
+ return ret ? failure() : success();
+}
- if (numStmts == 0)
- return PassResult::Success;
+// Check if tags of the dma start op and dma wait op match.
+static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
+ OpPointer<DmaWaitOp> waitOp) {
+ if (startOp->getTagMemRef() != waitOp->getTagMemRef())
+ return false;
+ auto startIndices = startOp->getTagIndices();
+ auto waitIndices = waitOp->getTagIndices();
+ // Both of these have the same number of indices since they correspond to the
+ // same tag memref.
+ for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
+ e = startIndices.end();
+ it != e; ++it, ++wIt) {
+ // Keep it simple for now, just checking if indices match.
+ // TODO(mlir-team): this would in general need to check if there is no
+ // intervening write writing to the same tag location, i.e., memory last
+ // write/data flow analysis. This is however sufficient/powerful enough for
+ // now since the DMA generation pass or the input for it will always have
+ // start/wait with matching tags (same SSA operand indices).
+ if (*it != *wIt)
+ return false;
+ }
+ return true;
+}
- SmallVector<OperationStmt *, 4> dmaStartStmts;
- SmallVector<OperationStmt *, 4> dmaFinishStmts;
+// Identify matching DMA start/finish statements to overlap computation with.
+static void findMatchingStartFinishStmts(
+ ForStmt *forStmt,
+ SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>>
+ &startWaitPairs) {
+ SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
for (auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
- if (opStmt->is<DmaStartOp>()) {
- dmaStartStmts.push_back(opStmt);
- } else if (opStmt->is<DmaWaitOp>()) {
+ // Collect DMA finish statements.
+ if (opStmt->is<DmaWaitOp>()) {
dmaFinishStmts.push_back(opStmt);
+ continue;
+ }
+ OpPointer<DmaStartOp> dmaStartOp;
+ if (!(dmaStartOp = opStmt->getAs<DmaStartOp>()))
+ continue;
+ // Only DMAs incoming into higher memory spaces.
+ // TODO(bondhugula): outgoing DMAs.
+ if (!dmaStartOp->isDestMemorySpaceFaster())
+ continue;
+
+ // We only double buffer if the buffer is not live out of loop.
+ const MLValue *memref =
+ cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
+ bool escapingUses = false;
+ for (const auto &use : memref->getUses()) {
+ if (!dominates(*forStmt, *use.getOwner())) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "can't pipeline: buffer is live out of loop\n";);
+ escapingUses = true;
+ break;
+ }
+ }
+ if (!escapingUses)
+ dmaStartStmts.push_back(opStmt);
+ }
+
+ // For each start statement, we look for a matching finish statement.
+ for (auto *dmaStartStmt : dmaStartStmts) {
+ for (auto *dmaFinishStmt : dmaFinishStmts) {
+ if (checkTagMatch(dmaStartStmt->getAs<DmaStartOp>(),
+ dmaFinishStmt->getAs<DmaWaitOp>())) {
+ startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
+ break;
+ }
}
}
+}
- // TODO(bondhugula,andydavis): match tag memref's (requires memory-based
- // subscript check utilities). Assume for now that start/finish are matched in
- // the order they appear.
- if (dmaStartStmts.size() != dmaFinishStmts.size())
+/// Overlap DMA transfers with computation in this loop. If successful,
+/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// inserted right before where it was.
+PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
+ auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+ if (!mayBeConstTripCount.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return PassResult::Failure;
+ }
+
+ SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs;
+ findMatchingStartFinishStmts(forStmt, startWaitPairs);
+
+ if (startWaitPairs.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
+ return failure();
+ }
// Double the buffers for the higher memory space memref's.
- // TODO(bondhugula): assuming we don't have multiple DMA starts for the same
- // memref.
+ // Identify memref's to replace by scanning through all DMA start statements.
+ // A DMA start statement has two memref's - the one from the higher level of
+ // memory hierarchy is the one to double buffer.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
- // Identify memref's to replace by scanning through all DMA start statements.
- // A DMA start statement has two memref's - the one from the higher level of
- // memory hierarchy is the one to double buffer.
- for (auto *dmaStartStmt : dmaStartStmts) {
- MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
- getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>())));
+ for (auto &pair : startWaitPairs) {
+ auto *dmaStartStmt = pair.first;
+ const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
+ dmaStartStmt->getAs<DmaStartOp>()->getFasterMemPos()));
if (!doubleBuffer(oldMemRef, forStmt)) {
- return PassResult::Failure;
+ // Normally, double buffering should not fail because we already checked
+ // that there are no uses outside.
+ LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
+ LLVM_DEBUG(dmaStartStmt->dump());
+ return failure();
}
}
- // Double the buffers for tag memref's.
- for (auto *dmaFinishStmt : dmaFinishStmts) {
- MLValue *oldTagMemRef = cast<MLValue>(
+ // Double the buffers for tag memrefs.
+ for (auto &pair : startWaitPairs) {
+ const auto *dmaFinishStmt = pair.second;
+ const MLValue *oldTagMemRef = cast<MLValue>(
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
if (!doubleBuffer(oldTagMemRef, forStmt)) {
- return PassResult::Failure;
+ LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
+ return failure();
}
}
- // Collect all compute ops.
- std::vector<const Statement *> computeOps;
- computeOps.reserve(forStmt->getStatements().size());
+ // Double buffering would have invalidated all the old DMA start/wait stmts.
+ startWaitPairs.clear();
+ findMatchingStartFinishStmts(forStmt, startWaitPairs);
+
// Store delay for statement for later lookup for AffineApplyOp's.
- DenseMap<const Statement *, unsigned> opDelayMap;
- for (auto &stmt : *forStmt) {
- auto *opStmt = dyn_cast<OperationStmt>(&stmt);
- if (!opStmt) {
- // All for and if stmt's are treated as pure compute operations.
- opDelayMap[&stmt] = 1;
- } else if (opStmt->is<DmaStartOp>()) {
- // DMA starts are not shifted.
- opDelayMap[opStmt] = 0;
- // Set shifts for DMA start stmt's affine operand computation slices to 0.
- if (auto *slice = mlir::createAffineComputationSlice(opStmt)) {
- opDelayMap[slice] = 0;
- } else {
- // If a slice wasn't created, the reachable affine_apply op's from its
- // operands are the ones that go with it.
- SmallVector<OperationStmt *, 4> affineApplyStmts;
- SmallVector<MLValue *, 4> operands(opStmt->getOperands());
- getReachableAffineApplyOps(operands, affineApplyStmts);
- for (auto *op : affineApplyStmts) {
- opDelayMap[op] = 0;
- }
- }
- } else if (opStmt->is<DmaWaitOp>()) {
- // DMA finish op shifted by one.
- opDelayMap[opStmt] = 1;
+ DenseMap<const Statement *, unsigned> stmtDelayMap;
+ for (auto &pair : startWaitPairs) {
+ auto *dmaStartStmt = pair.first;
+ assert(dmaStartStmt->is<DmaStartOp>());
+ stmtDelayMap[dmaStartStmt] = 0;
+ // Set shifts for DMA start stmt's affine operand computation slices to 0.
+ if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
+ stmtDelayMap[slice] = 0;
} else {
- // Everything else is a compute op; so shifted by one (op's supplying
- // 'affine' operands to DMA start's have already been set right shifts.
- opDelayMap[opStmt] = 1;
- computeOps.push_back(&stmt);
+ // If a slice wasn't created, the reachable affine_apply op's from its
+ // operands are the ones that go with it.
+ SmallVector<OperationStmt *, 4> affineApplyStmts;
+ SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
+ getReachableAffineApplyOps(operands, affineApplyStmts);
+ for (const auto *stmt : affineApplyStmts) {
+ stmtDelayMap[stmt] = 0;
+ }
+ }
+ }
+ // Everything else (including compute ops and dma finish) are shifted by one.
+ for (const auto &stmt : *forStmt) {
+ if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) {
+ stmtDelayMap[&stmt] = 1;
}
}
@@ -219,18 +291,20 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
std::vector<uint64_t> delays(forStmt->getStatements().size());
unsigned s = 0;
for (const auto &stmt : *forStmt) {
- assert(opDelayMap.find(&stmt) != opDelayMap.end());
- delays[s++] = opDelayMap[&stmt];
+ assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end());
+ delays[s++] = stmtDelayMap[&stmt];
}
- if (!checkDominancePreservationOnShift(*forStmt, delays)) {
+ if (!isStmtwiseShiftValid(*forStmt, delays)) {
// Violates SSA dominance.
+ LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";);
return PassResult::Failure;
}
if (stmtBodySkew(forStmt, delays)) {
+ LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";);
return PassResult::Failure;
}
- return PassResult::Success;
+ return success();
}
OpenPOWER on IntegriCloud