summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-11-06 11:25:16 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-06 11:25:54 -0800
commitb5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd (patch)
tree6a4eaf7306a999d40eec701647ade14ad50d6e6f /mlir/lib
parent5967f91770af670278ae9e668760e8d5be6bbb48 (diff)
downloadbcm5719-llvm-b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd.tar.gz
bcm5719-llvm-b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd.zip
Add ViewOp verification for dynamic strides, and address some comments from previous change.
PiperOrigin-RevId: 278903187
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp32
1 files changed, 20 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index e6b99035f6e..5a452c5242a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -2376,17 +2376,8 @@ static void print(OpAsmPrinter &p, ViewOp op) {
}
static LogicalResult verify(ViewOp op) {
- auto baseType = op.getOperand(0)->getType().dyn_cast<MemRefType>();
- auto viewType = op.getResult()->getType().dyn_cast<MemRefType>();
-
- // Operand 0 type and ViewOp result type must be memref.
- if (!baseType || !viewType)
- return op.emitError("operand type ") << baseType << " and result type "
- << viewType << " are must be memref";
-
- // The base memref should be rank 1 with i8 element type.
- if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8))
- return op.emitError("unsupported shape for base memref type ") << baseType;
+ auto baseType = op.getOperand(0)->getType().cast<MemRefType>();
+ auto viewType = op.getResult()->getType().cast<MemRefType>();
// The base memref should have identity layout map (or none).
if (baseType.getAffineMaps().size() > 1 ||
@@ -2403,7 +2394,7 @@ static LogicalResult verify(ViewOp op) {
// Verify that the result memref type has a strided layout map. is strided
int64_t offset;
llvm::SmallVector<int64_t, 4> strides;
- if (failed(mlir::getStridesAndOffset(viewType, strides, offset)))
+ if (failed(getStridesAndOffset(viewType, strides, offset)))
return op.emitError("result type ") << viewType << " is not strided";
// Verify that we have the correct number of operands for the result type.
@@ -2414,6 +2405,23 @@ static LogicalResult verify(ViewOp op) {
if (op.getNumOperands() !=
memrefOperandCount + numDynamicDims + dynamicOffsetCount)
return op.emitError("incorrect number of operands for type ") << viewType;
+
+ // Verify dynamic strides symbols were added to correct dimensions based
+ // on dynamic sizes.
+ ArrayRef<int64_t> viewShape = viewType.getShape();
+ unsigned viewRank = viewType.getRank();
+ assert(viewRank == strides.size());
+ bool dynamicStrides = false;
+ for (int i = viewRank - 2; i >= 0; --i) {
+ // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag.
+ if (ShapedType::isDynamic(viewShape[i + 1]))
+ dynamicStrides = true;
+ // If stride at dim 'i' is not dynamic, return error.
+ if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset())
+ return op.emitError("incorrect dynamic strides in view memref type ")
+ << viewType;
+ }
+
return success();
}
OpenPOWER on IntegriCloud