diff options
Diffstat (limited to 'mlir/lib/Transforms/MaterializeVectors.cpp')
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index e95bb7307e3..ebfdbc28d6c 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -463,10 +463,13 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, ++dim; }, superVectorType.getShape(), *optionalRatio); - auto projectionMap = AffineMap::get(optionalRatio->size(), 0, keep, {}); auto permutationMap = transfer->getPermutationMap(); - LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: ")); LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: ")); + if (keep.empty()) { + return permutationMap; + } + auto projectionMap = AffineMap::get(optionalRatio->size(), 0, keep, {}); + LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: ")); return composeUnboundedMaps(projectionMap, permutationMap); } @@ -484,9 +487,13 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, map(makePtrDynCaster<Value>(), read->getIndices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); + auto map = projectedPermutationMap(read, hwVectorType); + if (!map) { + return nullptr; + } auto cloned = b->create<VectorTransferReadOp>( - read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, - projectedPermutationMap(read, hwVectorType), read->getPaddingValue()); + read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, map, + read->getPaddingValue()); return cloned->getInstruction(); } |

