diff options
| author | Uday Bondhugula <udayb@iisc.ac.in> | 2019-09-18 11:25:33 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-09-18 11:26:11 -0700 |
| commit | 727a50ae2db4492a8c3168647996abacd75d0622 (patch) | |
| tree | 15176e900a91c9f26634db0439318f7c384ce07a /mlir/lib/Transforms/Utils | |
| parent | 1c73be76d84a04499b7e9ac5dfe129c204880dd8 (diff) | |
| download | bcm5719-llvm-727a50ae2db4492a8c3168647996abacd75d0622.tar.gz bcm5719-llvm-727a50ae2db4492a8c3168647996abacd75d0622.zip | |
Support symbolic operands for memref replacement; fix memrefNormalize
- allow symbols in index remapping provided for memref replacement
- fix memref normalize crash on cases with layout maps with symbols
Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Reported by: Alex Zinenko
Closes tensorflow/mlir#139
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/139 from bondhugula:memref-rep-symbols 2f48c1fdb5d4c58915bbddbd9f07b18541819233
PiperOrigin-RevId: 269851182
Diffstat (limited to 'mlir/lib/Transforms/Utils')
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 1 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 31 |
2 files changed, 23 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index e038512c0c0..0c9a666a6ec 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1548,6 +1548,7 @@ static LogicalResult generateCopy( replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/regionSymbols, + /*symbolOperands=*/{}, /*domInstFilter=*/&*begin, /*postDomInstFilter=*/&*postDomFilter); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index e57d40e5a1c..d6400ac50ed 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -62,14 +62,17 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, Operation *op, ArrayRef<Value *> extraIndices, AffineMap indexRemap, - ArrayRef<Value *> extraOperands) { + ArrayRef<Value *> extraOperands, + ArrayRef<Value *> symbolOperands) { unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); - (void)oldMemRefRank; + (void)oldMemRefRank; // unused in opt mode if (indexRemap) { - assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); - assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbolic operand count mistmatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); @@ -131,9 +134,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. SmallVector<Value *, 4> remapOperands; - remapOperands.reserve(extraOperands.size() + oldMemRefRank); + remapOperands.reserve(extraOperands.size() + oldMemRefRank + + symbolOperands.size()); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); + remapOperands.append(symbolOperands.begin(), symbolOperands.end()); SmallVector<Value *, 4> remapOutputs; remapOutputs.reserve(oldMemRefRank); @@ -226,6 +231,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, ArrayRef<Value *> extraIndices, AffineMap indexRemap, ArrayRef<Value *> extraOperands, + ArrayRef<Value *> symbolOperands, Operation *domInstFilter, Operation *postDomInstFilter) { unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); @@ -233,8 +239,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); (void)oldMemRefRank; if (indexRemap) { - assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); - assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); + assert(indexRemap.getNumSymbols() == symbolOperands.size() && + "symbol operand count mismatch"); + assert(indexRemap.getNumInputs() == + extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); @@ -287,7 +295,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, for (auto *op : opsToReplace) { if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, - indexRemap, extraOperands))) + indexRemap, extraOperands, + symbolOperands))) llvm_unreachable("memref replacement guaranteed to succeed here"); } @@ -446,6 +455,8 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { } auto *oldMemRef = allocOp.getResult(); + SmallVector<Value *, 4> symbolOperands(allocOp.getSymbolicOperands()); + auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(), b.getMultiDimIdentityMap(newRank)); auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType); @@ -453,7 +464,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, /*extraIndices=*/{}, - /*indexRemap=*/layoutMap))) { + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/symbolOperands))) { // If it failed (due to escapes for example), bail out. newAlloc.erase(); return failure(); |

