summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-10-18 11:14:26 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 13:32:28 -0700
commit18e666702cd00e0b9c1dafc9801fcdda6dfdb704 (patch)
tree17326676b703fbf74216a38657cf8df3cf5d3d2b /mlir/lib
parent3013dadb7c3326f016b3e6bf02f3df9a0d3efa6a (diff)
downloadbcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.tar.gz
bcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.zip
Generalize / improve DMA transfer overlap; nested and multiple DMA support; resolve
multiple TODOs. - replace the fake test pass (that worked on just the first loop in the MLFunction) to perform DMA pipelining on all suitable loops. - nested DMAs work now (DMAs in an outer loop, more DMAs in nested inner loops) - fix bugs / assumptions: correctly copy memory space and elemental type of source memref for double buffering. - correctly identify matching start/finish statements, handle multiple DMAs per loop. - introduce dominates/properlyDominates utitilies for MLFunction statements. - move checkDominancePreservationOnShifts to LoopAnalysis.h; rename it getShiftValidity - refactor getContainingStmtPos -> findAncestorStmtInBlock - move into Analysis/Utils.h; has two users. - other improvements / cleanup for related API/utilities - add size argument to dma_wait - for nested DMAs or in general, it makes it easy to obtain the size to use when lowering the dma_wait since we wouldn't want to identify the matching dma_start, and more importantly, in general/in the future, there may not always be a dma_start dominating the dma_wait. - add debug information in the pass PiperOrigin-RevId: 217734892
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp37
-rw-r--r--mlir/lib/Analysis/Utils.cpp62
-rw-r--r--mlir/lib/IR/StmtBlock.cpp16
-rw-r--r--mlir/lib/StandardOps/StandardOps.cpp21
-rw-r--r--mlir/lib/Transforms/LoopUtils.cpp53
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp264
-rw-r--r--mlir/lib/Transforms/Utils.cpp13
7 files changed, 302 insertions, 164 deletions
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 232d162f4e1..9e65e7b4d9e 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -24,8 +24,6 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/MLFunctionMatcher.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
@@ -128,12 +126,13 @@ static bool isAccessInvariant(MLValue *input, MemRefType *memRefType,
assert(indices.size() == memRefType->getRank());
assert(dim < indices.size());
auto layoutMap = memRefType->getAffineMaps();
- assert(layoutMap.size() <= 1);
+ assert(memRefType->getAffineMaps().size() <= 1);
// TODO(ntv): remove dependency on Builder once we support non-identity
// layout map.
Builder b(memRefType->getContext());
assert(layoutMap.empty() ||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
+ (void)layoutMap;
SmallVector<OperationStmt *, 4> affineApplyOps;
getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
@@ -197,3 +196,35 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) {
}
return true;
}
+
+/// Checks whether SSA dominance would be violated if a for stmt's body
+/// statements are shifted by the specified shifts. This method checks if a
+/// 'def' and all its uses have the same shift factor.
+// TODO(mlir-team): extend this to check for memory-based dependence
+// violation when we have the support.
+bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
+ ArrayRef<uint64_t> shifts) {
+ assert(shifts.size() == forStmt.getStatements().size());
+ unsigned s = 0;
+ for (const auto &stmt : forStmt) {
+ // A for or if stmt does not produce any def/results (that are used
+ // outside).
+ if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
+ const MLValue *result = opStmt->getResult(i);
+ for (const StmtOperand &use : result->getUses()) {
+ // If an ancestor statement doesn't lie in the block of forStmt, there
+ // is no shift to check.
+ // This is a naive way. If performance becomes an issue, a map can
+ // be used to store 'shifts' - to look up the shift for a statement in
+ // constant time.
+ if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
+ if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
+ return false;
+ }
+ }
+ }
+ s++;
+ }
+ return true;
+}
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
new file mode 100644
index 00000000000..20aac12d5e2
--- /dev/null
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -0,0 +1,62 @@
+//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous analysis routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Utils.h"
+
+#include "mlir/IR/Statements.h"
+
+using namespace mlir;
+
+/// Returns true if statement 'a' properly dominates statement b.
+bool mlir::properlyDominates(const Statement &a, const Statement &b) {
+ if (&a == &b)
+ return false;
+
+ if (a.findFunction() != b.findFunction())
+ return false;
+
+ if (a.getBlock() == b.getBlock()) {
+ // Do a linear scan to determine whether b comes after a.
+ auto aIter = StmtBlock::const_iterator(a);
+ auto bIter = StmtBlock::const_iterator(b);
+ auto aBlockStart = a.getBlock()->begin();
+ while (bIter != aBlockStart) {
+ --bIter;
+ if (aIter == bIter)
+ return true;
+ }
+ return false;
+ }
+
+ // Traverse up b's hierarchy to check if b's block is contained in a's.
+ if (const auto *bAncestor = a.getBlock()->findAncestorStmtInBlock(b))
+ // a and bAncestor are in the same block; check if the former dominates it.
+ return dominates(a, *bAncestor);
+
+ // b's block is not contained in A.
+ return false;
+}
+
+/// Returns true if statement A dominates statement B.
+bool mlir::dominates(const Statement &a, const Statement &b) {
+ return &a == &b || properlyDominates(a, b);
+}
diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp
index 7cd92ffc980..40a31f6c3b9 100644
--- a/mlir/lib/IR/StmtBlock.cpp
+++ b/mlir/lib/IR/StmtBlock.cpp
@@ -45,3 +45,19 @@ MLFunction *StmtBlock::findFunction() const {
}
return dyn_cast<MLFunction>(block);
}
+
+/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
+/// statement of 'stmt' that lies in this block. Returns nullptr if the latter
+/// fails.
+const Statement *
+StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const {
+ // Traverse up the statement hierarchy starting from the owner of operand to
+ // find the ancestor statement that resides in the block of 'forStmt'.
+ const auto *currStmt = &stmt;
+ while (currStmt->getBlock() != this) {
+ currStmt = currStmt->getParentStmt();
+ if (!currStmt)
+ return nullptr;
+ }
+ return currStmt;
+}
diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp
index bd63fa3887c..d2f71123540 100644
--- a/mlir/lib/StandardOps/StandardOps.cpp
+++ b/mlir/lib/StandardOps/StandardOps.cpp
@@ -392,7 +392,7 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
}
// Parse DmaStartOp.
-// EX:
+// Ex:
// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
// %tag[%index] :
// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
@@ -458,33 +458,38 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
// ---------------------------------------------------------------------------
// DmaWaitOp
// ---------------------------------------------------------------------------
-// Parse DmaWaitOp.
-// Eg:
-// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
-//
+
void DmaWaitOp::print(OpAsmPrinter *p) const {
*p << getOperationName() << ' ';
// Print operands.
p->printOperand(getTagMemRef());
*p << '[';
p->printOperands(getTagIndices());
- *p << ']';
+ *p << "], ";
+ p->printOperand(getNumElements());
*p << " : " << *getTagMemRef()->getType();
}
+// Parse DmaWaitOp.
+// Eg:
+// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
+//
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
Type *type;
auto *indexType = parser->getBuilder().getIndexType();
+ OpAsmParser::OperandType numElementsInfo;
- // Parse tag memref and index.
+ // Parse tag memref, its indices, and dma size.
if (parser->parseOperand(tagMemrefInfo) ||
parser->parseOperandList(tagIndexInfos, -1,
OpAsmParser::Delimiter::Square) ||
+ parser->parseComma() || parser->parseOperand(numElementsInfo) ||
parser->parseColonType(type) ||
parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
- parser->resolveOperands(tagIndexInfos, indexType, result->operands))
+ parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
+ parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true;
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
diff --git a/mlir/lib/Transforms/LoopUtils.cpp b/mlir/lib/Transforms/LoopUtils.cpp
index 743e15bbbae..c47f965055f 100644
--- a/mlir/lib/Transforms/LoopUtils.cpp
+++ b/mlir/lib/Transforms/LoopUtils.cpp
@@ -181,57 +181,6 @@ generateLoop(AffineMap lb, AffineMap ub,
return loopChunk;
}
-// Returns delay of that child statement of 'forStmt' which either has 'operand'
-// as one of its operands or has a descendant statement with operand 'operand'.
-// This is a naive implementation. If performance becomes an issue, a map can
-// be used to store 'delays' - to look up the delay for a statement in constant
-// time.
-static uint64_t getContainingStmtDelay(const StmtOperand &operand,
- const ForStmt &forStmt,
- ArrayRef<uint64_t> delays) {
- // Traverse up the statement hierarchy starting from the owner of operand to
- // find the ancestor statement that resides in the block of 'forStmt'.
- const Statement *stmt = operand.getOwner();
- assert(stmt != nullptr);
- while (stmt->getParentStmt() != &forStmt) {
- stmt = stmt->getParentStmt();
- assert(stmt && "traversing parent's should reach forStmt block");
- }
- // Look up the delay of 'stmt'.
- unsigned j = 0;
- for (const auto &s : forStmt) {
- if (&s == stmt)
- break;
- j++;
- }
- assert(j < forStmt.getStatements().size() && "child stmt should be found");
- return delays[j];
-}
-
-/// Checks if SSA dominance would be violated if a for stmt's body statements
-/// are shifted by the specified delays. This method checks if a 'def' and all
-/// its uses have the same delay factor.
-bool mlir::checkDominancePreservationOnShift(const ForStmt &forStmt,
- ArrayRef<uint64_t> delays) {
- assert(delays.size() == forStmt.getStatements().size());
- unsigned s = 0;
- for (const auto &stmt : forStmt) {
- // A for or if stmt does not produce any def/results (that are used
- // outside).
- if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
- for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
- const MLValue *result = opStmt->getResult(i);
- for (const StmtOperand &use : result->getUses()) {
- if (delays[s] != getContainingStmtDelay(use, forStmt, delays))
- return false;
- }
- }
- }
- s++;
- }
- return true;
-}
-
/// Skew the statements in the body of a 'for' statement with the specified
/// statement-wise delays. The delays are with respect to the original execution
/// order. A delay of zero for each statement will lead to no change.
@@ -260,7 +209,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
return UtilResult::Failure;
uint64_t tripCount = mayBeConstTripCount.getValue();
- assert(checkDominancePreservationOnShift(*forStmt, delays) &&
+ assert(isStmtwiseShiftValid(*forStmt, delays) &&
"dominance preservation failed\n");
unsigned numChildStmts = forStmt->getStatements().size();
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index bb60d8e9d78..d6a064988fb 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -22,21 +22,31 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/StmtVisitor.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "pipeline-data-transfer"
using namespace mlir;
namespace {
-struct PipelineDataTransfer : public MLFunctionPass {
- explicit PipelineDataTransfer() {}
+struct PipelineDataTransfer : public MLFunctionPass,
+ StmtWalker<PipelineDataTransfer> {
PassResult runOnMLFunction(MLFunction *f) override;
+ PassResult runOnForStmt(ForStmt *forStmt);
+
+ // Collect all 'for' statements.
+ void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
+ std::vector<ForStmt *> forStmts;
};
} // end anonymous namespace
@@ -47,20 +57,6 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-/// Given a DMA start operation, returns the operand position of either the
-/// source or destination memref depending on the one that is at the higher
-/// level of the memory hierarchy.
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
-// added. TODO(b/117228571)
-static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) {
- unsigned srcDmaPos = 0;
- unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1;
-
- if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace())
- return srcDmaPos;
- return destDmaPos;
-}
-
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
@@ -76,18 +72,20 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// Doubles the buffer of the supplied memref while replacing all uses of the
/// old memref. Returns false if such a replacement cannot be performed.
-static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
+static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2.
- auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
+ auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
// Add the leading dimension in the shape for the double buffer.
- ArrayRef<int> shape = origMemRefType->getShape();
+ ArrayRef<int> shape = oldMemRefType->getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2);
- auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
+ auto *newMemRefType =
+ bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
+ oldMemRefType->getMemorySpace());
return newMemRefType;
};
@@ -105,113 +103,187 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
- cast<MLValue>(ivModTwoOp->getResult(0))))
+ cast<MLValue>(ivModTwoOp->getResult(0)))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "memref replacement for double buffering failed\n";);
+ cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
return false;
+ }
return true;
}
-// For testing purposes, this just runs on the first 'for' statement of an
-// MLFunction at the top level.
-// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
-// the other TODOs listed inside are dealt with.
+/// Returns false if this succeeds on at least one 'for' stmt.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
if (f->empty())
return PassResult::Success;
- ForStmt *forStmt = nullptr;
- for (auto &stmt : *f) {
- if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
- break;
- }
+ // Do a post order walk so that inner loop DMAs are processed first. This is
+ // necessary since 'for' statements nested within would otherwise become
+ // invalid (erased) when the outer loop is pipelined (the pipelined one gets
+ // deleted and replaced by a prologue, a new steady-state loop and an
+ // epilogue).
+ forStmts.clear();
+ walkPostOrder(f);
+ bool ret = true;
+ for (auto *forStmt : forStmts) {
+ ret = ret & runOnForStmt(forStmt);
}
- if (!forStmt)
- return PassResult::Success;
-
- unsigned numStmts = forStmt->getStatements().size();
+ return ret ? failure() : success();
+}
- if (numStmts == 0)
- return PassResult::Success;
+// Check if tags of the dma start op and dma wait op match.
+static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
+ OpPointer<DmaWaitOp> waitOp) {
+ if (startOp->getTagMemRef() != waitOp->getTagMemRef())
+ return false;
+ auto startIndices = startOp->getTagIndices();
+ auto waitIndices = waitOp->getTagIndices();
+ // Both of these have the same number of indices since they correspond to the
+ // same tag memref.
+ for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
+ e = startIndices.end();
+ it != e; ++it, ++wIt) {
+ // Keep it simple for now, just checking if indices match.
+ // TODO(mlir-team): this would in general need to check if there is no
+ // intervening write writing to the same tag location, i.e., memory last
+ // write/data flow analysis. This is however sufficient/powerful enough for
+ // now since the DMA generation pass or the input for it will always have
+ // start/wait with matching tags (same SSA operand indices).
+ if (*it != *wIt)
+ return false;
+ }
+ return true;
+}
- SmallVector<OperationStmt *, 4> dmaStartStmts;
- SmallVector<OperationStmt *, 4> dmaFinishStmts;
+// Identify matching DMA start/finish statements to overlap computation with.
+static void findMatchingStartFinishStmts(
+ ForStmt *forStmt,
+ SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>>
+ &startWaitPairs) {
+ SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
for (auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
- if (opStmt->is<DmaStartOp>()) {
- dmaStartStmts.push_back(opStmt);
- } else if (opStmt->is<DmaWaitOp>()) {
+ // Collect DMA finish statements.
+ if (opStmt->is<DmaWaitOp>()) {
dmaFinishStmts.push_back(opStmt);
+ continue;
+ }
+ OpPointer<DmaStartOp> dmaStartOp;
+ if (!(dmaStartOp = opStmt->getAs<DmaStartOp>()))
+ continue;
+ // Only DMAs incoming into higher memory spaces.
+ // TODO(bondhugula): outgoing DMAs.
+ if (!dmaStartOp->isDestMemorySpaceFaster())
+ continue;
+
+ // We only double buffer if the buffer is not live out of loop.
+ const MLValue *memref =
+ cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
+ bool escapingUses = false;
+ for (const auto &use : memref->getUses()) {
+ if (!dominates(*forStmt, *use.getOwner())) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "can't pipeline: buffer is live out of loop\n";);
+ escapingUses = true;
+ break;
+ }
+ }
+ if (!escapingUses)
+ dmaStartStmts.push_back(opStmt);
+ }
+
+ // For each start statement, we look for a matching finish statement.
+ for (auto *dmaStartStmt : dmaStartStmts) {
+ for (auto *dmaFinishStmt : dmaFinishStmts) {
+ if (checkTagMatch(dmaStartStmt->getAs<DmaStartOp>(),
+ dmaFinishStmt->getAs<DmaWaitOp>())) {
+ startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
+ break;
+ }
}
}
+}
- // TODO(bondhugula,andydavis): match tag memref's (requires memory-based
- // subscript check utilities). Assume for now that start/finish are matched in
- // the order they appear.
- if (dmaStartStmts.size() != dmaFinishStmts.size())
+/// Overlap DMA transfers with computation in this loop. If successful,
+/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// inserted right before where it was.
+PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
+ auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+ if (!mayBeConstTripCount.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return PassResult::Failure;
+ }
+
+ SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs;
+ findMatchingStartFinishStmts(forStmt, startWaitPairs);
+
+ if (startWaitPairs.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
+ return failure();
+ }
// Double the buffers for the higher memory space memref's.
- // TODO(bondhugula): assuming we don't have multiple DMA starts for the same
- // memref.
+ // Identify memref's to replace by scanning through all DMA start statements.
+ // A DMA start statement has two memref's - the one from the higher level of
+ // memory hierarchy is the one to double buffer.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
- // Identify memref's to replace by scanning through all DMA start statements.
- // A DMA start statement has two memref's - the one from the higher level of
- // memory hierarchy is the one to double buffer.
- for (auto *dmaStartStmt : dmaStartStmts) {
- MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
- getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>())));
+ for (auto &pair : startWaitPairs) {
+ auto *dmaStartStmt = pair.first;
+ const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
+ dmaStartStmt->getAs<DmaStartOp>()->getFasterMemPos()));
if (!doubleBuffer(oldMemRef, forStmt)) {
- return PassResult::Failure;
+ // Normally, double buffering should not fail because we already checked
+ // that there are no uses outside.
+ LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
+ LLVM_DEBUG(dmaStartStmt->dump());
+ return failure();
}
}
- // Double the buffers for tag memref's.
- for (auto *dmaFinishStmt : dmaFinishStmts) {
- MLValue *oldTagMemRef = cast<MLValue>(
+ // Double the buffers for tag memrefs.
+ for (auto &pair : startWaitPairs) {
+ const auto *dmaFinishStmt = pair.second;
+ const MLValue *oldTagMemRef = cast<MLValue>(
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
if (!doubleBuffer(oldTagMemRef, forStmt)) {
- return PassResult::Failure;
+ LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
+ return failure();
}
}
- // Collect all compute ops.
- std::vector<const Statement *> computeOps;
- computeOps.reserve(forStmt->getStatements().size());
+ // Double buffering would have invalidated all the old DMA start/wait stmts.
+ startWaitPairs.clear();
+ findMatchingStartFinishStmts(forStmt, startWaitPairs);
+
// Store delay for statement for later lookup for AffineApplyOp's.
- DenseMap<const Statement *, unsigned> opDelayMap;
- for (auto &stmt : *forStmt) {
- auto *opStmt = dyn_cast<OperationStmt>(&stmt);
- if (!opStmt) {
- // All for and if stmt's are treated as pure compute operations.
- opDelayMap[&stmt] = 1;
- } else if (opStmt->is<DmaStartOp>()) {
- // DMA starts are not shifted.
- opDelayMap[opStmt] = 0;
- // Set shifts for DMA start stmt's affine operand computation slices to 0.
- if (auto *slice = mlir::createAffineComputationSlice(opStmt)) {
- opDelayMap[slice] = 0;
- } else {
- // If a slice wasn't created, the reachable affine_apply op's from its
- // operands are the ones that go with it.
- SmallVector<OperationStmt *, 4> affineApplyStmts;
- SmallVector<MLValue *, 4> operands(opStmt->getOperands());
- getReachableAffineApplyOps(operands, affineApplyStmts);
- for (auto *op : affineApplyStmts) {
- opDelayMap[op] = 0;
- }
- }
- } else if (opStmt->is<DmaWaitOp>()) {
- // DMA finish op shifted by one.
- opDelayMap[opStmt] = 1;
+ DenseMap<const Statement *, unsigned> stmtDelayMap;
+ for (auto &pair : startWaitPairs) {
+ auto *dmaStartStmt = pair.first;
+ assert(dmaStartStmt->is<DmaStartOp>());
+ stmtDelayMap[dmaStartStmt] = 0;
+ // Set shifts for DMA start stmt's affine operand computation slices to 0.
+ if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
+ stmtDelayMap[slice] = 0;
} else {
- // Everything else is a compute op; so shifted by one (op's supplying
- // 'affine' operands to DMA start's have already been set right shifts.
- opDelayMap[opStmt] = 1;
- computeOps.push_back(&stmt);
+ // If a slice wasn't created, the reachable affine_apply op's from its
+ // operands are the ones that go with it.
+ SmallVector<OperationStmt *, 4> affineApplyStmts;
+ SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
+ getReachableAffineApplyOps(operands, affineApplyStmts);
+ for (const auto *stmt : affineApplyStmts) {
+ stmtDelayMap[stmt] = 0;
+ }
+ }
+ }
+ // Everything else (including compute ops and dma finish) are shifted by one.
+ for (const auto &stmt : *forStmt) {
+ if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) {
+ stmtDelayMap[&stmt] = 1;
}
}
@@ -219,18 +291,20 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
std::vector<uint64_t> delays(forStmt->getStatements().size());
unsigned s = 0;
for (const auto &stmt : *forStmt) {
- assert(opDelayMap.find(&stmt) != opDelayMap.end());
- delays[s++] = opDelayMap[&stmt];
+ assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end());
+ delays[s++] = stmtDelayMap[&stmt];
}
- if (!checkDominancePreservationOnShift(*forStmt, delays)) {
+ if (!isStmtwiseShiftValid(*forStmt, delays)) {
// Violates SSA dominance.
+ LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";);
return PassResult::Failure;
}
if (stmtBodySkew(forStmt, delays)) {
+ LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";);
return PassResult::Failure;
}
- return PassResult::Success;
+ return success();
}
diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp
index 99ddb087bd4..5df6e99b592 100644
--- a/mlir/lib/Transforms/Utils.cpp
+++ b/mlir/lib/Transforms/Utils.cpp
@@ -48,7 +48,8 @@ static bool isMemRefDereferencingOp(const Operation &op) {
/// at the start for now.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position.
-bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
+bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
+ MLValue *newMemRef,
ArrayRef<MLValue *> extraIndices,
AffineMap indexRemap) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
@@ -219,11 +220,11 @@ mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
/// This allows applying different transformations on send and compute (for eg.
/// different shifts/delays).
///
-/// Returns nullptr if none of the operands were the result of an affine_apply
-/// and thus there was no affine computation slice to create. Returns the newly
-/// affine_apply operation statement otherwise.
-///
-///
+/// Returns nullptr either if none of opStmt's operands were the result of an
+/// affine_apply and thus there was no affine computation slice to create, or if
+/// all the affine_apply op's supplying operands to this opStmt do not have any
+/// uses besides this opStmt. Returns the new affine_apply operation statement
+/// otherwise.
OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
// Collect all operands that are results of affine apply ops.
SmallVector<MLValue *, 4> subOperands;
OpenPOWER on IntegriCloud