diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-12-11 09:26:51 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-11 09:27:34 -0800 |
| commit | 508d4e672e5de9b9c582fc5e96996d93bb92c74e (patch) | |
| tree | 267dac5084c2c7e12bc56db7de79cc423ebefe50 /mlir/test/Dialect | |
| parent | c5fb4c1303837a41eb06b14137d3e4a5387023a3 (diff) | |
| download | bcm5719-llvm-508d4e672e5de9b9c582fc5e96996d93bb92c74e.tar.gz bcm5719-llvm-508d4e672e5de9b9c582fc5e96996d93bb92c74e.zip | |
Continue refactoring StructuredOps utilities
This CL adds more common information to StructuredOpsUtils.h
The n_view attribute is retired in favor of args_in + args_out but the CL is otherwise NFC.
PiperOrigin-RevId: 285000621
Diffstat (limited to 'mlir/test/Dialect')
| -rw-r--r-- | mlir/test/Dialect/Linalg/fusion.mlir | 5 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/invalid.mlir | 83 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/llvm.mlir | 6 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/loops.mlir | 12 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/roundtrip.mlir | 12 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/tile.mlir | 5 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/tile_indexed_generic.mlir | 10 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/transform-patterns.mlir | 16 |
8 files changed, 98 insertions, 51 deletions
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir index 616b30d3d9d..cbb99a76673 100644 --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -305,9 +305,10 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of #id_2d = (i, j) -> (i, j) #pointwise_2d_trait = { + args_in = 2, + args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], - iterator_types = ["parallel", "parallel"], - n_views = [2, 1] + iterator_types = ["parallel", "parallel"] } func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>) { %c1 = constant 1 : index diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 9ec345455dc..f99ee74ceea 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -57,9 +57,10 @@ func @yield_parent(%arg0: memref<?xf32, (i)[off]->(off + i)>) { func @generic_at_least_2_operands(%arg0: memref<f32>) { // expected-error @+1 {{op expected 2 or more operands}} linalg.generic { + args_in = 1, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [1, 1], iterator_types = [] } %arg0: memref<f32> } @@ -69,9 +70,10 @@ func @generic_at_least_2_operands(%arg0: memref<f32>) { func @generic_exactly_2_views(%arg0: memref<f32>) { // expected-error @+1 {{op expected exactly 2 view operands}} linalg.generic { + args_in = 1, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [1, 1], iterator_types = [] } %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32> } @@ -81,9 +83,10 @@ func @generic_exactly_2_views(%arg0: memref<f32>) { func @generic_undefined_fun(%arg0: memref<f32>) { // expected-error @+1 {{op expected fun attribute to refer to a defined symbol}} linalg.generic { + args_in = 1, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [1, 1], iterator_types = [] } %arg0, %arg0: memref<f32>, memref<f32> } @@ -95,9 +98,10 @@ func @foo() { return } func @generic_mismatched_num_arguments(%arg0: memref<f32>) { // expected-error @+1 {{op expected fun arguments to match number of views}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [0, 1], iterator_types = [] } %arg0: memref<f32> } @@ -109,9 +113,10 @@ func @foo(%0: i32) { return } func @generic_mismatched_num_returns(%arg0: memref<f32>) { // expected-error @+1 {{op expected fun results to match number of output views}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [0, 1], iterator_types = [] } %arg0: memref<f32> } @@ -123,9 +128,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 } func @generic_symbol_in_map(%arg0: memref<i32>) { // expected-error @+1 {{op expected indexing_map #0 to have no symbols}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ ()[N] -> (0) ], - n_views = [0, 1], iterator_types = ["parallel"] } %arg0: memref<i32> } @@ -137,9 +143,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 } func @generic_wrong_dim_in_map(%arg0: memref<i32>) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [0, 1], iterator_types = ["parallel"] } %arg0: memref<i32> } @@ -151,9 +158,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 } func @generic_zero_d_view(%arg0: memref<i32>) { // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ () -> (1) ], - n_views = [0, 1], iterator_types = [] } %arg0: memref<i32> } @@ -165,9 +173,10 @@ func @foo(%0: f32) -> f32 { return %0: f32 } func @generic_one_d_view(%arg0: memref<?xf32, (i)[off]->(off + i)>) { // expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, (d0)[s0] -> (d0 + s0)>'}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0, 0) ], - n_views = [0, 1], iterator_types = [] } %arg0: memref<?xf32, (i)[off]->(off + i)> } @@ -182,9 +191,10 @@ func @foo(%0: i32) -> f32 { func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) { // expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [0, 1], iterator_types = [] } %arg0: memref<?xf32, (i)[off]->(off + i)> } @@ -199,9 +209,10 @@ func @foo(%0: f32) -> i4 { func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) { // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}} linalg.generic { + args_in = 0, + args_out = 1, fun = @foo, indexing_maps = [ () -> (0) ], - n_views = [0, 1], iterator_types = [] } %arg0: memref<?xf32, (i)[off]->(off + i)> } @@ -213,12 +224,13 @@ func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 } func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: memref<?xf32, (i)[off]->(off + i)>) { // expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}} linalg.generic { + args_in = 1, + args_out = 1, fun = @foo, indexing_maps = [ (i, j) -> (i + j) , (i, j) -> (i + j) ], - n_views = [1, 1], iterator_types = ["parallel","parallel"] } %arg0, %arg1: memref<?xf32, (i)[off]->(off + i)>, memref<?xf32, (i)[off]->(off + i)> } @@ -232,8 +244,9 @@ func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: me func @generic_empty_region(%arg0: memref<f32>) { // expected-error @+1 {{op expected region with 1 block}} linalg.generic { + args_in = 1, + args_out = 1, indexing_maps = [ () -> (0) ], - n_views = [1, 1], iterator_types = [] } %arg0, %arg0 { ^bb1: @@ -246,8 +259,9 @@ func @generic_empty_region(%arg0: memref<f32>) { func @generic_mismatched_num_arguments(%arg0: memref<f32>) { // expected-error @+1 {{op expected number of block arguments to match number of views}} linalg.generic { + args_in = 0, + args_out = 1, indexing_maps = [ () -> (0) ], - n_views = [0, 1], iterator_types = [] } %arg0 { ^bb: @@ -259,8 +273,9 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) { func @generic_block_arg_type(%arg0: memref<f32>) { // expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref<f32>'}} linalg.generic { + args_in = 0, + args_out = 1, indexing_maps = [ () -> (0) ], - n_views = [0, 1], iterator_types = [] } %arg0 { ^bb(%i: i1): @@ -272,8 +287,9 @@ func @generic_block_arg_type(%arg0: memref<f32>) { func @indexed_generic_block_arg_count(%arg0: memref<f32>) { // expected-error @+1 {{op expected number of block arguments to match number of views + number of loops}} linalg.indexed_generic { + args_in = 0, + args_out = 1, indexing_maps = [ (d0) -> (d0) ], - n_views = [0, 1], iterator_types = ["parallel"] } %arg0 { ^bb(%f: f32): @@ -285,8 +301,9 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) { func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) { // expected-error @+1 {{op expected block argument 0 to be of IndexType}} linalg.indexed_generic { + args_in = 0, + args_out = 1, indexing_maps = [ (d0) -> (d0) ], - n_views = [0, 1], iterator_types = ["parallel"] } %arg0 { ^bb(%i: f64, %f: f32): @@ -298,8 +315,9 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) { func @indexed_generic_block_arg_type(%arg0: memref<f32>) { // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output view: 'memref<f32>'}} linalg.indexed_generic { + args_in = 0, + args_out = 1, indexing_maps = [ (d0) -> (d0) ], - n_views = [0, 1], iterator_types = ["parallel"] } %arg0 { ^bb(%i: index, %f: i1): @@ -314,8 +332,9 @@ func @foo(%f: f32) -> (f32) { func @indexed_generic_fun_arg_count(%arg0: memref<f32>) { // expected-error @+1 {{op expected fun arguments to match number of views + number of loops}} linalg.indexed_generic { + args_in = 0, + args_out = 1, indexing_maps = [ (d0) -> (d0) ], - n_views = [0, 1], iterator_types = ["parallel"], fun = @foo } %arg0: memref<f32> @@ -329,7 +348,8 @@ func @foo(%i: i32, %val: f32) -> (f32) { func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) { // expected-error @+1 {{op expected fun argument 0 to be of IndexType}} linalg.indexed_generic { - n_views = [0, 1], + args_in = 0, + args_out = 1, iterator_types = ["parallel"], indexing_maps = [ (i) -> (i) ], fun = @foo @@ -344,8 +364,9 @@ func @foo(%i: index, %val: i1) -> (i1) { func @indexed_generic_fun_arg_type(%arg0: memref<f32>) { // expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}} linalg.indexed_generic { + args_in = 0, + args_out = 1, indexing_maps = [ (d0) -> (d0) ], - n_views = [0, 1], iterator_types = ["parallel"], fun = @foo } %arg0: memref<f32> @@ -359,8 +380,9 @@ func @foo(%i: index, %val: i1) -> (i1, i1) { func @indexed_generic_fun_result_count(%arg0: memref<f32>) { // expected-error @+1 {{op expected fun results to match number of output views}} linalg.indexed_generic { + args_in = 0, + args_out = 1, indexing_maps = [ (d0) -> (d0) ], - n_views = [0, 1], iterator_types = ["parallel"], fun = @foo } %arg0: memref<f32> @@ -375,8 +397,9 @@ func @foo(%i: index, %val: i32) -> (f32) { func @indexed_generic_fun_result_count(%arg0: memref<i32>) { // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}} linalg.indexed_generic { + args_in = 0, + args_out = 1, indexing_maps = [ (d0) -> (d0) ], - n_views = [0, 1], iterator_types = ["parallel"], fun = @foo } %arg0: memref<i32> @@ -385,10 +408,11 @@ func @indexed_generic_fun_result_count(%arg0: memref<i32>) { // ----- func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) { - // expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}} + // expected-error @+9 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}} linalg.generic { - indexing_maps = [ (i) -> (i) ], - n_views = [0, 1], + args_in = 0, + args_out = 1, + indexing_maps = [ (i) -> (i) ], iterator_types = ["parallel"] } %arg0 { ^bb(%i: f32): @@ -399,6 +423,13 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i) // ----- +func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) { + // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}} + linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32> +} + +// ----- + // expected-error @+1 {{unknown Linalg type}} !invalid_type = type !linalg.unknown diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index 76055b56b1c..7054a3d9316 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -138,7 +138,8 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a (m, n, k) -> (m, n) ] #matmul_trait = { - n_views = [2, 1], + args_in = 2, + args_out = 1, iterator_types = ["parallel", "parallel", "reduction"], indexing_maps = #matmul_accesses, library_call = "external_outerproduct_matmul" @@ -175,7 +176,8 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C #indexed_matmul_trait = { - n_views = [2, 1], + args_in = 2, + args_out = 1, iterator_types = ["parallel", "parallel", "reduction"], indexing_maps = #matmul_accesses, library_call = "external_indexed_outerproduct_matmul" diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index 933280b9e24..1425b4ed3a4 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -222,7 +222,8 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { (i, j, k) -> (i, k, j) ] #trait = { - n_views = [1, 2], + args_in = 1, + args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = #accesses, fun = @foo, @@ -247,7 +248,8 @@ func @generic_function(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1 // CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]> #trait2 = { - n_views = [1, 2], + args_in = 1, + args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = #accesses, library_call = "some_external_function_name_2", @@ -280,7 +282,8 @@ func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) -> return %i_float, %i_float : f32, f32 } #trait3 = { - n_views = [1, 2], + args_in = 1, + args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = #accesses, fun = @indexed_foo, @@ -310,7 +313,8 @@ func @indexed_generic_function( // CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]> #trait4 = { - n_views = [1, 2], + args_in = 1, + args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = #accesses, library_call = "some_external_function_name_2", diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 29e04aba33a..75d732d540d 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -120,8 +120,9 @@ func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, (i, j, k) -> (i, k, i + j) ] #trait = { + args_in = 1, + args_out = 1, indexing_maps = #accesses, - n_views = [1, 1], iterator_types = ["parallel", "parallel", "parallel"], fun = @foo, library_call = "some_external_function_name_1" @@ -136,11 +137,12 @@ func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %ar } // CHECK-LABEL: func @foo // CHECK-LABEL: func @generic -// CHECK: linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1", n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]> +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]> #trait2 = { + args_in = 1, + args_out = 1, indexing_maps = #accesses, - n_views = [1, 1], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2" } @@ -152,7 +154,7 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1 return } // CHECK-LABEL: func @generic_region -// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} { +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} { // CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors // CHECK: linalg.yield %{{.*}} : f32 // CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]> @@ -166,7 +168,7 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, return } // CHECK-LABEL: func @indexed_generic -// CHECK: linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} { +// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} { // CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 // CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]> diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir index 4040063d77a..763b33b7973 100644 --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -213,9 +213,10 @@ func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) { #id_2d = (i, j) -> (i, j) #pointwise_2d_trait = { + args_in = 2, + args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], - iterator_types = ["parallel", "parallel"], - n_views = [2, 1] + iterator_types = ["parallel", "parallel"] } func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>, diff --git a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir index c17a4634c66..c7cd61b76e3 100644 --- a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir +++ b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir @@ -4,9 +4,10 @@ #id_1d = (i) -> (i) #pointwise_1d_trait = { + args_in = 1, + args_out = 1, indexing_maps = [#id_1d, #id_1d], - iterator_types = ["parallel"], - n_views = [1, 1] + iterator_types = ["parallel"] } func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) { linalg.indexed_generic #pointwise_1d_trait %operand, %result { @@ -43,12 +44,13 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) // TILE-0n25: linalg.indexed_generic #combined_indices_trait = { + args_in = 1, + args_out = 1, indexing_maps = [ (i, j) -> (j, i + j), (i, j) -> (i, j) ], - iterator_types = ["parallel", "parallel"], - n_views = [1, 1] + iterator_types = ["parallel", "parallel"] } func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) { linalg.indexed_generic #combined_indices_trait %operand, %result { diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index e986625bcee..4a9d8bc9564 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -83,11 +83,12 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, // CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> #some_generic_trait = { + args_in = 1, + args_out = 1, indexing_maps = [ (i, j) -> (i, j), (i, j) -> (i, j) ], - n_views = [1, 1], iterator_types = ["parallel", "parallel"] } func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, @@ -164,12 +165,13 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, // CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> #matmul_trait = { + args_in = 2, + args_out = 1, indexing_maps = [ (m, n, k) -> (m, k), (m, n, k) -> (k, n), (m, n, k) -> (m, n) ], - n_views = [2, 1], iterator_types = ["parallel", "parallel", "reduction"], __internal_linalg_transform__ = "_marked_matmul_" } @@ -204,10 +206,11 @@ func @fma(%a: f32, %b: f32, %c: f32) -> f32 { (m, n, k) -> (m, n) ] #generic_matmul_trait = { + args_in = 2, + args_out = 1, fun = @fma, indexing_maps = #matmul_accesses, library_call = "linalg_matmul", - n_views = [2, 1], iterator_types = ["parallel", "parallel", "reduction"] } @@ -220,7 +223,7 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, } // CHECK-LABEL : func @fma // CHECK-LABEL : func @permute_generic -// CHECK : linalg.generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> +// CHECK : linalg.generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 @@ -228,10 +231,11 @@ func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> return %e: f32 } #indexed_matmul_trait = { + args_in = 2, + args_out = 1, fun = @fma_indexed, indexing_maps = #matmul_accesses, library_call = "linalg_matmul_indexed", - n_views = [2, 1], iterator_types = ["parallel", "parallel", "reduction"] } func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, @@ -242,7 +246,7 @@ func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, } // CHECK-LABEL : func @fma_indexed // CHECK-LABEL : func @permute_generic_indexed -// CHECK : linalg.indexed_generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> +// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>, %y: memref<?xf32, offset: ?, strides: [1]>, |

