summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Transforms/Utils.h15
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp23
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp21
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp5
-rw-r--r--mlir/test/Transforms/dma-generate.mlir26
5 files changed, 68 insertions, 22 deletions
diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h
index 3b828db6ae9..3dab02a4cd7 100644
--- a/mlir/include/mlir/Transforms/Utils.h
+++ b/mlir/include/mlir/Transforms/Utils.h
@@ -40,10 +40,10 @@ class Module;
class Function;
-/// Replaces all uses of oldMemRef with newMemRef while optionally remapping the
-/// old memref's indices using the supplied affine map, 'indexRemap'. The new
-/// memref could be of a different shape or rank. 'extraIndices' provides
-/// additional access indices to be added to the start.
+/// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally
+/// remapping the old memref's indices using the supplied affine map,
+/// 'indexRemap'. The new memref could be of a different shape or rank.
+/// 'extraIndices' provides additional access indices to be added to the start.
///
/// 'indexRemap' remaps indices of the old memref access to a new set of indices
/// that are used to index the memref. Additional input operands to indexRemap
@@ -57,9 +57,10 @@ class Function;
/// operations that are dominated by the former; similarly, `postDomInstFilter`
/// restricts replacement to only those operations that are postdominated by it.
///
-/// Returns true on success and false if the replacement is not possible
-/// (whenever a memref is used as an operand in a non-deferencing scenario). See
-/// comments at function definition for an example.
+/// Returns true on success and false if the replacement is not possible,
+/// whenever a memref is used as an operand in a non-deferencing context, except
+/// for dealloc's on the memref which are left untouched. See comments at
+/// function definition for an example.
//
// Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]:
// The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 45e57416111..29cc435a8a9 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -59,8 +59,6 @@ namespace {
/// by the latter. Only load op's handled for now.
// TODO(bondhugula): We currently can't generate DMAs correctly when stores are
// strided. Check for strided stores.
-// TODO(mlir-team): we don't insert dealloc's for the DMA buffers; this is thus
-// natural only for scoped allocations.
struct DmaGeneration : public FunctionPass {
explicit DmaGeneration(
unsigned slowMemorySpace = 0, unsigned fastMemorySpace = 1,
@@ -331,10 +329,8 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
Value *fastMemRef;
// Check if a buffer was already created.
- // TODO(bondhugula): union across all memory op's per buffer. For now assuming
- // that multiple memory op's on the same memref have the *same* memory
- // footprint.
- if (fastBufferMap.count(memref) == 0) {
+ bool existingBuf = fastBufferMap.count(memref) > 0;
+ if (!existingBuf) {
auto fastMemRefType = top.getMemRefType(
fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace);
@@ -358,6 +354,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
// Create a tag (single element 1-d memref) for the DMA.
auto tagMemRefType = top.getMemRefType({1}, top.getIntegerType(32));
auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
+
auto numElementsSSA =
top.create<ConstantIndexOp>(loc, numElements.getValue());
@@ -397,13 +394,23 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
zeroIndex, stride, numEltPerStride);
// Since new ops are being appended (for outgoing DMAs), adjust the end to
// mark end of range of the original.
- if (*nEnd == end)
- *nEnd = Block::iterator(op->getInstruction());
+ *nEnd = Block::iterator(op->getInstruction());
}
// Matching DMA wait to block on completion; tag always has a 0 index.
b->create<DmaWaitOp>(loc, tagMemRef, zeroIndex, numElementsSSA);
+ // Generate dealloc for the tag.
+ auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef);
+ if (*nEnd == end)
+ // Since new ops are being appended (for outgoing DMAs), adjust the end to
+ // mark end of range of the original.
+ *nEnd = Block::iterator(tagDeallocOp->getInstruction());
+
+ // Generate dealloc for the DMA buffer.
+ if (!existingBuf)
+ epilogue.create<DeallocOp>(loc, fastMemRef);
+
// Replace all uses of the old memref with the faster one while remapping
// access indices (subtracting out lower bound offsets for each dimension).
// Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index cfa045f2279..5c2e38205e7 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -124,8 +124,9 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer<AffineForOp> forOp) {
// replaceAllMemRefUsesWith will always succeed unless the forOp body has
// non-deferencing uses of the memref.
- if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(),
- {}, &*forOp->getBody()->begin())) {
+ if (!replaceAllMemRefUsesWith(
+ oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(), {},
+ /*domInstFilter=*/&*forOp->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getInstruction()->erase();
@@ -284,10 +285,20 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// If the old memref has no more uses, remove its 'dead' alloc if it was
// alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
// operation could have been used on it if it was dynamically shaped in
- // order to create the double buffer above)
- if (oldMemRef->use_empty())
- if (auto *allocInst = oldMemRef->getDefiningInst())
+ // order to create the double buffer above.)
+ // '-canonicalize' does this in a more general way, but we'll anyway do the
+ // simple/common case so that the output / test cases looks clear.
+ if (auto *allocInst = oldMemRef->getDefiningInst()) {
+ if (oldMemRef->use_empty()) {
allocInst->erase();
+ } else if (oldMemRef->hasOneUse()) {
+ auto *singleUse = oldMemRef->use_begin()->getOwner();
+ if (singleUse->isa<DeallocOp>()) {
+ singleUse->erase();
+ oldMemRef->getDefiningInst()->erase();
+ }
+ }
+ }
}
// Double the buffers for tag memrefs.
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 41689be52fc..519885b3a50 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -91,6 +91,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
!postDomInfo->postDominates(postDomInstFilter, opInst))
continue;
+ // Skip dealloc's - no replacement is necessary, and a replacement doesn't
+ // hurt dealloc's.
+ if (opInst->isa<DeallocOp>())
+ continue;
+
// Check if the memref was used in a non-deferencing context. It is fine for
// the memref to be used in a non-deferencing way outside of the region
// where this replacement is happening.
diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir
index bf6062e78e8..4e4f8736d49 100644
--- a/mlir/test/Transforms/dma-generate.mlir
+++ b/mlir/test/Transforms/dma-generate.mlir
@@ -40,6 +40,10 @@ func @loop_nest_1d() {
// Already in faster memory space.
// CHECK: %11 = load %2[%i0] : memref<256xf32, 1>
// CHECK-NEXT: }
+ // CHECK-NEXT: dealloc %6 : memref<1xi32>
+ // CHECK-NEXT: dealloc %5 : memref<256xf32, 1>
+ // CHECK-NEXT: dealloc %4 : memref<1xi32>
+ // CHECK-NEXT: dealloc %3 : memref<256xf32, 1>
// CHECK-NEXT: return
for %i = 0 to 256 {
load %A[%i] : memref<256 x f32>
@@ -95,6 +99,13 @@ func @loop_nest_1d() {
// OUTGOING DMA for C.
// CHECK-NEXT: dma_start [[BUFC]][%c0, %c0], %arg2[%c0, %c0], %c16384, [[TAGC_W]][%c0] : memref<512x32xf32, 1>, memref<512x32xf32>, memref<1xi32>
// CHECK-NEXT: dma_wait [[TAGC_W]][%c0], %c16384 : memref<1xi32>
+// CHECK-NEXT: dealloc [[TAGC_W]] : memref<1xi32>
+// CHECK-NEXT: dealloc [[TAGC]] : memref<1xi32>
+// CHECK-NEXT: dealloc [[BUFC]] : memref<512x32xf32, 1>
+// CHECK-NEXT: dealloc [[TAGA]] : memref<1xi32>
+// CHECK-NEXT: dealloc [[BUFA]] : memref<512x32xf32, 1>
+// CHECK-NEXT: dealloc [[TAGB]] : memref<1xi32>
+// CHECK-NEXT: dealloc [[BUFB]] : memref<512x32xf32, 1>
// CHECK-NEXT: return
// CHECK-NEXT:}
func @loop_nest_high_d(%A: memref<512 x 32 x f32>,
@@ -144,6 +155,8 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>,
// ...
// ...
// CHECK: }
+// CHECK-NEXT: dealloc %3 : memref<1xi32>
+// CHECK-NEXT: dealloc %2 : memref<1x2xf32, 1>
// CHECK-NEXT: }
// CHECK-NEXT: return
func @loop_nest_modulo() {
@@ -183,7 +196,6 @@ func @loop_nest_tiled() -> memref<256x1024xf32> {
}
}
}
- // CHECK: return %0 : memref<256x1024xf32>
return %0 : memref<256x1024xf32>
}
@@ -229,7 +241,7 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) {
// CHECK-NEXT: %6 = load %1[%4, %5] : memref<100x100xf32, 1>
// CHECK-NEXT: }
// CHECK-NEXT: }
-// CHECK-NEXT: return
+// CHECK: return
}
// CHECK-LABEL: func @dma_with_symbolic_loop_bounds
@@ -357,6 +369,9 @@ func @multi_load_store_union() {
// CHECK-NEXT: }
// CHECK-NEXT: dma_start %1[%c0, %c0], %0[%c2, %c2_0], %c170372, %3[%c0], %c512, %c446 : memref<382x446xf32, 1>, memref<512x512xf32>, memref<1xi32>
// CHECK-NEXT: dma_wait %3[%c0], %c170372 : memref<1xi32>
+// CHECK-NEXT: dealloc %3 : memref<1xi32>
+// CHECK-NEXT: dealloc %2 : memref<1xi32>
+// CHECK-NEXT: dealloc %1 : memref<382x446xf32, 1>
// CHECK-NEXT: return
// CHECK-NEXT:}
@@ -385,6 +400,8 @@ func @dma_loop_straightline_interspersed() {
// CHECK-NEXT: dma_start %0[%c0], %1[%c0], %c1_1, %2[%c0] : memref<256xf32>, memref<1xf32, 1>, memref<1xi32>
// CHECK-NEXT: dma_wait %2[%c0], %c1_1 : memref<1xi32>
// CHECK-NEXT: %3 = load %1[%c0_2] : memref<1xf32, 1>
+// CHECK-NEXT: dealloc %2 : memref<1xi32>
+// CHECK-NEXT: dealloc %1 : memref<1xf32, 1>
// CHECK-NEXT: %4 = alloc() : memref<254xf32, 1>
// CHECK-NEXT: %5 = alloc() : memref<1xi32>
// CHECK-NEXT: dma_start %0[%c1_0], %4[%c0], %c254, %5[%c0] : memref<256xf32>, memref<254xf32, 1>, memref<1xi32>
@@ -393,6 +410,8 @@ func @dma_loop_straightline_interspersed() {
// CHECK-NEXT: %6 = affine.apply [[MAP_MINUS_ONE]](%i0)
// CHECK-NEXT: %7 = load %4[%6] : memref<254xf32, 1>
// CHECK-NEXT: }
+// CHECK-NEXT: dealloc %5 : memref<1xi32>
+// CHECK-NEXT: dealloc %4 : memref<254xf32, 1>
// CHECK-NEXT: %8 = alloc() : memref<256xf32, 1>
// CHECK-NEXT: %9 = alloc() : memref<1xi32>
// CHECK-NEXT: dma_start %0[%c0], %8[%c0], %c256, %9[%c0] : memref<256xf32>, memref<256xf32, 1>, memref<1xi32>
@@ -402,6 +421,9 @@ func @dma_loop_straightline_interspersed() {
// CHECK-NEXT: store %11, %8[%c0_2] : memref<256xf32, 1>
// CHECK-NEXT: dma_start %8[%c0], %0[%c0], %c1, %10[%c0] : memref<256xf32, 1>, memref<256xf32>, memref<1xi32>
// CHECK-NEXT: dma_wait %10[%c0], %c1 : memref<1xi32>
+// CHECK-NEXT: dealloc %10 : memref<1xi32>
+// CHECK-NEXT: dealloc %9 : memref<1xi32>
+// CHECK-NEXT: dealloc %8 : memref<256xf32, 1>
// CHECK-NEXT: return
// -----
OpenPOWER on IntegriCloud