summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td4
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp30
-rw-r--r--mlir/test/Dialect/SPIRV/canonicalize.mlir180
3 files changed, 141 insertions, 73 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
index cbcd9303626..00ce72f5b2a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
@@ -292,6 +292,8 @@ def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd", SPV_Integer, [Commutative]> {
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -328,6 +330,8 @@ def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative]> {
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6bb052d49d7..ae7643fa915 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -26,6 +26,7 @@
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
@@ -1519,6 +1520,35 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
}
//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "spv.IAdd expects two operands");
+ // lhs + 0 = lhs
+ if (matchPattern(operand2(), m_Zero()))
+ return operand1();
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "spv.IMul expects two operands");
+ // lhs * 0 == 0
+ if (matchPattern(operand2(), m_Zero()))
+ return operand2();
+ // lhs * 1 = lhs
+ if (matchPattern(operand2(), m_One()))
+ return operand1();
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir
index 84c2544762b..9df7b09b8e4 100644
--- a/mlir/test/Dialect/SPIRV/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir
@@ -59,6 +59,34 @@ func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> {
// -----
//===----------------------------------------------------------------------===//
+// spv.Bitcast
+//===----------------------------------------------------------------------===//
+
+func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 {
+ // CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT]]
+ %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32>
+ %1 = spv.Bitcast %0 : vector<2xi32> to i64
+ %2 = spv.Bitcast %1 : i64 to f64
+ spv.ReturnValue %2 : f64
+}
+
+// -----
+
+func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 {
+ // CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64
+ // CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
+ // CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]]
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
+ %0 = spv.Bitcast %arg0 : vector<2xf32> to i64
+ %1 = spv.Bitcast %0 : i64 to f64
+ spv.Store "Uniform" %arg1, %0 : i64
+ spv.ReturnValue %1 : f64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spv.CompositeExtract
//===----------------------------------------------------------------------===//
@@ -135,29 +163,92 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a
// -----
//===----------------------------------------------------------------------===//
-// spv.Bitcast
+// spv.IAdd
//===----------------------------------------------------------------------===//
-func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 {
- // CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
- // CHECK-NEXT: spv.ReturnValue %[[RESULT]]
- %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32>
- %1 = spv.Bitcast %0 : vector<2xi32> to i64
- %2 = spv.Bitcast %1 : i64 to f64
- spv.ReturnValue %2 : f64
+// CHECK-LABEL: @iadd_zero
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func @iadd_zero(%arg0: i32) -> (i32, i32) {
+ %zero = spv.constant 0 : i32
+ %0 = spv.IAdd %arg0, %zero : i32
+ %1 = spv.IAdd %zero, %arg0 : i32
+ // CHECK: return %[[ARG]], %[[ARG]]
+ return %0, %1: i32, i32
}
// -----
-func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 {
- // CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64
- // CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
- // CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]]
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @imul_zero_one
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func @imul_zero_one(%arg0: i32) -> (i32, i32) {
+ // CHECK: %[[ZERO:.*]] = spv.constant 0
+ %zero = spv.constant 0 : i32
+ %one = spv.constant 1: i32
+ %0 = spv.IMul %arg0, %zero : i32
+ %1 = spv.IMul %one, %arg0 : i32
+ // CHECK: return %[[ZERO]], %[[ARG]]
+ return %0, %1: i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalNot
+//===----------------------------------------------------------------------===//
+
+func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
+ // CHECK: %[[RESULT:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
+ %2 = spv.IEqual %arg0, %arg1 : vector<3xi64>
+ %3 = spv.LogicalNot %2 : vector<3xi1>
+ spv.ReturnValue %3 : vector<3xi1>
+}
+
+// -----
+
+func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
+ // CHECK: %[[RESULT:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64>
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
+ %2 = spv.INotEqual %arg0, %arg1 : vector<3xi64>
+ %3 = spv.LogicalNot %2 : vector<3xi1>
+ spv.ReturnValue %3 : vector<3xi1>
+}
+
+// -----
+
+func @convert_logical_not_parent_multi_use(%arg0: vector<3xi64>, %arg1: vector<3xi64>, %arg2: !spv.ptr<vector<3xi1>, Uniform>) -> vector<3xi1> {
+ // CHECK: %[[RESULT_0:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
+ // CHECK-NEXT: %[[RESULT_1:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64>
+ // CHECK-NEXT: spv.Store "Uniform" {{%.*}}, %[[RESULT_0]]
// CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
- %0 = spv.Bitcast %arg0 : vector<2xf32> to i64
- %1 = spv.Bitcast %0 : i64 to f64
- spv.Store "Uniform" %arg1, %0 : i64
- spv.ReturnValue %1 : f64
+ %0 = spv.INotEqual %arg0, %arg1 : vector<3xi64>
+ %1 = spv.LogicalNot %0 : vector<3xi1>
+ spv.Store "Uniform" %arg2, %0 : vector<3xi1>
+ spv.ReturnValue %1 : vector<3xi1>
+}
+
+// -----
+
+func @convert_logical_not_to_logical_not_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> {
+ // CHECK: %[[RESULT:.*]] = spv.LogicalNotEqual {{%.*}}, {{%.*}} : vector<3xi1>
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
+ %2 = spv.LogicalEqual %arg0, %arg1 : vector<3xi1>
+ %3 = spv.LogicalNot %2 : vector<3xi1>
+ spv.ReturnValue %3 : vector<3xi1>
+}
+
+// -----
+
+func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> {
+ // CHECK: %[[RESULT:.*]] = spv.LogicalEqual {{%.*}}, {{%.*}} : vector<3xi1>
+ // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
+ %2 = spv.LogicalNotEqual %arg0, %arg1 : vector<3xi1>
+ %3 = spv.LogicalNot %2 : vector<3xi1>
+ spv.ReturnValue %3 : vector<3xi1>
}
// -----
@@ -391,60 +482,3 @@ func @cannot_canonicalize_selection_op_4(%cond: i1) -> () {
}
spv.Return
}
-
-// -----
-
-//===----------------------------------------------------------------------===//
-// spv.LogicalNot
-//===----------------------------------------------------------------------===//
-
-func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
- // CHECK: %[[RESULT:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
- // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
- %2 = spv.IEqual %arg0, %arg1 : vector<3xi64>
- %3 = spv.LogicalNot %2 : vector<3xi1>
- spv.ReturnValue %3 : vector<3xi1>
-}
-
-// -----
-
-func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
- // CHECK: %[[RESULT:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64>
- // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
- %2 = spv.INotEqual %arg0, %arg1 : vector<3xi64>
- %3 = spv.LogicalNot %2 : vector<3xi1>
- spv.ReturnValue %3 : vector<3xi1>
-}
-
-// -----
-
-func @convert_logical_not_parent_multi_use(%arg0: vector<3xi64>, %arg1: vector<3xi64>, %arg2: !spv.ptr<vector<3xi1>, Uniform>) -> vector<3xi1> {
- // CHECK: %[[RESULT_0:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
- // CHECK-NEXT: %[[RESULT_1:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64>
- // CHECK-NEXT: spv.Store "Uniform" {{%.*}}, %[[RESULT_0]]
- // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
- %0 = spv.INotEqual %arg0, %arg1 : vector<3xi64>
- %1 = spv.LogicalNot %0 : vector<3xi1>
- spv.Store "Uniform" %arg2, %0 : vector<3xi1>
- spv.ReturnValue %1 : vector<3xi1>
-}
-
-// -----
-
-func @convert_logical_not_to_logical_not_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> {
- // CHECK: %[[RESULT:.*]] = spv.LogicalNotEqual {{%.*}}, {{%.*}} : vector<3xi1>
- // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
- %2 = spv.LogicalEqual %arg0, %arg1 : vector<3xi1>
- %3 = spv.LogicalNot %2 : vector<3xi1>
- spv.ReturnValue %3 : vector<3xi1>
-}
-
-// -----
-
-func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> {
- // CHECK: %[[RESULT:.*]] = spv.LogicalEqual {{%.*}}, {{%.*}} : vector<3xi1>
- // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
- %2 = spv.LogicalNotEqual %arg0, %arg1 : vector<3xi1>
- %3 = spv.LogicalNot %2 : vector<3xi1>
- spv.ReturnValue %3 : vector<3xi1>
-}
OpenPOWER on IntegriCloud