// RUN: mlir-opt %s -linalg-fusion | FileCheck %s #map0 = affine_map<(d0) -> (d0 + 2)> #map1 = affine_map<(d0) -> (d0 + 4)> #map2 = affine_map<(d0) -> (d0 + 3)> #map3 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> #map4 = affine_map<(d0) -> (d0)> #map5 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> #map6 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %B, 1 : memref linalg.matmul(%A, %B, %C) : memref, memref, memref %c1 = constant 1 : index loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%5, %7, %8) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f1 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // No RAW dependences, the pass does not fuse RAR atm. // CHECK: linalg.matmul // CHECK: loop.for // CHECK: loop.for // CHECK: loop.for // CHECK: linalg.matmul func @f2(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index linalg.matmul(%A, %B, %C) : memref, memref, memref %0 = dim %C, 0 : memref %1 = dim %C, 1 : memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%5, %7, %8) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f2 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[C_0:.*]] = dim %[[C]], 0 : memref // CHECK-DAG: %[[C_1:.*]] = dim %[[C]], 1 : memref // CHECK-DAG: %[[D_1:.*]] = dim %[[D]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK: linalg.matmul func @f3(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index linalg.matmul(%A, %B, %C) : memref, memref, memref %0 = dim %D, 0 : memref %1 = dim %D, 1 : memref %2 = dim %C, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %7 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%5, %7, %8) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f3 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK: %[[D_0:.*]] = dim %[[D]], 0 : memref // CHECK: %[[D_1:.*]] = dim %[[D]], 1 : memref // CHECK: %[[C_1:.*]] = dim %[[C]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK: linalg.matmul func @f4(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index linalg.matmul(%A, %B, %C) : memref, memref, memref linalg.matmul(%A, %B, %D) : memref, memref, memref %0 = dim %C, 0 : memref %1 = dim %C, 1 : memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%5, %7, %8) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f4 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK: %[[C_0:.*]] = dim %[[C]], 0 : memref // CHECK: %[[C_1:.*]] = dim %[[C]], 1 : memref // CHECK: %[[D_1:.*]] = dim %[[D]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // Fuse D then fuse C, no false dependence prevent it. // CHECK: linalg.matmul // CHECK: linalg.matmul // CHECK: linalg.matmul func @f5(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %B, 1 : memref %1 = dim %D, 0 : memref %2 = dim %D, 1 : memref linalg.matmul(%A, %B, %C) : memref, memref, memref linalg.matmul(%C, %B, %D) : memref, memref, memref loop.for %arg5 = %c0 to %1 step %c2 { loop.for %arg6 = %c0 to %0 step %c3 { loop.for %arg7 = %c0 to %2 step %c4 { %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%5, %7, %8) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f5 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[B_1:.*]] = dim %[[B]], 1 : memref // CHECK-DAG: %[[D_0:.*]] = dim %[[D]], 0 : memref // CHECK-DAG: %[[D_1:.*]] = dim %[[D]], 1 : memref // Don't fuse C due to false dependence, note that this is too conservative though. // CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK: linalg.matmul func @f6(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %C, 1 : memref linalg.matmul(%A, %B, %C) : memref, memref, memref linalg.matmul(%A, %C, %E) : memref, memref, memref %1 = dim %C, 0 : memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %1 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %0 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %6 = affine.apply #map2(%arg6) %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%5, %7, %8) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f6 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // Cannot fuse C due to interleaved read of C that would be bypassed. // Cannot fuse E (WAW). // CHECK: linalg.matmul // CHECK: linalg.matmul // CHECK: loop.for // CHECK: loop.for // CHECK: loop.for // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul func @f7(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %C, 1 : memref %3 = dim %C, 0 : memref %4 = dim %D, 1 : memref linalg.matmul(%A, %C, %E) : memref, memref, memref linalg.matmul(%A, %B, %C) : memref, memref, memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %7 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %9 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%7, %9, %10) : memref, memref, memref } } } loop.for %arg5 = %c0 to %3 step %c2 { loop.for %arg6 = %c0 to %4 step %c3 { loop.for %arg7 = %c0 to %2 step %c4 { %7 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %9 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%7, %9, %10) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f7 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK: %[[A_0:.*]] = dim %[[A]], 0 : memref // CHECK: %[[A_1:.*]] = dim %[[A]], 1 : memref // CHECK: %[[C_1:.*]] = dim %[[C]], 1 : memref // CHECK: %[[C_0:.*]] = dim %[[C]], 0 : memref // CHECK: %[[D_1:.*]] = dim %[[D]], 1 : memref // CHECK: linalg.matmul(%[[A]], %[[C]], %[[E]]) // CHECK: loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK: linalg.matmul // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul func @f8(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref linalg.matmul(%A, %C, %D) : memref, memref, memref linalg.matmul(%A, %B, %C) : memref, memref, memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref %6 = affine.apply #map2(%arg6) %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul(%5, %7, %8) : memref, memref, memref } } } return %E : memref } // CHECK-LABEL: func @f8 // CHECK: (%[[A:.*]]: memref{{.*}}, %[[B:.*]]: memref{{.*}}, %[[C:.*]]: memref{{.*}}, %[[D:.*]]: memref{{.*}}, %[[E:.*]]: memref{{.*}}) // CHECK: linalg.matmul // CHECK: linalg.matmul // CHECK: loop.for // CHECK: loop.for // CHECK: loop.for // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul #id_2d = affine_map<(i, j) -> (i, j)> #pointwise_2d_trait = { args_in = 2, args_out = 1, indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } func @pointwise(%A: memref, %B: memref, %C: memref, %D: memref) { %c1 = constant 1 : index %c0 = constant 0 : index %c3 = constant 3 : index %c2 = constant 2 : index linalg.generic #pointwise_2d_trait %A, %A, %B { ^bb0(%E: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %E, %arg5 : f32 linalg.yield %2 : f32 }: memref, memref, memref %0 = dim %B, 0 : memref %1 = dim %B, 1 : memref loop.for %arg4 = %c0 to %0 step %c2 { loop.for %arg5 = %c0 to %1 step %c3 { %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref linalg.generic #pointwise_2d_trait %4, %5, %6 { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 }: memref, memref, memref } } return } // CHECK-LABEL: func @pointwise // CHECK: loop.for // CHECK: loop.for // CHECK-NOT: loop.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic // CHECK: mulf func @pointwise_no_view(%M: index, %N: index) { %c1 = constant 1 : index %c0 = constant 0 : index %c3 = constant 3 : index %c2 = constant 2 : index %A = alloc (%M, %N): memref %B = alloc (%M, %N): memref %C = alloc (%M, %N): memref %D = alloc (%M, %N): memref %E = alloc (%M, %N): memref linalg.generic #pointwise_2d_trait %A, %A, %B { ^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %e, %arg5 : f32 linalg.yield %2 : f32 }: memref, memref, memref %0 = dim %B, 0 : memref %1 = dim %B, 1 : memref loop.for %arg4 = %c0 to %0 step %c2 { loop.for %arg5 = %c0 to %1 step %c3 { %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref linalg.generic #pointwise_2d_trait %4, %5, %6 { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 }: memref, memref, memref } } return } // CHECK-LABEL: func @pointwise_no_view // CHECK: loop.for // CHECK: loop.for // CHECK-NOT: loop.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic // CHECK: mulf func @indexed_generic_test(%A: memref, %B: memref, %C: memref, %D: memref) { linalg.generic #pointwise_2d_trait %A, %B, %C { ^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %e, %arg5 : f32 linalg.yield %2 : f32 }: memref, memref, memref %c1 = constant 1 : index %c0 = constant 0 : index %c25 = constant 25 : index %c10 = constant 10 : index %0 = dim %C, 0 : memref %1 = dim %C, 1 : memref %2 = dim %D, 0 : memref %3 = dim %D, 1 : memref loop.for %arg2 = %c0 to %0 step %c10 { loop.for %arg3 = %c0 to %1 step %c25 { %4 = std.subview %C[%arg2, %arg3][%c10, %c25][%c1, %c1] : memref to memref %5 = std.subview %D[%arg2, %arg3][%c10, %c25][%c1, %c1] : memref to memref linalg.indexed_generic { indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"], args_in = 1, args_out = 1 } %4, %5 { ^bb0(%arg4: index, %arg5: index, %arg6: f32, %arg7: f32): %6 = addi %arg4, %arg2 : index %7 = addi %arg5, %arg3 : index %8 = index_cast %6 : index to i32 %9 = sitofp %8 : i32 to f32 %10 = index_cast %7 : index to i32 %11 = sitofp %10 : i32 to f32 %12 = addf %9, %11 : f32 linalg.yield %12 : f32 }: memref, memref } } return } // CHECK-LABEL: func @indexed_generic_test // CHECK: loop.for // CHECK: loop.for // CHECK-NOT: loop.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.indexed_generic // CHECK: index_cast