diff options
| author | Andy Davis <andydavis@google.com> | 2019-11-06 11:25:16 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-06 11:25:54 -0800 |
| commit | b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd (patch) | |
| tree | 6a4eaf7306a999d40eec701647ade14ad50d6e6f /mlir/lib | |
| parent | 5967f91770af670278ae9e668760e8d5be6bbb48 (diff) | |
| download | bcm5719-llvm-b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd.tar.gz bcm5719-llvm-b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd.zip | |
Add ViewOp verification for dynamic strides, and address some comments from previous change.
PiperOrigin-RevId: 278903187
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 32 |
1 files changed, 20 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index e6b99035f6e..5a452c5242a 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2376,17 +2376,8 @@ static void print(OpAsmPrinter &p, ViewOp op) { } static LogicalResult verify(ViewOp op) { - auto baseType = op.getOperand(0)->getType().dyn_cast<MemRefType>(); - auto viewType = op.getResult()->getType().dyn_cast<MemRefType>(); - - // Operand 0 type and ViewOp result type must be memref. - if (!baseType || !viewType) - return op.emitError("operand type ") << baseType << " and result type " - << viewType << " are must be memref"; - - // The base memref should be rank 1 with i8 element type. - if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8)) - return op.emitError("unsupported shape for base memref type ") << baseType; + auto baseType = op.getOperand(0)->getType().cast<MemRefType>(); + auto viewType = op.getResult()->getType().cast<MemRefType>(); // The base memref should have identity layout map (or none). if (baseType.getAffineMaps().size() > 1 || @@ -2403,7 +2394,7 @@ static LogicalResult verify(ViewOp op) { // Verify that the result memref type has a strided layout map. is strided int64_t offset; llvm::SmallVector<int64_t, 4> strides; - if (failed(mlir::getStridesAndOffset(viewType, strides, offset))) + if (failed(getStridesAndOffset(viewType, strides, offset))) return op.emitError("result type ") << viewType << " is not strided"; // Verify that we have the correct number of operands for the result type. @@ -2414,6 +2405,23 @@ static LogicalResult verify(ViewOp op) { if (op.getNumOperands() != memrefOperandCount + numDynamicDims + dynamicOffsetCount) return op.emitError("incorrect number of operands for type ") << viewType; + + // Verify dynamic strides symbols were added to correct dimensions based + // on dynamic sizes. + ArrayRef<int64_t> viewShape = viewType.getShape(); + unsigned viewRank = viewType.getRank(); + assert(viewRank == strides.size()); + bool dynamicStrides = false; + for (int i = viewRank - 2; i >= 0; --i) { + // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. + if (ShapedType::isDynamic(viewShape[i + 1])) + dynamicStrides = true; + // If stride at dim 'i' is not dynamic, return error. + if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) + return op.emitError("incorrect dynamic strides in view memref type ") + << viewType; + } + return success(); } |

