// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s // CHECK-LABEL: func @test_subi_zero func @test_subi_zero(%arg0: i32) -> i32 { // CHECK-NEXT: %c0_i32 = constant 0 : i32 // CHECK-NEXT: return %c0 %y = subi %arg0, %arg0 : i32 return %y: i32 } // CHECK-LABEL: func @test_subi_zero_vector func @test_subi_zero_vector(%arg0: vector<4xi32>) -> vector<4xi32> { //CHECK-NEXT: %cst = constant dense<0> : vector<4xi32> %y = subi %arg0, %arg0 : vector<4xi32> // CHECK-NEXT: return %cst return %y: vector<4xi32> } // CHECK-LABEL: func @test_subi_zero_tensor func @test_subi_zero_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { //CHECK-NEXT: %cst = constant dense<0> : tensor<4x5xi32> %y = subi %arg0, %arg0 : tensor<4x5xi32> // CHECK-NEXT: return %cst return %y: tensor<4x5xi32> } // CHECK-LABEL: func @dim func @dim(%arg0: tensor<8x4xf32>) -> index { // CHECK: %c4 = constant 4 : index %0 = dim %arg0, 1 : tensor<8x4xf32> // CHECK-NEXT: return %c4 return %0 : index } // CHECK-LABEL: func @test_commutative func @test_commutative(%arg0: i32) -> (i32, i32) { // CHECK: %c42_i32 = constant 42 : i32 %c42_i32 = constant 42 : i32 // CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32 %y = addi %c42_i32, %arg0 : i32 // This should not be swapped. // CHECK-NEXT: %1 = subi %c42_i32, %arg0 : i32 %z = subi %c42_i32, %arg0 : i32 // CHECK-NEXT: return %0, %1 return %y, %z: i32, i32 } // CHECK-LABEL: func @trivial_dce func @trivial_dce(%arg0: tensor<8x4xf32>) { %0 = dim %arg0, 1 : tensor<8x4xf32> // CHECK-NEXT: return return } // CHECK-LABEL: func @addi_zero func @addi_zero(%arg0: i32) -> i32 { // CHECK-NEXT: return %arg0 %c0_i32 = constant 0 : i32 %y = addi %c0_i32, %arg0 : i32 return %y: i32 } // CHECK-LABEL: func @addi_zero_index func @addi_zero_index(%arg0: index) -> index { // CHECK-NEXT: return %arg0 %c0_index = constant 0 : index %y = addi %c0_index, %arg0 : index return %y: index } // CHECK-LABEL: func @addi_zero_vector func @addi_zero_vector(%arg0: vector<4 x i32>) -> vector<4 x i32> { // CHECK-NEXT: return %arg0 %c0_v4i32 = constant dense<0> : vector<4 x i32> %y = addi %c0_v4i32, %arg0 : vector<4 x i32> return %y: vector<4 x i32> } // CHECK-LABEL: func @addi_zero_tensor func @addi_zero_tensor(%arg0: tensor<4 x 5 x i32>) -> tensor<4 x 5 x i32> { // CHECK-NEXT: return %arg0 %c0_t45i32 = constant dense<0> : tensor<4 x 5 x i32> %y = addi %arg0, %c0_t45i32 : tensor<4 x 5 x i32> return %y: tensor<4 x 5 x i32> } // CHECK-LABEL: func @muli_zero func @muli_zero(%arg0: i32) -> i32 { // CHECK-NEXT: %c0_i32 = constant 0 : i32 %c0_i32 = constant 0 : i32 %y = muli %c0_i32, %arg0 : i32 // CHECK-NEXT: return %c0_i32 return %y: i32 } // CHECK-LABEL: func @muli_zero_index func @muli_zero_index(%arg0: index) -> index { // CHECK-NEXT: %[[CST:.*]] = constant 0 : index %c0_index = constant 0 : index %y = muli %c0_index, %arg0 : index // CHECK-NEXT: return %[[CST]] return %y: index } // CHECK-LABEL: func @muli_zero_vector func @muli_zero_vector(%arg0: vector<4 x i32>) -> vector<4 x i32> { // CHECK-NEXT: %cst = constant dense<0> : vector<4xi32> %cst = constant dense<0> : vector<4 x i32> %y = muli %cst, %arg0 : vector<4 x i32> // CHECK-NEXT: return %cst return %y: vector<4 x i32> } // CHECK-LABEL: func @muli_zero_tensor func @muli_zero_tensor(%arg0: tensor<4 x 5 x i32>) -> tensor<4 x 5 x i32> { // CHECK-NEXT: %cst = constant dense<0> : tensor<4x5xi32> %cst = constant dense<0> : tensor<4 x 5 x i32> %y = muli %arg0, %cst : tensor<4 x 5 x i32> // CHECK-NEXT: return %cst return %y: tensor<4 x 5 x i32> } // CHECK-LABEL: func @muli_one func @muli_one(%arg0: i32) -> i32 { // CHECK-NEXT: return %arg0 %c0_i32 = constant 1 : i32 %y = muli %c0_i32, %arg0 : i32 return %y: i32 } // CHECK-LABEL: func @muli_one_index func @muli_one_index(%arg0: index) -> index { // CHECK-NEXT: return %arg0 %c0_index = constant 1 : index %y = muli %c0_index, %arg0 : index return %y: index } // CHECK-LABEL: func @muli_one_vector func @muli_one_vector(%arg0: vector<4 x i32>) -> vector<4 x i32> { // CHECK-NEXT: return %arg0 %c1_v4i32 = constant dense<1> : vector<4 x i32> %y = muli %c1_v4i32, %arg0 : vector<4 x i32> return %y: vector<4 x i32> } // CHECK-LABEL: func @muli_one_tensor func @muli_one_tensor(%arg0: tensor<4 x 5 x i32>) -> tensor<4 x 5 x i32> { // CHECK-NEXT: return %arg0 %c1_t45i32 = constant dense<1> : tensor<4 x 5 x i32> %y = muli %arg0, %c1_t45i32 : tensor<4 x 5 x i32> return %y: tensor<4 x 5 x i32> } //CHECK-LABEL: func @and_self func @and_self(%arg0: i32) -> i32 { //CHECK-NEXT: return %arg0 %1 = and %arg0, %arg0 : i32 return %1 : i32 } //CHECK-LABEL: func @and_self_vector func @and_self_vector(%arg0: vector<4xi32>) -> vector<4xi32> { //CHECK-NEXT: return %arg0 %1 = and %arg0, %arg0 : vector<4xi32> return %1 : vector<4xi32> } //CHECK-LABEL: func @and_self_tensor func @and_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { //CHECK-NEXT: return %arg0 %1 = and %arg0, %arg0 : tensor<4x5xi32> return %1 : tensor<4x5xi32> } //CHECK-LABEL: func @and_zero func @and_zero(%arg0: i32) -> i32 { // CHECK-NEXT: %c0_i32 = constant 0 : i32 %c0_i32 = constant 0 : i32 // CHECK-NEXT: return %c0_i32 %1 = and %arg0, %c0_i32 : i32 return %1 : i32 } //CHECK-LABEL: func @and_zero_index func @and_zero_index(%arg0: index) -> index { // CHECK-NEXT: %[[CST:.*]] = constant 0 : index %c0_index = constant 0 : index // CHECK-NEXT: return %[[CST]] %1 = and %arg0, %c0_index : index return %1 : index } //CHECK-LABEL: func @and_zero_vector func @and_zero_vector(%arg0: vector<4xi32>) -> vector<4xi32> { // CHECK-NEXT: %cst = constant dense<0> : vector<4xi32> %cst = constant dense<0> : vector<4xi32> // CHECK-NEXT: return %cst %1 = and %arg0, %cst : vector<4xi32> return %1 : vector<4xi32> } //CHECK-LABEL: func @and_zero_tensor func @and_zero_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { // CHECK-NEXT: %cst = constant dense<0> : tensor<4x5xi32> %cst = constant dense<0> : tensor<4x5xi32> // CHECK-NEXT: return %cst %1 = and %arg0, %cst : tensor<4x5xi32> return %1 : tensor<4x5xi32> } //CHECK-LABEL: func @or_self func @or_self(%arg0: i32) -> i32 { //CHECK-NEXT: return %arg0 %1 = or %arg0, %arg0 : i32 return %1 : i32 } //CHECK-LABEL: func @or_self_vector func @or_self_vector(%arg0: vector<4xi32>) -> vector<4xi32> { //CHECK-NEXT: return %arg0 %1 = or %arg0, %arg0 : vector<4xi32> return %1 : vector<4xi32> } //CHECK-LABEL: func @or_self_tensor func @or_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { //CHECK-NEXT: return %arg0 %1 = or %arg0, %arg0 : tensor<4x5xi32> return %1 : tensor<4x5xi32> } //CHECK-LABEL: func @or_zero func @or_zero(%arg0: i32) -> i32 { %c0_i32 = constant 0 : i32 // CHECK-NEXT: return %arg0 %1 = or %arg0, %c0_i32 : i32 return %1 : i32 } //CHECK-LABEL: func @or_zero_index func @or_zero_index(%arg0: index) -> index { %c0_index = constant 0 : index // CHECK-NEXT: return %arg0 %1 = or %arg0, %c0_index : index return %1 : index } //CHECK-LABEL: func @or_zero_vector func @or_zero_vector(%arg0: vector<4xi32>) -> vector<4xi32> { // CHECK-NEXT: return %arg0 %cst = constant dense<0> : vector<4xi32> %1 = or %arg0, %cst : vector<4xi32> return %1 : vector<4xi32> } //CHECK-LABEL: func @or_zero_tensor func @or_zero_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { // CHECK-NEXT: return %arg0 %cst = constant dense<0> : tensor<4x5xi32> %1 = or %arg0, %cst : tensor<4x5xi32> return %1 : tensor<4x5xi32> } //CHECK-LABEL: func @xor_self func @xor_self(%arg0: i32) -> i32 { //CHECK-NEXT: %c0_i32 = constant 0 %1 = xor %arg0, %arg0 : i32 //CHECK-NEXT: return %c0_i32 return %1 : i32 } //CHECK-LABEL: func @xor_self_vector func @xor_self_vector(%arg0: vector<4xi32>) -> vector<4xi32> { //CHECK-NEXT: %cst = constant dense<0> : vector<4xi32> %1 = xor %arg0, %arg0 : vector<4xi32> //CHECK-NEXT: return %cst return %1 : vector<4xi32> } //CHECK-LABEL: func @xor_self_tensor func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { //CHECK-NEXT: %cst = constant dense<0> : tensor<4x5xi32> %1 = xor %arg0, %arg0 : tensor<4x5xi32> //CHECK-NEXT: return %cst return %1 : tensor<4x5xi32> } // CHECK-LABEL: func @memref_cast_folding func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 { %1 = memref_cast %arg0 : memref<4xf32> to memref // CHECK-NEXT: %c0 = constant 0 : index %c0 = constant 0 : index %dim = dim %1, 0 : memref // CHECK-NEXT: affine.load %arg0[3] affine.load %1[%dim - 1] : memref // CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32> store %arg1, %1[%c0] : memref // CHECK-NEXT: %{{.*}} = load %arg0[%c0] : memref<4xf32> %0 = load %1[%c0] : memref // CHECK-NEXT: dealloc %arg0 : memref<4xf32> dealloc %1: memref // CHECK-NEXT: return %{{.*}} return %0 : f32 } // CHECK-LABEL: func @alloc_const_fold func @alloc_const_fold() -> memref { // CHECK-NEXT: %0 = alloc() : memref<4xf32> %c4 = constant 4 : index %a = alloc(%c4) : memref // CHECK-NEXT: %1 = memref_cast %0 : memref<4xf32> to memref // CHECK-NEXT: return %1 : memref return %a : memref } // CHECK-LABEL: func @dead_alloc_fold func @dead_alloc_fold() { // CHECK-NEXT: return %c4 = constant 4 : index %a = alloc(%c4) : memref return } // CHECK-LABEL: func @dead_dealloc_fold func @dead_dealloc_fold() { // CHECK-NEXT: return %a = alloc() : memref<4xf32> dealloc %a: memref<4xf32> return } // CHECK-LABEL: func @dead_dealloc_fold_multi_use func @dead_dealloc_fold_multi_use(%cond : i1) { // CHECK-NEXT: cond_br %a = alloc() : memref<4xf32> cond_br %cond, ^bb1, ^bb2 // CHECK-LABEL: bb1: ^bb1: // CHECK-NEXT: return dealloc %a: memref<4xf32> return // CHECK-LABEL: bb2: ^bb2: // CHECK-NEXT: return dealloc %a: memref<4xf32> return } // CHECK-LABEL: func @dead_block_elim func @dead_block_elim() { // CHECK-NOT ^bb func @nested() { return ^bb1: return } return ^bb1: return } // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index) func @dyn_shape_fold(%L : index, %M : index) -> (memref, memref) { // CHECK: %c0 = constant 0 : index %zero = constant 0 : index // The constants below disappear after they propagate into shapes. %nine = constant 9 : index %N = constant 1024 : index %K = constant 512 : index // CHECK-NEXT: %0 = alloc(%arg0) : memref %a = alloc(%L, %N) : memref // CHECK-NEXT: %1 = alloc(%arg1) : memref<4x1024x8x512x?xf32> %b = alloc(%N, %K, %M) : memref<4 x ? x 8 x ? x ? x f32> // CHECK-NEXT: %2 = alloc() : memref<512x1024xi32> %c = alloc(%K, %N) : memref // CHECK: affine.for affine.for %i = 0 to %L { // CHECK-NEXT: affine.for affine.for %j = 0 to 10 { // CHECK-NEXT: load %0[%arg2, %arg3] : memref // CHECK-NEXT: store %{{.*}}, %1[%c0, %c0, %arg2, %arg3, %c0] : memref<4x1024x8x512x?xf32> %v = load %a[%i, %j] : memref store %v, %b[%zero, %zero, %i, %j, %zero] : memref<4x?x8x?x?xf32> } } // CHECK: alloc() : memref<9x9xf32> %d = alloc(%nine, %nine) : memref return %c, %d : memref, memref } #map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> #map2 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)> // CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index, func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index, %BUF: memref, %M : index, %N : index, %K : index) { // CHECK-SAME: [[M:arg[0-9]+]]: index // CHECK-SAME: [[N:arg[0-9]+]]: index // CHECK-SAME: [[K:arg[0-9]+]]: index %c0 = constant 0 : index %c1 = constant 1 : index %0 = alloc(%arg0, %arg1) : memref %1 = alloc(%arg1, %arg2) : memref %2 = dim %1, 2 : memref affine.for %arg3 = 0 to %2 { %3 = alloc(%arg0) : memref %ub = dim %3, 0 : memref affine.for %arg4 = 0 to %ub { %s = dim %0, 0 : memref %v = std.view %3[%c0][%arg4, %s] : memref to memref %sv = std.subview %0[%c0, %c0][%s,%arg4][%c1,%c1] : memref to memref %l = dim %v, 1 : memref %u = dim %sv, 0 : memref affine.for %arg5 = %l to %u { "foo"() : () -> () } } } // CHECK-NEXT: %c0 = constant 0 : index // CHECK-NEXT: %c1 = constant 1 : index // CHECK-NEXT: affine.for %arg7 = 0 to %arg2 { // CHECK-NEXT: affine.for %arg8 = 0 to %arg0 { // CHECK-NEXT: affine.for %arg9 = %arg0 to %arg0 { // CHECK-NEXT: "foo"() : () -> () // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } %A = view %BUF[%c0][%M, %K] : memref to memref %B = view %BUF[%c0][%K, %N] : memref to memref %C = view %BUF[%c0][%M, %N] : memref to memref %M_ = dim %A, 0 : memref %K_ = dim %A, 1 : memref %N_ = dim %C, 1 : memref loop.for %i = %c0 to %M_ step %c1 { loop.for %j = %c0 to %N_ step %c1 { loop.for %k = %c0 to %K_ step %c1 { } } } // CHECK: loop.for %{{.*}} = %c0 to %[[M]] step %c1 { // CHECK: loop.for %arg8 = %c0 to %[[N]] step %c1 { // CHECK: loop.for %arg9 = %c0 to %[[K]] step %c1 { return } // CHECK-LABEL: func @merge_constants func @merge_constants() -> (index, index) { // CHECK-NEXT: %c42 = constant 42 : index %0 = constant 42 : index %1 = constant 42 : index // CHECK-NEXT: return %c42, %c42 return %0, %1: index, index } // CHECK-LABEL: func @hoist_constant func @hoist_constant(%arg0: memref<8xi32>) { // CHECK-NEXT: %c42_i32 = constant 42 : i32 // CHECK-NEXT: affine.for %arg1 = 0 to 8 { affine.for %arg1 = 0 to 8 { // CHECK-NEXT: store %c42_i32, %arg0[%arg1] %c42_i32 = constant 42 : i32 store %c42_i32, %arg0[%arg1] : memref<8xi32> } return } // CHECK-LABEL: func @const_fold_propagate func @const_fold_propagate() -> memref { %VT_i = constant 512 : index %VT_i_s = affine.apply affine_map<(d0) -> (d0 floordiv 8)> (%VT_i) %VT_k_l = affine.apply affine_map<(d0) -> (d0 floordiv 16)> (%VT_i) // CHECK: = alloc() : memref<64x32xf32> %Av = alloc(%VT_i_s, %VT_k_l) : memref return %Av : memref } // CHECK-LABEL: func @br_folding func @br_folding() -> i32 { // CHECK-NEXT: %[[CST:.*]] = constant 0 : i32 // CHECK-NEXT: return %[[CST]] : i32 %c0_i32 = constant 0 : i32 br ^bb1(%c0_i32 : i32) ^bb1(%x : i32): return %x : i32 } // CHECK-LABEL: func @cond_br_folding func @cond_br_folding(%cond : i1, %a : i32) { %false_cond = constant 0 : i1 %true_cond = constant 1 : i1 cond_br %cond, ^bb1, ^bb2(%a : i32) ^bb1: // CHECK: ^bb1: // CHECK-NEXT: br ^bb3 cond_br %true_cond, ^bb3, ^bb2(%a : i32) ^bb2(%x : i32): // CHECK: ^bb2 // CHECK: br ^bb3 cond_br %false_cond, ^bb2(%x : i32), ^bb3 ^bb3: return } // CHECK-LABEL: func @cond_br_and_br_folding func @cond_br_and_br_folding(%a : i32) { // Test the compound folding of conditional and unconditional branches. // CHECK-NEXT: return %false_cond = constant 0 : i1 %true_cond = constant 1 : i1 cond_br %true_cond, ^bb2, ^bb1(%a : i32) ^bb1(%x : i32): cond_br %false_cond, ^bb1(%x : i32), ^bb2 ^bb2: return } // CHECK-LABEL: func @indirect_call_folding func @indirect_target() { return } func @indirect_call_folding() { // CHECK-NEXT: call @indirect_target() : () -> () // CHECK-NEXT: return %indirect_fn = constant @indirect_target : () -> () call_indirect %indirect_fn() : () -> () return } // // IMPORTANT NOTE: the operations in this test are exactly those produced by // lowering affine.apply affine_map<(i) -> (i mod 42)> to standard operations. Please only // change these operations together with the affine lowering pass tests. // // CHECK-LABEL: @lowered_affine_mod func @lowered_affine_mod() -> (index, index) { // CHECK-NEXT: {{.*}} = constant 41 : index %c-43 = constant -43 : index %c42 = constant 42 : index %0 = remi_signed %c-43, %c42 : index %c0 = constant 0 : index %1 = cmpi "slt", %0, %c0 : index %2 = addi %0, %c42 : index %3 = select %1, %2, %0 : index // CHECK-NEXT: {{.*}} = constant 1 : index %c43 = constant 43 : index %c42_0 = constant 42 : index %4 = remi_signed %c43, %c42_0 : index %c0_1 = constant 0 : index %5 = cmpi "slt", %4, %c0_1 : index %6 = addi %4, %c42_0 : index %7 = select %5, %6, %4 : index return %3, %7 : index, index } // // IMPORTANT NOTE: the operations in this test are exactly those produced by // lowering affine.apply affine_map<(i) -> (i mod 42)> to standard operations. Please only // change these operations together with the affine lowering pass tests. // // CHECK-LABEL: func @lowered_affine_floordiv func @lowered_affine_floordiv() -> (index, index) { // CHECK-NEXT: %c-2 = constant -2 : index %c-43 = constant -43 : index %c42 = constant 42 : index %c0 = constant 0 : index %c-1 = constant -1 : index %0 = cmpi "slt", %c-43, %c0 : index %1 = subi %c-1, %c-43 : index %2 = select %0, %1, %c-43 : index %3 = divi_signed %2, %c42 : index %4 = subi %c-1, %3 : index %5 = select %0, %4, %3 : index // CHECK-NEXT: %c1 = constant 1 : index %c43 = constant 43 : index %c42_0 = constant 42 : index %c0_1 = constant 0 : index %c-1_2 = constant -1 : index %6 = cmpi "slt", %c43, %c0_1 : index %7 = subi %c-1_2, %c43 : index %8 = select %6, %7, %c43 : index %9 = divi_signed %8, %c42_0 : index %10 = subi %c-1_2, %9 : index %11 = select %6, %10, %9 : index return %5, %11 : index, index } // // IMPORTANT NOTE: the operations in this test are exactly those produced by // lowering affine.apply affine_map<(i) -> (i mod 42)> to standard operations. Please only // change these operations together with the affine lowering pass tests. // // CHECK-LABEL: func @lowered_affine_ceildiv func @lowered_affine_ceildiv() -> (index, index) { // CHECK-NEXT: %c-1 = constant -1 : index %c-43 = constant -43 : index %c42 = constant 42 : index %c0 = constant 0 : index %c1 = constant 1 : index %0 = cmpi "sle", %c-43, %c0 : index %1 = subi %c0, %c-43 : index %2 = subi %c-43, %c1 : index %3 = select %0, %1, %2 : index %4 = divi_signed %3, %c42 : index %5 = subi %c0, %4 : index %6 = addi %4, %c1 : index %7 = select %0, %5, %6 : index // CHECK-NEXT: %c2 = constant 2 : index %c43 = constant 43 : index %c42_0 = constant 42 : index %c0_1 = constant 0 : index %c1_2 = constant 1 : index %8 = cmpi "sle", %c43, %c0_1 : index %9 = subi %c0_1, %c43 : index %10 = subi %c43, %c1_2 : index %11 = select %8, %9, %10 : index %12 = divi_signed %11, %c42_0 : index %13 = subi %c0_1, %12 : index %14 = addi %12, %c1_2 : index %15 = select %8, %13, %14 : index return %7, %15 : index, index } // Checks that NOP casts are removed. // CHECK-LABEL: cast_values func @cast_values(%arg0: tensor<*xi32>, %arg1: memref) -> (tensor<2xi32>, memref<2xi32>) { // NOP casts %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<*xi32> %1 = memref_cast %arg1 : memref to memref // CHECK-NEXT: %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<2xi32> // CHECK-NEXT: %1 = memref_cast %arg1 : memref to memref<2xi32> %2 = tensor_cast %0 : tensor<*xi32> to tensor<2xi32> %3 = memref_cast %1 : memref to memref<2xi32> // NOP casts %4 = tensor_cast %2 : tensor<2xi32> to tensor<2xi32> %5 = memref_cast %3 : memref<2xi32> to memref<2xi32> // CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32> return %4, %5 : tensor<2xi32>, memref<2xi32> } // ----- #TEST_VIEW_MAP0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)> #TEST_VIEW_MAP1 = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + d1 * s0 + d2)> #TEST_VIEW_MAP2 = affine_map<(d0, d1)[s0] -> (d0 * 4 + d1 + s0)> // CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 11 + d1 + 15)> // CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 11 + s0 + d1)> // CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1 + 15)> // CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 7 + d2)> // CHECK-DAG: #[[VIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 15)> // CHECK-DAG: #[[VIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 7 + d1)> // CHECK-LABEL: func @view func @view(%arg0 : index) { // CHECK: %[[ALLOC_MEM:.*]] = alloc() : memref<2048xi8> %0 = alloc() : memref<2048xi8> %c0 = constant 0 : index %c7 = constant 7 : index %c11 = constant 11 : index %c15 = constant 15 : index // Test: fold constant sizes and offset, update map with static stride/offset. // CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<7x11xf32, #[[VIEW_MAP0]]> %1 = view %0[%c15][%c7, %c11] : memref<2048xi8> to memref load %1[%c0, %c0] : memref // Test: fold constant sizes but not offset, update map with static stride. // Test that we do not a fold dynamic dim which is not produced by a constant. // CHECK: std.view %[[ALLOC_MEM]][%arg0][] : memref<2048xi8> to memref<7x11xf32, #[[VIEW_MAP1]]> %2 = view %0[%arg0][%c7, %c11] : memref<2048xi8> to memref load %2[%c0, %c0] : memref // Test: fold constant offset but not sizes, update map with constant offset. // Test that we fold constant offset but not dynamic dims. // CHECK: std.view %[[ALLOC_MEM]][][%arg0, %arg0] : memref<2048xi8> to memref %3 = view %0[%c15][%arg0, %arg0] : memref<2048xi8> to memref load %3[%c0, %c0] : memref // Test: fold one constant dim, no offset, should update with constant // stride on dim 1, but leave dynamic stride on dim 0. // CHECK: std.view %[[ALLOC_MEM]][][%arg0, %arg0] : memref<2048xi8> to memref %4 = view %0[][%arg0, %arg0, %c7] : memref<2048xi8> to memref load %4[%c0, %c0, %c0] : memref // Test: preserve an existing static dim size while folding a dynamic // dimension and offset. // CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]> %5 = view %0[%c15][%c7] : memref<2048xi8> to memref load %5[%c0, %c0] : memref // Test: folding static alloc and memref_cast into a view. // CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<15x7xf32, #[[VIEW_MAP5]]> %6 = memref_cast %0 : memref<2048xi8> to memref %7 = view %6[%c15][%c7] : memref to memref load %7[%c0, %c0] : memref return } // ----- // CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> // CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 64 + s0 + d1 * 4 + d2)> // CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 79)> // CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 28 + d2 * 11)> // CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> // CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 128 + s0 + d1 * 28 + d2 * 11)> // CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2 + 79)> // CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)> // CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)> // CHECK-LABEL: func @subview // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index func @subview(%arg0 : index, %arg1 : index) -> (index, index) { // CHECK: %[[C0:.*]] = constant 0 : index %c0 = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index %c1 = constant 1 : index // CHECK: %[[C2:.*]] = constant 2 : index %c2 = constant 2 : index // CHECK: %[[C7:.*]] = constant 7 : index %c7 = constant 7 : index // CHECK: %[[C11:.*]] = constant 11 : index %c11 = constant 11 : index %c15 = constant 15 : index // CHECK: %[[ALLOC0:.*]] = alloc() %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> // Test: subview with constant base memref and constant operands is folded. // Note that the subview uses the base memrefs layout map because it used // zero offset and unit stride arguments. // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[BASE_MAP0]]> %1 = subview %0[%c0, %c0, %c0][%c7, %c11, %c2][%c1, %c1, %c1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> load %1[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> // Test: subview with one dynamic operand should not be folded. // CHECK: std.subview %[[ALLOC0]][%[[C0]], %[[ARG0]], %[[C0]]][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]> %2 = subview %0[%c0, %arg0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> load %2[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> // CHECK: %[[ALLOC1:.*]] = alloc(%[[ARG0]]) %3 = alloc(%arg0) : memref (d0 * 64 + d1 * 4 + d2)>> // Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static. // CHECK: std.subview %[[ALLOC1]][][][] : memref to memref<7x11x15xf32, #[[BASE_MAP0]]> %4 = subview %3[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1] : memref (d0 * 64 + d1 * 4 + d2)>> to memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> load %4[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> // Test: subview offset operands are folded correctly w.r.t. base strides. // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]> %5 = subview %0[%c1, %c2, %c7][%c7, %c11, %c2][%c1, %c1, %c1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> load %5[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> // Test: subview stride operands are folded correctly w.r.t. base strides. // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]> %6 = subview %0[%c0, %c0, %c0][%c7, %c11, %c2][%c2, %c7, %c11] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> load %6[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> // Test: subview shape are folded, but offsets and strides are not even if base memref is static // CHECK: std.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]][][%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> %10 = subview %0[%arg0, %arg0, %arg0][%c7, %c11, %c2][%arg1, %arg1, %arg1] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref load %10[%arg1, %arg1, %arg1] : memref // Test: subview strides are folded, but offsets and shape are not even if base memref is static // CHECK: std.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]][%[[ARG1]], %[[ARG1]], %[[ARG1]]][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref to memref load %11[%arg0, %arg0, %arg0] : memref // Test: subview offsets are folded, but strides and shape are not even if base memref is static // CHECK: std.subview %[[ALLOC0]][][%[[ARG1]], %[[ARG1]], %[[ARG1]]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref to memref load %13[%arg1, %arg1, %arg1] : memref // CHECK: %[[ALLOC2:.*]] = alloc(%[[ARG0]], %[[ARG0]], %[[ARG1]]) %14 = alloc(%arg0, %arg0, %arg1) : memref // Test: subview shape are folded, even if base memref is not static // CHECK: std.subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]][][%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> %15 = subview %14[%arg0, %arg0, %arg0][%c7, %c11, %c2][%arg1, %arg1, %arg1] : memref to memref load %15[%arg1, %arg1, %arg1] : memref // TEST: subview strides are not folded when the base memref is not static // CHECK: std.subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]][%[[ARG1]], %[[ARG1]], %[[ARG1]]][%[[C2]], %[[C2]], %[[C2]]] : memref to memref to memref load %16[%arg0, %arg0, %arg0] : memref // TEST: subview offsets are not folded when the base memref is not static // CHECK: std.subview %[[ALLOC2]][%[[C1]], %[[C1]], %[[C1]]][%[[ARG0]], %[[ARG0]], %[[ARG0]]][%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref to memref to memref load %17[%arg0, %arg0, %arg0] : memref // CHECK: %[[ALLOC3:.*]] = alloc() : memref<12x4xf32> %18 = alloc() : memref<12x4xf32> %c4 = constant 4 : index // TEST: subview strides are maintained when sizes are folded // CHECK: std.subview %[[ALLOC3]][%arg1, %arg1][][] : memref<12x4xf32> to memref<2x4xf32, #[[SUBVIEW_MAP6]]> %19 = subview %18[%arg1, %arg1][%c2, %c4][] : memref<12x4xf32> to memref load %19[%arg1, %arg1] : memref // TEST: subview strides and sizes are maintained when offsets are folded // CHECK: std.subview %[[ALLOC3]][][][] : memref<12x4xf32> to memref<12x4xf32, #[[SUBVIEW_MAP7]]> %20 = subview %18[%c2, %c4][][] : memref<12x4xf32> to memref<12x4xf32, offset: ?, strides:[4, 1]> load %20[%arg1, %arg1] : memref<12x4xf32, offset: ?, strides:[4, 1]> // Test: dim on subview is rewritten to size operand. %7 = dim %4, 0 : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> %8 = dim %4, 1 : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index }