summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/StandardOps
diff options
context:
space:
mode:
authorStephan Herhut <herhut@google.com>2019-11-18 04:31:02 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-18 04:31:33 -0800
commitf0f3b71d67e218f2a556065e16ee3d0ec86067ef (patch)
tree228b719beacec73c1812ea6658cf91eeb20dde01 /mlir/lib/Dialect/StandardOps
parentb8dc3fd81273b5928bfe98519cf10ce5bf9c565d (diff)
downloadbcm5719-llvm-f0f3b71d67e218f2a556065e16ee3d0ec86067ef.tar.gz
bcm5719-llvm-f0f3b71d67e218f2a556065e16ee3d0ec86067ef.zip
Implement folding of pattern dim(subview(_)[...][s1, ..., sn][...], i) -> si.
PiperOrigin-RevId: 281042016
Diffstat (limited to 'mlir/lib/Dialect/StandardOps')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp10
1 files changed, 9 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index c4abee3858e..e38ce065647 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -1347,7 +1347,7 @@ static LogicalResult verify(DimOp op) {
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
// Constant fold dim when the size along the index referred to is a constant.
- auto opType = getOperand()->getType();
+ auto opType = memrefOrTensor()->getType();
int64_t indexSize = -1;
if (auto tensorType = opType.dyn_cast<RankedTensorType>())
indexSize = tensorType.getShape()[getIndex()];
@@ -1357,6 +1357,14 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
if (indexSize >= 0)
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
+ // Fold dim to the size argument of a SubViewOp.
+ auto memref = memrefOrTensor()->getDefiningOp();
+ if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
+ auto sizes = subview.getDynamicSizes();
+ if (!sizes.empty())
+ return *(sizes.begin() + getIndex());
+ }
+
return {};
}
OpenPOWER on IntegriCloud