diff options
| author | Uday Bondhugula <bondhugula@google.com> | 2018-10-18 11:14:26 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 13:32:28 -0700 |
| commit | 18e666702cd00e0b9c1dafc9801fcdda6dfdb704 (patch) | |
| tree | 17326676b703fbf74216a38657cf8df3cf5d3d2b /mlir/lib | |
| parent | 3013dadb7c3326f016b3e6bf02f3df9a0d3efa6a (diff) | |
| download | bcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.tar.gz bcm5719-llvm-18e666702cd00e0b9c1dafc9801fcdda6dfdb704.zip | |
Generalize / improve DMA transfer overlap; nested and multiple DMA support; resolve
multiple TODOs.
- replace the fake test pass (that worked on just the first loop in the
MLFunction) to perform DMA pipelining on all suitable loops.
- nested DMAs work now (DMAs in an outer loop, more DMAs in nested inner loops)
- fix bugs / assumptions: correctly copy memory space and elemental type of source
memref for double buffering.
- correctly identify matching start/finish statements, handle multiple DMAs per
loop.
- introduce dominates/properlyDominates utitilies for MLFunction statements.
- move checkDominancePreservationOnShifts to LoopAnalysis.h; rename it
getShiftValidity
- refactor getContainingStmtPos -> findAncestorStmtInBlock - move into
Analysis/Utils.h; has two users.
- other improvements / cleanup for related API/utilities
- add size argument to dma_wait - for nested DMAs or in general, it makes it
easy to obtain the size to use when lowering the dma_wait since we wouldn't
want to identify the matching dma_start, and more importantly, in general/in the
future, there may not always be a dma_start dominating the dma_wait.
- add debug information in the pass
PiperOrigin-RevId: 217734892
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 37 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 62 | ||||
| -rw-r--r-- | mlir/lib/IR/StmtBlock.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/StandardOps/StandardOps.cpp | 21 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUtils.cpp | 53 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 264 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils.cpp | 13 |
7 files changed, 302 insertions, 164 deletions
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 232d162f4e1..9e65e7b4d9e 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -24,8 +24,6 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/MLFunctionMatcher.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Statements.h" @@ -128,12 +126,13 @@ static bool isAccessInvariant(MLValue *input, MemRefType *memRefType, assert(indices.size() == memRefType->getRank()); assert(dim < indices.size()); auto layoutMap = memRefType->getAffineMaps(); - assert(layoutMap.size() <= 1); + assert(memRefType->getAffineMaps().size() <= 1); // TODO(ntv): remove dependency on Builder once we support non-identity // layout map. Builder b(memRefType->getContext()); assert(layoutMap.empty() || layoutMap[0] == b.getMultiDimIdentityMap(indices.size())); + (void)layoutMap; SmallVector<OperationStmt *, 4> affineApplyOps; getReachableAffineApplyOps({indices[dim]}, affineApplyOps); @@ -197,3 +196,35 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) { } return true; } + +/// Checks whether SSA dominance would be violated if a for stmt's body +/// statements are shifted by the specified shifts. This method checks if a +/// 'def' and all its uses have the same shift factor. +// TODO(mlir-team): extend this to check for memory-based dependence +// violation when we have the support. +bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, + ArrayRef<uint64_t> shifts) { + assert(shifts.size() == forStmt.getStatements().size()); + unsigned s = 0; + for (const auto &stmt : forStmt) { + // A for or if stmt does not produce any def/results (that are used + // outside). + if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { + for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { + const MLValue *result = opStmt->getResult(i); + for (const StmtOperand &use : result->getUses()) { + // If an ancestor statement doesn't lie in the block of forStmt, there + // is no shift to check. + // This is a naive way. If performance becomes an issue, a map can + // be used to store 'shifts' - to look up the shift for a statement in + // constant time. + if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner())) + if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)]) + return false; + } + } + } + s++; + } + return true; +} diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp new file mode 100644 index 00000000000..20aac12d5e2 --- /dev/null +++ b/mlir/lib/Analysis/Utils.cpp @@ -0,0 +1,62 @@ +//===- Utils.cpp ---- Misc utilities for analysis -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements miscellaneous analysis routines for non-loop IR +// structures. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Utils.h" + +#include "mlir/IR/Statements.h" + +using namespace mlir; + +/// Returns true if statement 'a' properly dominates statement b. +bool mlir::properlyDominates(const Statement &a, const Statement &b) { + if (&a == &b) + return false; + + if (a.findFunction() != b.findFunction()) + return false; + + if (a.getBlock() == b.getBlock()) { + // Do a linear scan to determine whether b comes after a. + auto aIter = StmtBlock::const_iterator(a); + auto bIter = StmtBlock::const_iterator(b); + auto aBlockStart = a.getBlock()->begin(); + while (bIter != aBlockStart) { + --bIter; + if (aIter == bIter) + return true; + } + return false; + } + + // Traverse up b's hierarchy to check if b's block is contained in a's. + if (const auto *bAncestor = a.getBlock()->findAncestorStmtInBlock(b)) + // a and bAncestor are in the same block; check if the former dominates it. + return dominates(a, *bAncestor); + + // b's block is not contained in A. + return false; +} + +/// Returns true if statement A dominates statement B. +bool mlir::dominates(const Statement &a, const Statement &b) { + return &a == &b || properlyDominates(a, b); +} diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 7cd92ffc980..40a31f6c3b9 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -45,3 +45,19 @@ MLFunction *StmtBlock::findFunction() const { } return dyn_cast<MLFunction>(block); } + +/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor +/// statement of 'stmt' that lies in this block. Returns nullptr if the latter +/// fails. +const Statement * +StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const { + // Traverse up the statement hierarchy starting from the owner of operand to + // find the ancestor statement that resides in the block of 'forStmt'. + const auto *currStmt = &stmt; + while (currStmt->getBlock() != this) { + currStmt = currStmt->getParentStmt(); + if (!currStmt) + return nullptr; + } + return currStmt; +} diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index bd63fa3887c..d2f71123540 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -392,7 +392,7 @@ void DmaStartOp::print(OpAsmPrinter *p) const { } // Parse DmaStartOp. -// EX: +// Ex: // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, // %tag[%index] : // memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>, @@ -458,33 +458,38 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { // --------------------------------------------------------------------------- // DmaWaitOp // --------------------------------------------------------------------------- -// Parse DmaWaitOp. -// Eg: -// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4> -// + void DmaWaitOp::print(OpAsmPrinter *p) const { *p << getOperationName() << ' '; // Print operands. p->printOperand(getTagMemRef()); *p << '['; p->printOperands(getTagIndices()); - *p << ']'; + *p << "], "; + p->printOperand(getNumElements()); *p << " : " << *getTagMemRef()->getType(); } +// Parse DmaWaitOp. +// Eg: +// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> +// bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; Type *type; auto *indexType = parser->getBuilder().getIndexType(); + OpAsmParser::OperandType numElementsInfo; - // Parse tag memref and index. + // Parse tag memref, its indices, and dma size. if (parser->parseOperand(tagMemrefInfo) || parser->parseOperandList(tagIndexInfos, -1, OpAsmParser::Delimiter::Square) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || parser->parseColonType(type) || parser->resolveOperand(tagMemrefInfo, type, result->operands) || - parser->resolveOperands(tagIndexInfos, indexType, result->operands)) + parser->resolveOperands(tagIndexInfos, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) return true; if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank()) diff --git a/mlir/lib/Transforms/LoopUtils.cpp b/mlir/lib/Transforms/LoopUtils.cpp index 743e15bbbae..c47f965055f 100644 --- a/mlir/lib/Transforms/LoopUtils.cpp +++ b/mlir/lib/Transforms/LoopUtils.cpp @@ -181,57 +181,6 @@ generateLoop(AffineMap lb, AffineMap ub, return loopChunk; } -// Returns delay of that child statement of 'forStmt' which either has 'operand' -// as one of its operands or has a descendant statement with operand 'operand'. -// This is a naive implementation. If performance becomes an issue, a map can -// be used to store 'delays' - to look up the delay for a statement in constant -// time. -static uint64_t getContainingStmtDelay(const StmtOperand &operand, - const ForStmt &forStmt, - ArrayRef<uint64_t> delays) { - // Traverse up the statement hierarchy starting from the owner of operand to - // find the ancestor statement that resides in the block of 'forStmt'. - const Statement *stmt = operand.getOwner(); - assert(stmt != nullptr); - while (stmt->getParentStmt() != &forStmt) { - stmt = stmt->getParentStmt(); - assert(stmt && "traversing parent's should reach forStmt block"); - } - // Look up the delay of 'stmt'. - unsigned j = 0; - for (const auto &s : forStmt) { - if (&s == stmt) - break; - j++; - } - assert(j < forStmt.getStatements().size() && "child stmt should be found"); - return delays[j]; -} - -/// Checks if SSA dominance would be violated if a for stmt's body statements -/// are shifted by the specified delays. This method checks if a 'def' and all -/// its uses have the same delay factor. -bool mlir::checkDominancePreservationOnShift(const ForStmt &forStmt, - ArrayRef<uint64_t> delays) { - assert(delays.size() == forStmt.getStatements().size()); - unsigned s = 0; - for (const auto &stmt : forStmt) { - // A for or if stmt does not produce any def/results (that are used - // outside). - if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { - for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { - const MLValue *result = opStmt->getResult(i); - for (const StmtOperand &use : result->getUses()) { - if (delays[s] != getContainingStmtDelay(use, forStmt, delays)) - return false; - } - } - } - s++; - } - return true; -} - /// Skew the statements in the body of a 'for' statement with the specified /// statement-wise delays. The delays are with respect to the original execution /// order. A delay of zero for each statement will lead to no change. @@ -260,7 +209,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays, return UtilResult::Failure; uint64_t tripCount = mayBeConstTripCount.getValue(); - assert(checkDominancePreservationOnShift(*forStmt, delays) && + assert(isStmtwiseShiftValid(*forStmt, delays) && "dominance preservation failed\n"); unsigned numChildStmts = forStmt->getStatements().size(); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index bb60d8e9d78..d6a064988fb 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -22,21 +22,31 @@ #include "mlir/Transforms/Passes.h" #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/StmtVisitor.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "pipeline-data-transfer" using namespace mlir; namespace { -struct PipelineDataTransfer : public MLFunctionPass { - explicit PipelineDataTransfer() {} +struct PipelineDataTransfer : public MLFunctionPass, + StmtWalker<PipelineDataTransfer> { PassResult runOnMLFunction(MLFunction *f) override; + PassResult runOnForStmt(ForStmt *forStmt); + + // Collect all 'for' statements. + void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } + std::vector<ForStmt *> forStmts; }; } // end anonymous namespace @@ -47,20 +57,6 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -/// Given a DMA start operation, returns the operand position of either the -/// source or destination memref depending on the one that is at the higher -/// level of the memory hierarchy. -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are -// added. TODO(b/117228571) -static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) { - unsigned srcDmaPos = 0; - unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1; - - if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace()) - return srcDmaPos; - return destDmaPos; -} - // Returns the position of the tag memref operand given a DMA statement. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) @@ -76,18 +72,20 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { /// Doubles the buffer of the supplied memref while replacing all uses of the /// old memref. Returns false if such a replacement cannot be performed. -static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { +static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) { MLFuncBuilder bInner(forStmt, forStmt->begin()); bInner.setInsertionPoint(forStmt, forStmt->begin()); // Doubles the shape with a leading dimension extent of 2. - auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * { + auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * { // Add the leading dimension in the shape for the double buffer. - ArrayRef<int> shape = origMemRefType->getShape(); + ArrayRef<int> shape = oldMemRefType->getShape(); SmallVector<int, 4> shapeSizes(shape.begin(), shape.end()); shapeSizes.insert(shapeSizes.begin(), 2); - auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type()); + auto *newMemRefType = + bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {}, + oldMemRefType->getMemorySpace()); return newMemRefType; }; @@ -105,113 +103,187 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { auto ivModTwoOp = bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt); if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, - cast<MLValue>(ivModTwoOp->getResult(0)))) + cast<MLValue>(ivModTwoOp->getResult(0)))) { + LLVM_DEBUG(llvm::dbgs() + << "memref replacement for double buffering failed\n";); + cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock(); return false; + } return true; } -// For testing purposes, this just runs on the first 'for' statement of an -// MLFunction at the top level. -// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when -// the other TODOs listed inside are dealt with. +/// Returns false if this succeeds on at least one 'for' stmt. PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { if (f->empty()) return PassResult::Success; - ForStmt *forStmt = nullptr; - for (auto &stmt : *f) { - if ((forStmt = dyn_cast<ForStmt>(&stmt))) { - break; - } + // Do a post order walk so that inner loop DMAs are processed first. This is + // necessary since 'for' statements nested within would otherwise become + // invalid (erased) when the outer loop is pipelined (the pipelined one gets + // deleted and replaced by a prologue, a new steady-state loop and an + // epilogue). + forStmts.clear(); + walkPostOrder(f); + bool ret = true; + for (auto *forStmt : forStmts) { + ret = ret & runOnForStmt(forStmt); } - if (!forStmt) - return PassResult::Success; - - unsigned numStmts = forStmt->getStatements().size(); + return ret ? failure() : success(); +} - if (numStmts == 0) - return PassResult::Success; +// Check if tags of the dma start op and dma wait op match. +static bool checkTagMatch(OpPointer<DmaStartOp> startOp, + OpPointer<DmaWaitOp> waitOp) { + if (startOp->getTagMemRef() != waitOp->getTagMemRef()) + return false; + auto startIndices = startOp->getTagIndices(); + auto waitIndices = waitOp->getTagIndices(); + // Both of these have the same number of indices since they correspond to the + // same tag memref. + for (auto it = startIndices.begin(), wIt = waitIndices.begin(), + e = startIndices.end(); + it != e; ++it, ++wIt) { + // Keep it simple for now, just checking if indices match. + // TODO(mlir-team): this would in general need to check if there is no + // intervening write writing to the same tag location, i.e., memory last + // write/data flow analysis. This is however sufficient/powerful enough for + // now since the DMA generation pass or the input for it will always have + // start/wait with matching tags (same SSA operand indices). + if (*it != *wIt) + return false; + } + return true; +} - SmallVector<OperationStmt *, 4> dmaStartStmts; - SmallVector<OperationStmt *, 4> dmaFinishStmts; +// Identify matching DMA start/finish statements to overlap computation with. +static void findMatchingStartFinishStmts( + ForStmt *forStmt, + SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>> + &startWaitPairs) { + SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts; for (auto &stmt : *forStmt) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; - if (opStmt->is<DmaStartOp>()) { - dmaStartStmts.push_back(opStmt); - } else if (opStmt->is<DmaWaitOp>()) { + // Collect DMA finish statements. + if (opStmt->is<DmaWaitOp>()) { dmaFinishStmts.push_back(opStmt); + continue; + } + OpPointer<DmaStartOp> dmaStartOp; + if (!(dmaStartOp = opStmt->getAs<DmaStartOp>())) + continue; + // Only DMAs incoming into higher memory spaces. + // TODO(bondhugula): outgoing DMAs. + if (!dmaStartOp->isDestMemorySpaceFaster()) + continue; + + // We only double buffer if the buffer is not live out of loop. + const MLValue *memref = + cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos())); + bool escapingUses = false; + for (const auto &use : memref->getUses()) { + if (!dominates(*forStmt, *use.getOwner())) { + LLVM_DEBUG(llvm::dbgs() + << "can't pipeline: buffer is live out of loop\n";); + escapingUses = true; + break; + } + } + if (!escapingUses) + dmaStartStmts.push_back(opStmt); + } + + // For each start statement, we look for a matching finish statement. + for (auto *dmaStartStmt : dmaStartStmts) { + for (auto *dmaFinishStmt : dmaFinishStmts) { + if (checkTagMatch(dmaStartStmt->getAs<DmaStartOp>(), + dmaFinishStmt->getAs<DmaWaitOp>())) { + startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt}); + break; + } } } +} - // TODO(bondhugula,andydavis): match tag memref's (requires memory-based - // subscript check utilities). Assume for now that start/finish are matched in - // the order they appear. - if (dmaStartStmts.size() != dmaFinishStmts.size()) +/// Overlap DMA transfers with computation in this loop. If successful, +/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// inserted right before where it was. +PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { + auto mayBeConstTripCount = getConstantTripCount(*forStmt); + if (!mayBeConstTripCount.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return PassResult::Failure; + } + + SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs; + findMatchingStartFinishStmts(forStmt, startWaitPairs); + + if (startWaitPairs.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); + return failure(); + } // Double the buffers for the higher memory space memref's. - // TODO(bondhugula): assuming we don't have multiple DMA starts for the same - // memref. + // Identify memref's to replace by scanning through all DMA start statements. + // A DMA start statement has two memref's - the one from the higher level of + // memory hierarchy is the one to double buffer. // TODO(bondhugula): check whether double-buffering is even necessary. // TODO(bondhugula): make this work with different layouts: assuming here that // the dimension we are adding here for the double buffering is the outermost // dimension. - // Identify memref's to replace by scanning through all DMA start statements. - // A DMA start statement has two memref's - the one from the higher level of - // memory hierarchy is the one to double buffer. - for (auto *dmaStartStmt : dmaStartStmts) { - MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( - getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>()))); + for (auto &pair : startWaitPairs) { + auto *dmaStartStmt = pair.first; + const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( + dmaStartStmt->getAs<DmaStartOp>()->getFasterMemPos())); if (!doubleBuffer(oldMemRef, forStmt)) { - return PassResult::Failure; + // Normally, double buffering should not fail because we already checked + // that there are no uses outside. + LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); + LLVM_DEBUG(dmaStartStmt->dump()); + return failure(); } } - // Double the buffers for tag memref's. - for (auto *dmaFinishStmt : dmaFinishStmts) { - MLValue *oldTagMemRef = cast<MLValue>( + // Double the buffers for tag memrefs. + for (auto &pair : startWaitPairs) { + const auto *dmaFinishStmt = pair.second; + const MLValue *oldTagMemRef = cast<MLValue>( dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt))); if (!doubleBuffer(oldTagMemRef, forStmt)) { - return PassResult::Failure; + LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); + return failure(); } } - // Collect all compute ops. - std::vector<const Statement *> computeOps; - computeOps.reserve(forStmt->getStatements().size()); + // Double buffering would have invalidated all the old DMA start/wait stmts. + startWaitPairs.clear(); + findMatchingStartFinishStmts(forStmt, startWaitPairs); + // Store delay for statement for later lookup for AffineApplyOp's. - DenseMap<const Statement *, unsigned> opDelayMap; - for (auto &stmt : *forStmt) { - auto *opStmt = dyn_cast<OperationStmt>(&stmt); - if (!opStmt) { - // All for and if stmt's are treated as pure compute operations. - opDelayMap[&stmt] = 1; - } else if (opStmt->is<DmaStartOp>()) { - // DMA starts are not shifted. - opDelayMap[opStmt] = 0; - // Set shifts for DMA start stmt's affine operand computation slices to 0. - if (auto *slice = mlir::createAffineComputationSlice(opStmt)) { - opDelayMap[slice] = 0; - } else { - // If a slice wasn't created, the reachable affine_apply op's from its - // operands are the ones that go with it. - SmallVector<OperationStmt *, 4> affineApplyStmts; - SmallVector<MLValue *, 4> operands(opStmt->getOperands()); - getReachableAffineApplyOps(operands, affineApplyStmts); - for (auto *op : affineApplyStmts) { - opDelayMap[op] = 0; - } - } - } else if (opStmt->is<DmaWaitOp>()) { - // DMA finish op shifted by one. - opDelayMap[opStmt] = 1; + DenseMap<const Statement *, unsigned> stmtDelayMap; + for (auto &pair : startWaitPairs) { + auto *dmaStartStmt = pair.first; + assert(dmaStartStmt->is<DmaStartOp>()); + stmtDelayMap[dmaStartStmt] = 0; + // Set shifts for DMA start stmt's affine operand computation slices to 0. + if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) { + stmtDelayMap[slice] = 0; } else { - // Everything else is a compute op; so shifted by one (op's supplying - // 'affine' operands to DMA start's have already been set right shifts. - opDelayMap[opStmt] = 1; - computeOps.push_back(&stmt); + // If a slice wasn't created, the reachable affine_apply op's from its + // operands are the ones that go with it. + SmallVector<OperationStmt *, 4> affineApplyStmts; + SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands()); + getReachableAffineApplyOps(operands, affineApplyStmts); + for (const auto *stmt : affineApplyStmts) { + stmtDelayMap[stmt] = 0; + } + } + } + // Everything else (including compute ops and dma finish) are shifted by one. + for (const auto &stmt : *forStmt) { + if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) { + stmtDelayMap[&stmt] = 1; } } @@ -219,18 +291,20 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { std::vector<uint64_t> delays(forStmt->getStatements().size()); unsigned s = 0; for (const auto &stmt : *forStmt) { - assert(opDelayMap.find(&stmt) != opDelayMap.end()); - delays[s++] = opDelayMap[&stmt]; + assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end()); + delays[s++] = stmtDelayMap[&stmt]; } - if (!checkDominancePreservationOnShift(*forStmt, delays)) { + if (!isStmtwiseShiftValid(*forStmt, delays)) { // Violates SSA dominance. + LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";); return PassResult::Failure; } if (stmtBodySkew(forStmt, delays)) { + LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";); return PassResult::Failure; } - return PassResult::Success; + return success(); } diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp index 99ddb087bd4..5df6e99b592 100644 --- a/mlir/lib/Transforms/Utils.cpp +++ b/mlir/lib/Transforms/Utils.cpp @@ -48,7 +48,8 @@ static bool isMemRefDereferencingOp(const Operation &op) { /// at the start for now. // TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily // extended to add additional indices at any position. -bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef, +bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, + MLValue *newMemRef, ArrayRef<MLValue *> extraIndices, AffineMap indexRemap) { unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank(); @@ -219,11 +220,11 @@ mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc, /// This allows applying different transformations on send and compute (for eg. /// different shifts/delays). /// -/// Returns nullptr if none of the operands were the result of an affine_apply -/// and thus there was no affine computation slice to create. Returns the newly -/// affine_apply operation statement otherwise. -/// -/// +/// Returns nullptr either if none of opStmt's operands were the result of an +/// affine_apply and thus there was no affine computation slice to create, or if +/// all the affine_apply op's supplying operands to this opStmt do not have any +/// uses besides this opStmt. Returns the new affine_apply operation statement +/// otherwise. OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { // Collect all operands that are results of affine apply ops. SmallVector<MLValue *, 4> subOperands; |

