diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Transforms/Utils.h | 15 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DmaGeneration.cpp | 23 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 21 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 5 | ||||
| -rw-r--r-- | mlir/test/Transforms/dma-generate.mlir | 26 |
5 files changed, 68 insertions, 22 deletions
diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index 3b828db6ae9..3dab02a4cd7 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -40,10 +40,10 @@ class Module; class Function; -/// Replaces all uses of oldMemRef with newMemRef while optionally remapping the -/// old memref's indices using the supplied affine map, 'indexRemap'. The new -/// memref could be of a different shape or rank. 'extraIndices' provides -/// additional access indices to be added to the start. +/// Replaces all "deferencing" uses of oldMemRef with newMemRef while optionally +/// remapping the old memref's indices using the supplied affine map, +/// 'indexRemap'. The new memref could be of a different shape or rank. +/// 'extraIndices' provides additional access indices to be added to the start. /// /// 'indexRemap' remaps indices of the old memref access to a new set of indices /// that are used to index the memref. Additional input operands to indexRemap @@ -57,9 +57,10 @@ class Function; /// operations that are dominated by the former; similarly, `postDomInstFilter` /// restricts replacement to only those operations that are postdominated by it. /// -/// Returns true on success and false if the replacement is not possible -/// (whenever a memref is used as an operand in a non-deferencing scenario). See -/// comments at function definition for an example. +/// Returns true on success and false if the replacement is not possible, +/// whenever a memref is used as an operand in a non-deferencing context, except +/// for dealloc's on the memref which are left untouched. See comments at +/// function definition for an example. // // Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]: // The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 45e57416111..29cc435a8a9 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -59,8 +59,6 @@ namespace { /// by the latter. Only load op's handled for now. // TODO(bondhugula): We currently can't generate DMAs correctly when stores are // strided. Check for strided stores. -// TODO(mlir-team): we don't insert dealloc's for the DMA buffers; this is thus -// natural only for scoped allocations. struct DmaGeneration : public FunctionPass { explicit DmaGeneration( unsigned slowMemorySpace = 0, unsigned fastMemorySpace = 1, @@ -331,10 +329,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, Value *fastMemRef; // Check if a buffer was already created. - // TODO(bondhugula): union across all memory op's per buffer. For now assuming - // that multiple memory op's on the same memref have the *same* memory - // footprint. - if (fastBufferMap.count(memref) == 0) { + bool existingBuf = fastBufferMap.count(memref) > 0; + if (!existingBuf) { auto fastMemRefType = top.getMemRefType( fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace); @@ -358,6 +354,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, // Create a tag (single element 1-d memref) for the DMA. auto tagMemRefType = top.getMemRefType({1}, top.getIntegerType(32)); auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType); + auto numElementsSSA = top.create<ConstantIndexOp>(loc, numElements.getValue()); @@ -397,13 +394,23 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, zeroIndex, stride, numEltPerStride); // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. - if (*nEnd == end) - *nEnd = Block::iterator(op->getInstruction()); + *nEnd = Block::iterator(op->getInstruction()); } // Matching DMA wait to block on completion; tag always has a 0 index. b->create<DmaWaitOp>(loc, tagMemRef, zeroIndex, numElementsSSA); + // Generate dealloc for the tag. + auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef); + if (*nEnd == end) + // Since new ops are being appended (for outgoing DMAs), adjust the end to + // mark end of range of the original. + *nEnd = Block::iterator(tagDeallocOp->getInstruction()); + + // Generate dealloc for the DMA buffer. + if (!existingBuf) + epilogue.create<DeallocOp>(loc, fastMemRef); + // Replace all uses of the old memref with the faster one while remapping // access indices (subtracting out lower bound offsets for each dimension). // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT], diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index cfa045f2279..5c2e38205e7 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -124,8 +124,9 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer<AffineForOp> forOp) { // replaceAllMemRefUsesWith will always succeed unless the forOp body has // non-deferencing uses of the memref. - if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(), - {}, &*forOp->getBody()->begin())) { + if (!replaceAllMemRefUsesWith( + oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(), {}, + /*domInstFilter=*/&*forOp->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -284,10 +285,20 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) { // If the old memref has no more uses, remove its 'dead' alloc if it was // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' // operation could have been used on it if it was dynamically shaped in - // order to create the double buffer above) - if (oldMemRef->use_empty()) - if (auto *allocInst = oldMemRef->getDefiningInst()) + // order to create the double buffer above.) + // '-canonicalize' does this in a more general way, but we'll anyway do the + // simple/common case so that the output / test cases looks clear. + if (auto *allocInst = oldMemRef->getDefiningInst()) { + if (oldMemRef->use_empty()) { allocInst->erase(); + } else if (oldMemRef->hasOneUse()) { + auto *singleUse = oldMemRef->use_begin()->getOwner(); + if (singleUse->isa<DeallocOp>()) { + singleUse->erase(); + oldMemRef->getDefiningInst()->erase(); + } + } + } } // Double the buffers for tag memrefs. diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 41689be52fc..519885b3a50 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -91,6 +91,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, !postDomInfo->postDominates(postDomInstFilter, opInst)) continue; + // Skip dealloc's - no replacement is necessary, and a replacement doesn't + // hurt dealloc's. + if (opInst->isa<DeallocOp>()) + continue; + // Check if the memref was used in a non-deferencing context. It is fine for // the memref to be used in a non-deferencing way outside of the region // where this replacement is happening. diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index bf6062e78e8..4e4f8736d49 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -40,6 +40,10 @@ func @loop_nest_1d() { // Already in faster memory space. // CHECK: %11 = load %2[%i0] : memref<256xf32, 1> // CHECK-NEXT: } + // CHECK-NEXT: dealloc %6 : memref<1xi32> + // CHECK-NEXT: dealloc %5 : memref<256xf32, 1> + // CHECK-NEXT: dealloc %4 : memref<1xi32> + // CHECK-NEXT: dealloc %3 : memref<256xf32, 1> // CHECK-NEXT: return for %i = 0 to 256 { load %A[%i] : memref<256 x f32> @@ -95,6 +99,13 @@ func @loop_nest_1d() { // OUTGOING DMA for C. // CHECK-NEXT: dma_start [[BUFC]][%c0, %c0], %arg2[%c0, %c0], %c16384, [[TAGC_W]][%c0] : memref<512x32xf32, 1>, memref<512x32xf32>, memref<1xi32> // CHECK-NEXT: dma_wait [[TAGC_W]][%c0], %c16384 : memref<1xi32> +// CHECK-NEXT: dealloc [[TAGC_W]] : memref<1xi32> +// CHECK-NEXT: dealloc [[TAGC]] : memref<1xi32> +// CHECK-NEXT: dealloc [[BUFC]] : memref<512x32xf32, 1> +// CHECK-NEXT: dealloc [[TAGA]] : memref<1xi32> +// CHECK-NEXT: dealloc [[BUFA]] : memref<512x32xf32, 1> +// CHECK-NEXT: dealloc [[TAGB]] : memref<1xi32> +// CHECK-NEXT: dealloc [[BUFB]] : memref<512x32xf32, 1> // CHECK-NEXT: return // CHECK-NEXT:} func @loop_nest_high_d(%A: memref<512 x 32 x f32>, @@ -144,6 +155,8 @@ func @loop_nest_high_d(%A: memref<512 x 32 x f32>, // ... // ... // CHECK: } +// CHECK-NEXT: dealloc %3 : memref<1xi32> +// CHECK-NEXT: dealloc %2 : memref<1x2xf32, 1> // CHECK-NEXT: } // CHECK-NEXT: return func @loop_nest_modulo() { @@ -183,7 +196,6 @@ func @loop_nest_tiled() -> memref<256x1024xf32> { } } } - // CHECK: return %0 : memref<256x1024xf32> return %0 : memref<256x1024xf32> } @@ -229,7 +241,7 @@ func @dma_with_symbolic_accesses(%A : memref<100x100xf32>, %M : index) { // CHECK-NEXT: %6 = load %1[%4, %5] : memref<100x100xf32, 1> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: return +// CHECK: return } // CHECK-LABEL: func @dma_with_symbolic_loop_bounds @@ -357,6 +369,9 @@ func @multi_load_store_union() { // CHECK-NEXT: } // CHECK-NEXT: dma_start %1[%c0, %c0], %0[%c2, %c2_0], %c170372, %3[%c0], %c512, %c446 : memref<382x446xf32, 1>, memref<512x512xf32>, memref<1xi32> // CHECK-NEXT: dma_wait %3[%c0], %c170372 : memref<1xi32> +// CHECK-NEXT: dealloc %3 : memref<1xi32> +// CHECK-NEXT: dealloc %2 : memref<1xi32> +// CHECK-NEXT: dealloc %1 : memref<382x446xf32, 1> // CHECK-NEXT: return // CHECK-NEXT:} @@ -385,6 +400,8 @@ func @dma_loop_straightline_interspersed() { // CHECK-NEXT: dma_start %0[%c0], %1[%c0], %c1_1, %2[%c0] : memref<256xf32>, memref<1xf32, 1>, memref<1xi32> // CHECK-NEXT: dma_wait %2[%c0], %c1_1 : memref<1xi32> // CHECK-NEXT: %3 = load %1[%c0_2] : memref<1xf32, 1> +// CHECK-NEXT: dealloc %2 : memref<1xi32> +// CHECK-NEXT: dealloc %1 : memref<1xf32, 1> // CHECK-NEXT: %4 = alloc() : memref<254xf32, 1> // CHECK-NEXT: %5 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%c1_0], %4[%c0], %c254, %5[%c0] : memref<256xf32>, memref<254xf32, 1>, memref<1xi32> @@ -393,6 +410,8 @@ func @dma_loop_straightline_interspersed() { // CHECK-NEXT: %6 = affine.apply [[MAP_MINUS_ONE]](%i0) // CHECK-NEXT: %7 = load %4[%6] : memref<254xf32, 1> // CHECK-NEXT: } +// CHECK-NEXT: dealloc %5 : memref<1xi32> +// CHECK-NEXT: dealloc %4 : memref<254xf32, 1> // CHECK-NEXT: %8 = alloc() : memref<256xf32, 1> // CHECK-NEXT: %9 = alloc() : memref<1xi32> // CHECK-NEXT: dma_start %0[%c0], %8[%c0], %c256, %9[%c0] : memref<256xf32>, memref<256xf32, 1>, memref<1xi32> @@ -402,6 +421,9 @@ func @dma_loop_straightline_interspersed() { // CHECK-NEXT: store %11, %8[%c0_2] : memref<256xf32, 1> // CHECK-NEXT: dma_start %8[%c0], %0[%c0], %c1, %10[%c0] : memref<256xf32, 1>, memref<256xf32>, memref<1xi32> // CHECK-NEXT: dma_wait %10[%c0], %c1 : memref<1xi32> +// CHECK-NEXT: dealloc %10 : memref<1xi32> +// CHECK-NEXT: dealloc %9 : memref<1xi32> +// CHECK-NEXT: dealloc %8 : memref<256xf32, 1> // CHECK-NEXT: return // ----- |

