From 2ef57806ba0917195ad9c7917eb1da249bf86796 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Mon, 10 Dec 2018 11:39:31 -0800 Subject: Update/fix -pipeline-data-transfer; fix b/120770946 - fix replaceAllMemRefUsesWith call to replace only inside loop body. - handle the case where DMA buffers are dynamic; extend doubleBuffer() method to handle dynamically shaped DMA buffers (pass the right operands to AllocOp) - place alloc's for DMA buffers at the depth at which pipelining is being done (instead of at top-level) - add more test cases PiperOrigin-RevId: 224852231 --- mlir/lib/Transforms/PipelineDataTransfer.cpp | 47 ++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 13 deletions(-) (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp') diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 3d8c21c543e..fc97aa8d2d2 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -75,9 +75,12 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { return 0; } -/// Doubles the buffer of the supplied memref while replacing all uses of the -/// old memref. Returns false if such a replacement cannot be performed. -static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) { +/// Doubles the buffer of the supplied memref on the specified 'for' statement +/// by adding a leading dimension of size two to the memref. Replaces all uses +/// of the old memref by the new one while indexing the newly added dimension by +/// the loop IV of the specified 'for' statement modulo 2. Returns false if such +/// a replacement cannot be performed. +static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { MLFuncBuilder bInner(forStmt, forStmt->begin()); bInner.setInsertionPoint(forStmt, forStmt->begin()); @@ -94,21 +97,37 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) { return newMemRefType; }; - auto newMemRefType = doubleShape(oldMemRef->getType().cast()); + auto oldMemRefType = oldMemRef->getType().cast(); + auto newMemRefType = doubleShape(oldMemRefType); + + // Put together alloc operands for the dynamic dimensions of the memref. + MLFuncBuilder bOuter(forStmt); + SmallVector allocOperands; + unsigned dynamicDimCount = 0; + for (auto dimSize : oldMemRefType.getShape()) { + if (dimSize == -1) + allocOperands.push_back(bOuter.create(forStmt->getLoc(), oldMemRef, + dynamicDimCount++)); + } - // Create and place the alloc at the top level. - MLFuncBuilder topBuilder(forStmt->getFunction()); - auto newMemRef = cast( - topBuilder.create(forStmt->getLoc(), newMemRefType) - ->getResult()); + // Create and place the alloc right before the 'for' statement. + // TODO(mlir-team): we are assuming scoped allocation here, and aren't + // inserting a dealloc -- this isn't the right thing. + SSAValue *newMemRef = + bOuter.create(forStmt->getLoc(), newMemRefType, allocOperands); + // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {}); auto ivModTwoOp = bInner.create(forStmt->getLoc(), modTwoMap, forStmt); - if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, - cast(ivModTwoOp->getResult(0)))) { + + // replaceAllMemRefUsesWith will always succeed unless the forStmt body has + // non-deferencing uses of the memref. + if (!replaceAllMemRefUsesWith(oldMemRef, cast(newMemRef), + ivModTwoOp->getResult(0), AffineMap::Null(), {}, + &*forStmt->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getOperation()->erase(); @@ -185,7 +204,7 @@ static void findMatchingStartFinishStmts( cast(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos())); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!dominates(*forStmt, *use.getOwner())) { + if (!dominates(*forStmt->begin(), *use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -247,7 +266,9 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { return success(); } // If the old memref has no more uses, remove its 'dead' alloc if it was - // alloc'ed (note: DMA buffers are rarely function live-in). + // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' + // operation could have been used on it if it was dynamically shaped in + // order to create the double buffer above) if (oldMemRef->use_empty()) if (auto *allocStmt = oldMemRef->getDefiningStmt()) allocStmt->erase(); -- cgit v1.2.3