summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/PipelineDataTransfer.cpp
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-10-04 17:15:30 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 13:23:19 -0700
commit6cfdb756b165ebd32068506fc70d623f86bb80b3 (patch)
tree2ac173f35c2b0b7913a243d0b5eed17c78c5a09b /mlir/lib/Transforms/PipelineDataTransfer.cpp
parentb55b4076011419c8d8d8cac58c8fda7631067bb2 (diff)
downloadbcm5719-llvm-6cfdb756b165ebd32068506fc70d623f86bb80b3.tar.gz
bcm5719-llvm-6cfdb756b165ebd32068506fc70d623f86bb80b3.zip
Introduce memref replacement/rewrite support: to replace an existing memref
with a new one (of a potentially different rank/shape) with an optional index remapping. - introduce Utils::replaceAllMemRefUsesWith - use this for DMA double buffering (This CL also adds a few temporary utilities / code that will be done away with once: 1) abstract DMA op's are added 2) memref deferencing side-effect / trait is available on op's 3) b/117159533 is resolved (memref index computation slices). PiperOrigin-RevId: 215831373
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp235
1 files changed, 224 insertions, 11 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 96d706e98ac..5aaac1c6c29 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -21,10 +21,13 @@
#include "mlir/Transforms/Passes.h"
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
using namespace mlir;
@@ -43,27 +46,237 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-// For testing purposes, this just runs on the first statement of the MLFunction
-// if that statement is a for stmt, and shifts the second half of its body by
-// one.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
+// op traits for it are added. TODO(b/117228571)
+static bool isDmaStartStmt(const OperationStmt &stmt) {
+ return stmt.getName().strref().contains("dma.in.start") ||
+ stmt.getName().strref().contains("dma.out.start");
+}
+
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added. TODO(b/117228571)
+static bool isDmaFinishStmt(const OperationStmt &stmt) {
+ return stmt.getName().strref().contains("dma.finish");
+}
+
+/// Given a DMA start operation, returns the operand position of either the
+/// source or destination memref depending on the one that is at the higher
+/// level of the memory hierarchy.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added. TODO(b/117228571)
+static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
+ assert(isDmaStartStmt(dmaStartStmt));
+ unsigned srcDmaPos = 0;
+ unsigned destDmaPos =
+ cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
+
+ if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
+ ->getMemorySpace() >
+ cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
+ ->getMemorySpace())
+ return srcDmaPos;
+ return destDmaPos;
+}
+
+// Returns the position of the tag memref operand given a DMA statement.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added. TODO(b/117228571)
+unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
+ assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
+ if (isDmaStartStmt(dmaStmt)) {
+ // Second to last operand.
+ return dmaStmt.getNumOperands() - 2;
+ }
+ // First operand for a dma finish statement.
+ return 0;
+}
+
+/// Doubles the buffer of the supplied memref.
+static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
+ MLFuncBuilder bInner(forStmt, forStmt->begin());
+ bInner.setInsertionPoint(forStmt, forStmt->begin());
+
+ // Doubles the shape with a leading dimension extent of 2.
+ auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
+ // Add the leading dimension in the shape for the double buffer.
+ ArrayRef<int> shape = origMemRefType->getShape();
+ SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
+ shapeSizes.insert(shapeSizes.begin(), 2);
+
+ auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
+ return newMemRefType;
+ };
+
+ auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
+
+ // Create and place the alloc at the top level.
+ auto *func = forStmt->getFunction();
+ MLFuncBuilder topBuilder(func, func->begin());
+ auto *newMemRef = cast<MLValue>(
+ topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
+ ->getResult());
+
+ auto d0 = bInner.getDimExpr(0);
+ auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
+ auto ivModTwoOp =
+ bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
+ if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0)))
+ return false;
+ // We don't need ivMod2Op any more - this is cloned by
+ // replaceAllMemRefUsesWith wherever the memref replacement happens. Once
+ // b/117159533 is addressed, we'll eventually only need to pass
+ // ivModTwoOp->getResult(0) to replaceAllMemRefUsesWith.
+ cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
+ return true;
+}
+
+// For testing purposes, this just runs on the first for statement of an
+// MLFunction at the top level.
+// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
+// the other TODOs listed inside are dealt with.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
if (f->empty())
return PassResult::Success;
- auto *forStmt = dyn_cast<ForStmt>(&f->front());
+
+ ForStmt *forStmt = nullptr;
+ for (auto &stmt : *f) {
+ if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
+ break;
+ }
+ }
if (!forStmt)
- return PassResult::Failure;
+ return PassResult::Success;
unsigned numStmts = forStmt->getStatements().size();
+
if (numStmts == 0)
return PassResult::Success;
- std::vector<uint64_t> delays(numStmts);
- for (unsigned i = 0; i < numStmts; i++)
- delays[i] = (i < numStmts / 2) ? 0 : 1;
+ SmallVector<OperationStmt *, 4> dmaStartStmts;
+ SmallVector<OperationStmt *, 4> dmaFinishStmts;
+ for (auto &stmt : *forStmt) {
+ auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ if (!opStmt)
+ continue;
+ if (isDmaStartStmt(*opStmt)) {
+ dmaStartStmts.push_back(opStmt);
+ } else if (isDmaFinishStmt(*opStmt)) {
+ dmaFinishStmts.push_back(opStmt);
+ }
+ }
+
+ // TODO(bondhugula,andydavis): match tag memref's (requires memory-based
+ // subscript check utilities). Assume for now that start/finish are matched in
+ // the order they appear.
+ if (dmaStartStmts.size() != dmaFinishStmts.size())
+ return PassResult::Failure;
+
+ // Double the buffers for the higher memory space memref's.
+ // TODO(bondhugula): assuming we don't have multiple DMA starts for the same
+ // memref.
+ // TODO(bondhugula): check whether double-buffering is even necessary.
+ // TODO(bondhugula): make this work with different layouts: assuming here that
+ // the dimension we are adding here for the double buffering is the outermost
+ // dimension.
+ // Identify memref's to replace by scanning through all DMA start statements.
+ // A DMA start statement has two memref's - the one from the higher level of
+ // memory hierarchy is the one to double buffer.
+ for (auto *dmaStartStmt : dmaStartStmts) {
+ MLValue *oldMemRef = cast<MLValue>(
+ dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
+ if (!doubleBuffer(oldMemRef, forStmt))
+ return PassResult::Failure;
+ }
+
+ // Double the buffers for tag memref's.
+ for (auto *dmaFinishStmt : dmaFinishStmts) {
+ MLValue *oldTagMemRef = cast<MLValue>(
+ dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
+ if (!doubleBuffer(oldTagMemRef, forStmt))
+ return PassResult::Failure;
+ }
+
+ // Collect all compute ops.
+ std::vector<const Statement *> computeOps;
+ computeOps.reserve(forStmt->getStatements().size());
+ // Store delay for statement for later lookup for AffineApplyOp's.
+ DenseMap<const Statement *, unsigned> opDelayMap;
+ for (const auto &stmt : *forStmt) {
+ auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ if (!opStmt) {
+ // All for and if stmt's are treated as pure compute operations.
+ // TODO(bondhugula): check whether such statements do not have any DMAs
+ // nested within.
+ opDelayMap[&stmt] = 1;
+ } else if (isDmaStartStmt(*opStmt)) {
+ // DMA starts are not shifted.
+ opDelayMap[&stmt] = 0;
+ } else if (isDmaFinishStmt(*opStmt)) {
+ // DMA finish op shifted by one.
+ opDelayMap[&stmt] = 1;
+ } else if (!opStmt->is<AffineApplyOp>()) {
+ // Compute op shifted by one.
+ opDelayMap[&stmt] = 1;
+ computeOps.push_back(&stmt);
+ }
+ // Shifts for affine apply op's determined later.
+ }
+
+ // Get the ancestor of a 'stmt' that lies in forStmt's block.
+ auto getAncestorInForBlock =
+ [&](const Statement *stmt, const StmtBlock &block) -> const Statement * {
+ // Traverse up the statement hierarchy starting from the owner of operand to
+ // find the ancestor statement that resides in the block of 'forStmt'.
+ while (stmt != nullptr && stmt->getBlock() != &block) {
+ stmt = stmt->getParentStmt();
+ }
+ return stmt;
+ };
+
+ // Determine delays for affine apply op's: look up delay from its consumer op.
+ // This code will be thrown away once we have a way to obtain indices through
+ // a composed affine_apply op. See TODO(b/117159533). Such a composed
+ // affine_apply will be used exclusively by a given memref deferencing op.
+ for (const auto &stmt : *forStmt) {
+ auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ // Skip statements that aren't affine apply ops.
+ if (!opStmt || !opStmt->is<AffineApplyOp>())
+ continue;
+ // Traverse uses of each result of the affine apply op.
+ for (auto *res : opStmt->getResults()) {
+ for (auto &use : res->getUses()) {
+ auto *ancestorInForBlock =
+ getAncestorInForBlock(use.getOwner(), *forStmt);
+ assert(ancestorInForBlock &&
+ "traversing parent should reach forStmt block");
+ auto *opCheck = dyn_cast<OperationStmt>(ancestorInForBlock);
+ if (!opCheck || opCheck->is<AffineApplyOp>())
+ continue;
+ assert(opDelayMap.find(ancestorInForBlock) != opDelayMap.end());
+ if (opDelayMap.find(&stmt) != opDelayMap.end()) {
+ // This is where we enforce all uses of this affine_apply to have
+ // the same shifts - so that we know what shift to use for the
+ // affine_apply to preserve semantics.
+ assert(opDelayMap[&stmt] == opDelayMap[ancestorInForBlock]);
+ } else {
+ // Obtain delay from its consumer.
+ opDelayMap[&stmt] = opDelayMap[ancestorInForBlock];
+ }
+ }
+ }
+ }
+
+ // Get delays stored in map.
+ std::vector<uint64_t> delays(forStmt->getStatements().size());
+ unsigned s = 0;
+ for (const auto &stmt : *forStmt) {
+ delays[s++] = opDelayMap[&stmt];
+ }
- if (!checkDominancePreservationOnShift(*forStmt, delays))
+ if (!checkDominancePreservationOnShift(*forStmt, delays)) {
// Violates SSA dominance.
return PassResult::Failure;
+ }
if (stmtBodySkew(forStmt, delays))
return PassResult::Failure;
OpenPOWER on IntegriCloud