diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td | 4 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 30 | ||||
| -rw-r--r-- | mlir/test/Dialect/SPIRV/canonicalize.mlir | 180 |
3 files changed, 141 insertions, 73 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td index cbcd9303626..00ce72f5b2a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td @@ -292,6 +292,8 @@ def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd", SPV_Integer, [Commutative]> { ``` }]; + + let hasFolder = 1; } // ----- @@ -328,6 +330,8 @@ def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative]> { ``` }]; + + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 6bb052d49d7..ae7643fa915 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" @@ -1519,6 +1520,35 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) { } //===----------------------------------------------------------------------===// +// spv.IAdd +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "spv.IAdd expects two operands"); + // lhs + 0 = lhs + if (matchPattern(operand2(), m_Zero())) + return operand1(); + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// spv.IMul +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) { + assert(operands.size() == 2 && "spv.IMul expects two operands"); + // lhs * 0 == 0 + if (matchPattern(operand2(), m_Zero())) + return operand2(); + // lhs * 1 = lhs + if (matchPattern(operand2(), m_One())) + return operand1(); + + return nullptr; +} + +//===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index 84c2544762b..9df7b09b8e4 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -59,6 +59,34 @@ func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> { // ----- //===----------------------------------------------------------------------===// +// spv.Bitcast +//===----------------------------------------------------------------------===// + +func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 { + // CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64 + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] + %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32> + %1 = spv.Bitcast %0 : vector<2xi32> to i64 + %2 = spv.Bitcast %1 : i64 to f64 + spv.ReturnValue %2 : f64 +} + +// ----- + +func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 { + // CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64 + // CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64 + // CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]] + // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]] + %0 = spv.Bitcast %arg0 : vector<2xf32> to i64 + %1 = spv.Bitcast %0 : i64 to f64 + spv.Store "Uniform" %arg1, %0 : i64 + spv.ReturnValue %1 : f64 +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.CompositeExtract //===----------------------------------------------------------------------===// @@ -135,29 +163,92 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a // ----- //===----------------------------------------------------------------------===// -// spv.Bitcast +// spv.IAdd //===----------------------------------------------------------------------===// -func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 { - // CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64 - // CHECK-NEXT: spv.ReturnValue %[[RESULT]] - %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32> - %1 = spv.Bitcast %0 : vector<2xi32> to i64 - %2 = spv.Bitcast %1 : i64 to f64 - spv.ReturnValue %2 : f64 +// CHECK-LABEL: @iadd_zero +// CHECK-SAME: (%[[ARG:.*]]: i32) +func @iadd_zero(%arg0: i32) -> (i32, i32) { + %zero = spv.constant 0 : i32 + %0 = spv.IAdd %arg0, %zero : i32 + %1 = spv.IAdd %zero, %arg0 : i32 + // CHECK: return %[[ARG]], %[[ARG]] + return %0, %1: i32, i32 } // ----- -func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 { - // CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64 - // CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64 - // CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]] +//===----------------------------------------------------------------------===// +// spv.IMul +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @imul_zero_one +// CHECK-SAME: (%[[ARG:.*]]: i32) +func @imul_zero_one(%arg0: i32) -> (i32, i32) { + // CHECK: %[[ZERO:.*]] = spv.constant 0 + %zero = spv.constant 0 : i32 + %one = spv.constant 1: i32 + %0 = spv.IMul %arg0, %zero : i32 + %1 = spv.IMul %one, %arg0 : i32 + // CHECK: return %[[ZERO]], %[[ARG]] + return %0, %1: i32, i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.LogicalNot +//===----------------------------------------------------------------------===// + +func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.IEqual %arg0, %arg1 : vector<3xi64> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.INotEqual %arg0, %arg1 : vector<3xi64> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_parent_multi_use(%arg0: vector<3xi64>, %arg1: vector<3xi64>, %arg2: !spv.ptr<vector<3xi1>, Uniform>) -> vector<3xi1> { + // CHECK: %[[RESULT_0:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: %[[RESULT_1:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: spv.Store "Uniform" {{%.*}}, %[[RESULT_0]] // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]] - %0 = spv.Bitcast %arg0 : vector<2xf32> to i64 - %1 = spv.Bitcast %0 : i64 to f64 - spv.Store "Uniform" %arg1, %0 : i64 - spv.ReturnValue %1 : f64 + %0 = spv.INotEqual %arg0, %arg1 : vector<3xi64> + %1 = spv.LogicalNot %0 : vector<3xi1> + spv.Store "Uniform" %arg2, %0 : vector<3xi1> + spv.ReturnValue %1 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_to_logical_not_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.LogicalNotEqual {{%.*}}, {{%.*}} : vector<3xi1> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.LogicalEqual %arg0, %arg1 : vector<3xi1> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.LogicalEqual {{%.*}}, {{%.*}} : vector<3xi1> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.LogicalNotEqual %arg0, %arg1 : vector<3xi1> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> } // ----- @@ -391,60 +482,3 @@ func @cannot_canonicalize_selection_op_4(%cond: i1) -> () { } spv.Return } - -// ----- - -//===----------------------------------------------------------------------===// -// spv.LogicalNot -//===----------------------------------------------------------------------===// - -func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { - // CHECK: %[[RESULT:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64> - // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> - %2 = spv.IEqual %arg0, %arg1 : vector<3xi64> - %3 = spv.LogicalNot %2 : vector<3xi1> - spv.ReturnValue %3 : vector<3xi1> -} - -// ----- - -func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { - // CHECK: %[[RESULT:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64> - // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> - %2 = spv.INotEqual %arg0, %arg1 : vector<3xi64> - %3 = spv.LogicalNot %2 : vector<3xi1> - spv.ReturnValue %3 : vector<3xi1> -} - -// ----- - -func @convert_logical_not_parent_multi_use(%arg0: vector<3xi64>, %arg1: vector<3xi64>, %arg2: !spv.ptr<vector<3xi1>, Uniform>) -> vector<3xi1> { - // CHECK: %[[RESULT_0:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64> - // CHECK-NEXT: %[[RESULT_1:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64> - // CHECK-NEXT: spv.Store "Uniform" {{%.*}}, %[[RESULT_0]] - // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]] - %0 = spv.INotEqual %arg0, %arg1 : vector<3xi64> - %1 = spv.LogicalNot %0 : vector<3xi1> - spv.Store "Uniform" %arg2, %0 : vector<3xi1> - spv.ReturnValue %1 : vector<3xi1> -} - -// ----- - -func @convert_logical_not_to_logical_not_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> { - // CHECK: %[[RESULT:.*]] = spv.LogicalNotEqual {{%.*}}, {{%.*}} : vector<3xi1> - // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> - %2 = spv.LogicalEqual %arg0, %arg1 : vector<3xi1> - %3 = spv.LogicalNot %2 : vector<3xi1> - spv.ReturnValue %3 : vector<3xi1> -} - -// ----- - -func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> { - // CHECK: %[[RESULT:.*]] = spv.LogicalEqual {{%.*}}, {{%.*}} : vector<3xi1> - // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> - %2 = spv.LogicalNotEqual %arg0, %arg1 : vector<3xi1> - %3 = spv.LogicalNot %2 : vector<3xi1> - spv.ReturnValue %3 : vector<3xi1> -} |

