summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/StandardOps
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-11-07 08:04:33 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-07 08:05:03 -0800
commit5fbdb67b0aa7f01b17dcca62e08e3db38d021fce (patch)
tree558cbcb37a4b93532aea6098de4616607602751a /mlir/lib/Dialect/StandardOps
parenta10d836c6de913445105274a0d92b0265da3bd2f (diff)
downloadbcm5719-llvm-5fbdb67b0aa7f01b17dcca62e08e3db38d021fce.tar.gz
bcm5719-llvm-5fbdb67b0aa7f01b17dcca62e08e3db38d021fce.zip
Add canonicalizer for ViewOp which folds constants into the ViewOp memref shape and layout map strides and offset.
PiperOrigin-RevId: 279088023
Diffstat (limited to 'mlir/lib/Dialect/StandardOps')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp112
1 files changed, 112 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 82d4324dff8..60002649a21 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -2419,6 +2419,118 @@ static LogicalResult verify(ViewOp op) {
return success();
}
+namespace {
+
+struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
+ using OpRewritePattern<ViewOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(ViewOp viewOp,
+ PatternRewriter &rewriter) const override {
+ // Return if none of the operands are constants.
+ if (llvm::none_of(viewOp.getOperands(), [](Value *operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return matchFailure();
+
+ // Get result memref type.
+ auto memrefType = viewOp.getType();
+ if (memrefType.getAffineMaps().size() != 1)
+ return matchFailure();
+ auto map = memrefType.getAffineMaps()[0];
+
+ // Fold any dynamic dim operands which are produced by a constant.
+ SmallVector<int64_t, 4> newShapeConstants;
+ newShapeConstants.reserve(memrefType.getRank());
+ SmallVector<Value *, 4> newOperands;
+ SmallVector<Value *, 4> droppedOperands;
+
+ unsigned dynamicDimPos = 1;
+ unsigned rank = memrefType.getRank();
+ for (unsigned dim = 0, e = rank; dim < e; ++dim) {
+ int64_t dimSize = memrefType.getDimSize(dim);
+ // If this is already static dimension, keep it.
+ if (!ShapedType::isDynamic(dimSize)) {
+ newShapeConstants.push_back(dimSize);
+ continue;
+ }
+ auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp();
+ if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+ // Dynamic shape dimension will be folded.
+ newShapeConstants.push_back(constantIndexOp.getValue());
+ // Record to check for zero uses later below.
+ droppedOperands.push_back(constantIndexOp);
+ } else {
+ // Dynamic shape dimension not folded; copy operand from old memref.
+ newShapeConstants.push_back(dimSize);
+ newOperands.push_back(viewOp.getOperand(dynamicDimPos));
+ }
+ dynamicDimPos++;
+ }
+
+ // Get offset from old memref view type 'memRefType'.
+ int64_t oldOffset;
+ llvm::SmallVector<int64_t, 4> oldStrides;
+ if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
+ return matchFailure();
+
+ // Fold dynamic offset operand if it is produced by a constant.
+ auto *dynamicOffset = viewOp.getDynamicOffset();
+ int64_t newOffset = oldOffset;
+ unsigned dynamicOffsetOperandCount = 0;
+ if (dynamicOffset != nullptr) {
+ auto *defOp = dynamicOffset->getDefiningOp();
+ if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
+ // Dynamic offset will be folded into the map.
+ newOffset = constantIndexOp.getValue();
+ droppedOperands.push_back(dynamicOffset);
+ } else {
+ // Unable to fold dynamic offset. Add it to 'newOperands' list.
+ newOperands.push_back(dynamicOffset);
+ dynamicOffsetOperandCount = 1;
+ }
+ }
+
+ // Compute new strides based on 'newShapeConstants'.
+ SmallVector<int64_t, 4> newStrides(rank);
+ newStrides[rank - 1] = 1;
+ bool dynamicStrides = false;
+ for (int i = rank - 2; i >= 0; --i) {
+ if (ShapedType::isDynamic(newShapeConstants[i + 1]))
+ dynamicStrides = true;
+ if (dynamicStrides)
+ newStrides[i] = MemRefType::getDynamicStrideOrOffset();
+ else
+ newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1];
+ }
+
+ // Regenerate strided layout map with 'newStrides' and 'newOffset'.
+ map = makeStridedLinearLayoutMap(newStrides, newOffset,
+ rewriter.getContext());
+
+ // Create new memref type with constant folded dims and/or offset/strides.
+ auto newMemRefType =
+ MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
+ memrefType.getMemorySpace());
+ assert(static_cast<int64_t>(newOperands.size()) ==
+ dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
+
+ // Create new ViewOp.
+ auto newShapeCastOp = rewriter.create<ViewOp>(
+ viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands);
+ // Insert a cast so we have the same type as the old memref type.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
+ newShapeCastOp, viewOp.getType());
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ViewOpShapeFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// ZeroExtendIOp
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud