summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-11-06 11:25:16 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-06 11:25:54 -0800
commitb5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd (patch)
tree6a4eaf7306a999d40eec701647ade14ad50d6e6f /mlir
parent5967f91770af670278ae9e668760e8d5be6bbb48 (diff)
downloadbcm5719-llvm-b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd.tar.gz
bcm5719-llvm-b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd.zip
Add ViewOp verification for dynamic strides, and address some comments from previous change.
PiperOrigin-RevId: 278903187
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.td3
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp32
-rw-r--r--mlir/test/IR/core-ops.mlir20
-rw-r--r--mlir/test/IR/invalid-ops.mlir25
4 files changed, 57 insertions, 23 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 4dbdcc47ff0..10f94381d22 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1169,7 +1169,8 @@ def ViewOp : Std_Op<"view"> {
(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * 4 + d2 + s1)
}];
- let arguments = (ins AnyMemRef:$source, Variadic<Index>:$operands);
+ let arguments = (ins MemRefRankOf<[I8], [1]>:$source,
+ Variadic<Index>:$operands);
let results = (outs AnyMemRef);
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index e6b99035f6e..5a452c5242a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -2376,17 +2376,8 @@ static void print(OpAsmPrinter &p, ViewOp op) {
}
static LogicalResult verify(ViewOp op) {
- auto baseType = op.getOperand(0)->getType().dyn_cast<MemRefType>();
- auto viewType = op.getResult()->getType().dyn_cast<MemRefType>();
-
- // Operand 0 type and ViewOp result type must be memref.
- if (!baseType || !viewType)
- return op.emitError("operand type ") << baseType << " and result type "
- << viewType << " are must be memref";
-
- // The base memref should be rank 1 with i8 element type.
- if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8))
- return op.emitError("unsupported shape for base memref type ") << baseType;
+ auto baseType = op.getOperand(0)->getType().cast<MemRefType>();
+ auto viewType = op.getResult()->getType().cast<MemRefType>();
// The base memref should have identity layout map (or none).
if (baseType.getAffineMaps().size() > 1 ||
@@ -2403,7 +2394,7 @@ static LogicalResult verify(ViewOp op) {
// Verify that the result memref type has a strided layout map. is strided
int64_t offset;
llvm::SmallVector<int64_t, 4> strides;
- if (failed(mlir::getStridesAndOffset(viewType, strides, offset)))
+ if (failed(getStridesAndOffset(viewType, strides, offset)))
return op.emitError("result type ") << viewType << " is not strided";
// Verify that we have the correct number of operands for the result type.
@@ -2414,6 +2405,23 @@ static LogicalResult verify(ViewOp op) {
if (op.getNumOperands() !=
memrefOperandCount + numDynamicDims + dynamicOffsetCount)
return op.emitError("incorrect number of operands for type ") << viewType;
+
+ // Verify dynamic strides symbols were added to correct dimensions based
+ // on dynamic sizes.
+ ArrayRef<int64_t> viewShape = viewType.getShape();
+ unsigned viewRank = viewType.getRank();
+ assert(viewRank == strides.size());
+ bool dynamicStrides = false;
+ for (int i = viewRank - 2; i >= 0; --i) {
+ // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag.
+ if (ShapedType::isDynamic(viewShape[i + 1]))
+ dynamicStrides = true;
+ // If stride at dim 'i' is not dynamic, return error.
+ if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset())
+ return op.emitError("incorrect dynamic strides in view memref type ")
+ << viewType;
+ }
+
return success();
}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 977ec661646..bbabe60d12d 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -10,8 +10,10 @@
// CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0)
// CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
// CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0)
-// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 4 + d1 + s0)
+
// CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1)
+// CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)
+// CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1)[s0] -> (d0 * s0 + d1)
// CHECK-LABEL: func @func_with_ops(%arg0: f32) {
func @func_with_ops(f32) {
@@ -478,24 +480,24 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>) {
func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xi8>
// Test two dynamic sizes and dynamic offset.
- // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP0]]>
+ // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP2]]>
%1 = view %0[%arg0, %arg1][%arg2]
- : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
+ : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
// Test two dynamic sizes and static offset.
- // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP1]]>
+ // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP3]]>
%2 = view %0[%arg0, %arg1][]
- : memref<2048xi8> to memref<?x?xf32, (d0, d1) -> (d0 * 4 + d1)>
+ : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * s0 + d1)>
// Test one dynamic size and dynamic offset.
- // CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP0]]>
+ // CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP2]]>
%3 = view %0[%arg1][%arg2]
- : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
+ : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
// Test one dynamic size and static offset.
- // CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref<?x16xf32, #[[VIEW_MAP1]]>
+ // CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref<?x4xf32, #[[VIEW_MAP1]]>
%4 = view %0[%arg0][]
- : memref<2048xi8> to memref<?x16xf32, (d0, d1) -> (d0 * 4 + d1)>
+ : memref<2048xi8> to memref<?x4xf32, (d0, d1) -> (d0 * 4 + d1)>
// Test static sizes and static offset.
// CHECK: %{{.*}} = std.view %0[][] : memref<2048xi8> to memref<64x4xf32, #[[VIEW_MAP1]]>
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 4d45d95222d..4d1d853dadb 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -926,7 +926,7 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xf32>
- // expected-error@+1 {{unsupported shape for base memref}}
+ // expected-error@+1 {{must be 1D memref of 8-bit integer values}}
%1 = view %0[%arg0, %arg1][]
: memref<2048xf32> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
return
@@ -953,3 +953,26 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0), 1>
return
}
+
+// -----
+
+func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
+ %0 = alloc() : memref<2048xi8>
+ // expected-error@+1 {{incorrect dynamic strides}}
+ %1 = view %0[%arg0, %arg1][]
+ : memref<2048xi8> to
+ memref<?x?x4xf32, (d0, d1, d2) -> (d0 * 777 + d1 * 4 + d2)>
+ return
+}
+
+// -----
+
+func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
+ %0 = alloc() : memref<2048xi8>
+ // expected-error@+1 {{incorrect dynamic strides}}
+ %1 = view %0[%arg0][]
+ : memref<2048xi8> to
+ memref<16x4x?xf32, (d0, d1, d2) -> (d0 * 777 + d1 * 4 + d2)>
+ return
+}
+
OpenPOWER on IntegriCloud