summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Utils
diff options
context:
space:
mode:
authorUday Bondhugula <udayb@iisc.ac.in>2019-09-18 11:25:33 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-09-18 11:26:11 -0700
commit727a50ae2db4492a8c3168647996abacd75d0622 (patch)
tree15176e900a91c9f26634db0439318f7c384ce07a /mlir/lib/Transforms/Utils
parent1c73be76d84a04499b7e9ac5dfe129c204880dd8 (diff)
downloadbcm5719-llvm-727a50ae2db4492a8c3168647996abacd75d0622.tar.gz
bcm5719-llvm-727a50ae2db4492a8c3168647996abacd75d0622.zip
Support symbolic operands for memref replacement; fix memrefNormalize
- allow symbols in index remapping provided for memref replacement - fix memref normalize crash on cases with layout maps with symbols Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Reported by: Alex Zinenko Closes tensorflow/mlir#139 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/139 from bondhugula:memref-rep-symbols 2f48c1fdb5d4c58915bbddbd9f07b18541819233 PiperOrigin-RevId: 269851182
Diffstat (limited to 'mlir/lib/Transforms/Utils')
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp1
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp31
2 files changed, 23 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index e038512c0c0..0c9a666a6ec 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1548,6 +1548,7 @@ static LogicalResult generateCopy(
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/regionSymbols,
+ /*symbolOperands=*/{},
/*domInstFilter=*/&*begin,
/*postDomInstFilter=*/&*postDomFilter);
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index e57d40e5a1c..d6400ac50ed 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -62,14 +62,17 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
Operation *op,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
- ArrayRef<Value *> extraOperands) {
+ ArrayRef<Value *> extraOperands,
+ ArrayRef<Value *> symbolOperands) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
- (void)oldMemRefRank;
+ (void)oldMemRefRank; // unused in opt mode
if (indexRemap) {
- assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
- assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
+ assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
+ "symbolic operand count mistmatch");
+ assert(indexRemap.getNumInputs() ==
+ extraOperands.size() + oldMemRefRank + symbolOperands.size());
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
@@ -131,9 +134,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
// provided. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1.
SmallVector<Value *, 4> remapOperands;
- remapOperands.reserve(extraOperands.size() + oldMemRefRank);
+ remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+ symbolOperands.size());
remapOperands.append(extraOperands.begin(), extraOperands.end());
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+ remapOperands.append(symbolOperands.begin(), symbolOperands.end());
SmallVector<Value *, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
@@ -226,6 +231,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands,
+ ArrayRef<Value *> symbolOperands,
Operation *domInstFilter,
Operation *postDomInstFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
@@ -233,8 +239,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)oldMemRefRank;
if (indexRemap) {
- assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
- assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
+ assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
+ "symbol operand count mismatch");
+ assert(indexRemap.getNumInputs() ==
+ extraOperands.size() + oldMemRefRank + symbolOperands.size());
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
@@ -287,7 +295,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
for (auto *op : opsToReplace) {
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
- indexRemap, extraOperands)))
+ indexRemap, extraOperands,
+ symbolOperands)))
llvm_unreachable("memref replacement guaranteed to succeed here");
}
@@ -446,6 +455,8 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
}
auto *oldMemRef = allocOp.getResult();
+ SmallVector<Value *, 4> symbolOperands(allocOp.getSymbolicOperands());
+
auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(),
b.getMultiDimIdentityMap(newRank));
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
@@ -453,7 +464,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
/*extraIndices=*/{},
- /*indexRemap=*/layoutMap))) {
+ /*indexRemap=*/layoutMap,
+ /*extraOperands=*/{},
+ /*symbolOperands=*/symbolOperands))) {
// If it failed (due to escapes for example), bail out.
newAlloc.erase();
return failure();
OpenPOWER on IntegriCloud