summaryrefslogtreecommitdiffstats
path: root/mlir/test/Dialect
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-12-11 09:26:51 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-11 09:27:34 -0800
commit508d4e672e5de9b9c582fc5e96996d93bb92c74e (patch)
tree267dac5084c2c7e12bc56db7de79cc423ebefe50 /mlir/test/Dialect
parentc5fb4c1303837a41eb06b14137d3e4a5387023a3 (diff)
downloadbcm5719-llvm-508d4e672e5de9b9c582fc5e96996d93bb92c74e.tar.gz
bcm5719-llvm-508d4e672e5de9b9c582fc5e96996d93bb92c74e.zip
Continue refactoring StructuredOps utilities
This CL adds more common information to StructuredOpsUtils.h The n_view attribute is retired in favor of args_in + args_out but the CL is otherwise NFC. PiperOrigin-RevId: 285000621
Diffstat (limited to 'mlir/test/Dialect')
-rw-r--r--mlir/test/Dialect/Linalg/fusion.mlir5
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir83
-rw-r--r--mlir/test/Dialect/Linalg/llvm.mlir6
-rw-r--r--mlir/test/Dialect/Linalg/loops.mlir12
-rw-r--r--mlir/test/Dialect/Linalg/roundtrip.mlir12
-rw-r--r--mlir/test/Dialect/Linalg/tile.mlir5
-rw-r--r--mlir/test/Dialect/Linalg/tile_indexed_generic.mlir10
-rw-r--r--mlir/test/Dialect/Linalg/transform-patterns.mlir16
8 files changed, 98 insertions, 51 deletions
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 616b30d3d9d..cbb99a76673 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -305,9 +305,10 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
#id_2d = (i, j) -> (i, j)
#pointwise_2d_trait = {
+ args_in = 2,
+ args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
- iterator_types = ["parallel", "parallel"],
- n_views = [2, 1]
+ iterator_types = ["parallel", "parallel"]
}
func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>) {
%c1 = constant 1 : index
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9ec345455dc..f99ee74ceea 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -57,9 +57,10 @@ func @yield_parent(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
func @generic_at_least_2_operands(%arg0: memref<f32>) {
// expected-error @+1 {{op expected 2 or more operands}}
linalg.generic {
+ args_in = 1,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [1, 1],
iterator_types = []
} %arg0: memref<f32>
}
@@ -69,9 +70,10 @@ func @generic_at_least_2_operands(%arg0: memref<f32>) {
func @generic_exactly_2_views(%arg0: memref<f32>) {
// expected-error @+1 {{op expected exactly 2 view operands}}
linalg.generic {
+ args_in = 1,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [1, 1],
iterator_types = []
} %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32>
}
@@ -81,9 +83,10 @@ func @generic_exactly_2_views(%arg0: memref<f32>) {
func @generic_undefined_fun(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun attribute to refer to a defined symbol}}
linalg.generic {
+ args_in = 1,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [1, 1],
iterator_types = []
} %arg0, %arg0: memref<f32>, memref<f32>
}
@@ -95,9 +98,10 @@ func @foo() { return }
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun arguments to match number of views}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [0, 1],
iterator_types = []
} %arg0: memref<f32>
}
@@ -109,9 +113,10 @@ func @foo(%0: i32) { return }
func @generic_mismatched_num_returns(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun results to match number of output views}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [0, 1],
iterator_types = []
} %arg0: memref<f32>
}
@@ -123,9 +128,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
func @generic_symbol_in_map(%arg0: memref<i32>) {
// expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ ()[N] -> (0) ],
- n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0: memref<i32>
}
@@ -137,9 +143,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0: memref<i32>
}
@@ -151,9 +158,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
func @generic_zero_d_view(%arg0: memref<i32>) {
// expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (1) ],
- n_views = [0, 1],
iterator_types = []
} %arg0: memref<i32>
}
@@ -165,9 +173,10 @@ func @foo(%0: f32) -> f32 { return %0: f32 }
func @generic_one_d_view(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, (d0)[s0] -> (d0 + s0)>'}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0, 0) ],
- n_views = [0, 1],
iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
@@ -182,9 +191,10 @@ func @foo(%0: i32) -> f32 {
func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [0, 1],
iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
@@ -199,9 +209,10 @@ func @foo(%0: f32) -> i4 {
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
- n_views = [0, 1],
iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
@@ -213,12 +224,13 @@ func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 }
func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
linalg.generic {
+ args_in = 1,
+ args_out = 1,
fun = @foo,
indexing_maps = [
(i, j) -> (i + j) ,
(i, j) -> (i + j)
],
- n_views = [1, 1],
iterator_types = ["parallel","parallel"]
} %arg0, %arg1: memref<?xf32, (i)[off]->(off + i)>, memref<?xf32, (i)[off]->(off + i)>
}
@@ -232,8 +244,9 @@ func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: me
func @generic_empty_region(%arg0: memref<f32>) {
// expected-error @+1 {{op expected region with 1 block}}
linalg.generic {
+ args_in = 1,
+ args_out = 1,
indexing_maps = [ () -> (0) ],
- n_views = [1, 1],
iterator_types = []
} %arg0, %arg0 {
^bb1:
@@ -246,8 +259,9 @@ func @generic_empty_region(%arg0: memref<f32>) {
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of views}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ () -> (0) ],
- n_views = [0, 1],
iterator_types = []
} %arg0 {
^bb:
@@ -259,8 +273,9 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
func @generic_block_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref<f32>'}}
linalg.generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ () -> (0) ],
- n_views = [0, 1],
iterator_types = []
} %arg0 {
^bb(%i: i1):
@@ -272,8 +287,9 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of views + number of loops}}
linalg.indexed_generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
- n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0 {
^bb(%f: f32):
@@ -285,8 +301,9 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 0 to be of IndexType}}
linalg.indexed_generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
- n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0 {
^bb(%i: f64, %f: f32):
@@ -298,8 +315,9 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output view: 'memref<f32>'}}
linalg.indexed_generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
- n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0 {
^bb(%i: index, %f: i1):
@@ -314,8 +332,9 @@ func @foo(%f: f32) -> (f32) {
func @indexed_generic_fun_arg_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun arguments to match number of views + number of loops}}
linalg.indexed_generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
- n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
@@ -329,7 +348,8 @@ func @foo(%i: i32, %val: f32) -> (f32) {
func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun argument 0 to be of IndexType}}
linalg.indexed_generic {
- n_views = [0, 1],
+ args_in = 0,
+ args_out = 1,
iterator_types = ["parallel"],
indexing_maps = [ (i) -> (i) ],
fun = @foo
@@ -344,8 +364,9 @@ func @foo(%i: index, %val: i1) -> (i1) {
func @indexed_generic_fun_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}}
linalg.indexed_generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
- n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
@@ -359,8 +380,9 @@ func @foo(%i: index, %val: i1) -> (i1, i1) {
func @indexed_generic_fun_result_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun results to match number of output views}}
linalg.indexed_generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
- n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
@@ -375,8 +397,9 @@ func @foo(%i: index, %val: i32) -> (f32) {
func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
// expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}}
linalg.indexed_generic {
+ args_in = 0,
+ args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
- n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<i32>
@@ -385,10 +408,11 @@ func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
// -----
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
- // expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
+ // expected-error @+9 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
linalg.generic {
- indexing_maps = [ (i) -> (i) ],
- n_views = [0, 1],
+ args_in = 0,
+ args_out = 1,
+ indexing_maps = [ (i) -> (i) ],
iterator_types = ["parallel"]
} %arg0 {
^bb(%i: f32):
@@ -399,6 +423,13 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
// -----
+func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
+ linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
+}
+
+// -----
+
// expected-error @+1 {{unknown Linalg type}}
!invalid_type = type !linalg.unknown
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 76055b56b1c..7054a3d9316 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -138,7 +138,8 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
(m, n, k) -> (m, n)
]
#matmul_trait = {
- n_views = [2, 1],
+ args_in = 2,
+ args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_outerproduct_matmul"
@@ -175,7 +176,8 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
#indexed_matmul_trait = {
- n_views = [2, 1],
+ args_in = 2,
+ args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_indexed_outerproduct_matmul"
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 933280b9e24..1425b4ed3a4 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -222,7 +222,8 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
(i, j, k) -> (i, k, j)
]
#trait = {
- n_views = [1, 2],
+ args_in = 1,
+ args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
fun = @foo,
@@ -247,7 +248,8 @@ func @generic_function(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1
// CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
#trait2 = {
- n_views = [1, 2],
+ args_in = 1,
+ args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_function_name_2",
@@ -280,7 +282,8 @@ func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) ->
return %i_float, %i_float : f32, f32
}
#trait3 = {
- n_views = [1, 2],
+ args_in = 1,
+ args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
fun = @indexed_foo,
@@ -310,7 +313,8 @@ func @indexed_generic_function(
// CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
#trait4 = {
- n_views = [1, 2],
+ args_in = 1,
+ args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_function_name_2",
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 29e04aba33a..75d732d540d 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -120,8 +120,9 @@ func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?,
(i, j, k) -> (i, k, i + j)
]
#trait = {
+ args_in = 1,
+ args_out = 1,
indexing_maps = #accesses,
- n_views = [1, 1],
iterator_types = ["parallel", "parallel", "parallel"],
fun = @foo,
library_call = "some_external_function_name_1"
@@ -136,11 +137,12 @@ func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %ar
}
// CHECK-LABEL: func @foo
// CHECK-LABEL: func @generic
-// CHECK: linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1", n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
+// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
#trait2 = {
+ args_in = 1,
+ args_out = 1,
indexing_maps = #accesses,
- n_views = [1, 1],
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_2"
}
@@ -152,7 +154,7 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
return
}
// CHECK-LABEL: func @generic_region
-// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
+// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} {
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
@@ -166,7 +168,7 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
return
}
// CHECK-LABEL: func @indexed_generic
-// CHECK: linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
+// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} {
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index 4040063d77a..763b33b7973 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -213,9 +213,10 @@ func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
#id_2d = (i, j) -> (i, j)
#pointwise_2d_trait = {
+ args_in = 2,
+ args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
- iterator_types = ["parallel", "parallel"],
- n_views = [2, 1]
+ iterator_types = ["parallel", "parallel"]
}
func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
diff --git a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
index c17a4634c66..c7cd61b76e3 100644
--- a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
+++ b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
@@ -4,9 +4,10 @@
#id_1d = (i) -> (i)
#pointwise_1d_trait = {
+ args_in = 1,
+ args_out = 1,
indexing_maps = [#id_1d, #id_1d],
- iterator_types = ["parallel"],
- n_views = [1, 1]
+ iterator_types = ["parallel"]
}
func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
linalg.indexed_generic #pointwise_1d_trait %operand, %result {
@@ -43,12 +44,13 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>)
// TILE-0n25: linalg.indexed_generic
#combined_indices_trait = {
+ args_in = 1,
+ args_out = 1,
indexing_maps = [
(i, j) -> (j, i + j),
(i, j) -> (i, j)
],
- iterator_types = ["parallel", "parallel"],
- n_views = [1, 1]
+ iterator_types = ["parallel", "parallel"]
}
func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
linalg.indexed_generic #combined_indices_trait %operand, %result {
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index e986625bcee..4a9d8bc9564 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -83,11 +83,12 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
#some_generic_trait = {
+ args_in = 1,
+ args_out = 1,
indexing_maps = [
(i, j) -> (i, j),
(i, j) -> (i, j)
],
- n_views = [1, 1],
iterator_types = ["parallel", "parallel"]
}
func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -164,12 +165,13 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
#matmul_trait = {
+ args_in = 2,
+ args_out = 1,
indexing_maps = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
],
- n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"],
__internal_linalg_transform__ = "_marked_matmul_"
}
@@ -204,10 +206,11 @@ func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
(m, n, k) -> (m, n)
]
#generic_matmul_trait = {
+ args_in = 2,
+ args_out = 1,
fun = @fma,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
- n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"]
}
@@ -220,7 +223,7 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
}
// CHECK-LABEL : func @fma
// CHECK-LABEL : func @permute_generic
-// CHECK : linalg.generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+// CHECK : linalg.generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
@@ -228,10 +231,11 @@ func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) ->
return %e: f32
}
#indexed_matmul_trait = {
+ args_in = 2,
+ args_out = 1,
fun = @fma_indexed,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul_indexed",
- n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"]
}
func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -242,7 +246,7 @@ func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
}
// CHECK-LABEL : func @fma_indexed
// CHECK-LABEL : func @permute_generic_indexed
-// CHECK : linalg.indexed_generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
OpenPOWER on IntegriCloud