summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp86
-rw-r--r--mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir45
2 files changed, 116 insertions, 15 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 56b243c2971..6bb9deaf85a 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -314,8 +314,9 @@ public:
return matchFailure();
}
- // Use the bitwidth set in the value attribute to decide the result type of
- // the SPIR-V constant operation since SPIR-V does not support index types.
+ // Use the bitwidth set in the value attribute to decide the result type
+ // of the SPIR-V constant operation since SPIR-V does not support index
+ // types.
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
@@ -331,11 +332,47 @@ public:
}
};
-/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
-/// for this. If the integer operation is on variables of IndexType, the type of
-/// the return value of the replacement operation differs from that of the
-/// replaced operation. This is not handled in tablegen-based pattern
-/// specification.
+/// Convert compare operation to SPIR-V dialect.
+class CmpIOpConversion final : public ConversionPattern {
+public:
+ CmpIOpConversion(MLIRContext *context)
+ : ConversionPattern(CmpIOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto cmpIOp = cast<CmpIOp>(op);
+ CmpIOpOperandAdaptor cmpIOpOperands(operands);
+
+ switch (cmpIOp.getPredicate()) {
+#define DISPATCH(cmpPredicate, spirvOp) \
+ case cmpPredicate: \
+ rewriter.replaceOpWithNewOp<spirvOp>(op, op->getResult(0)->getType(), \
+ cmpIOpOperands.lhs(), \
+ cmpIOpOperands.rhs()); \
+ return matchSuccess();
+
+ DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
+ DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
+ DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
+ DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
+ DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
+ DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
+
+#undef DISPATCH
+
+ default:
+ break;
+ }
+ return matchFailure();
+ }
+};
+
+/// Convert integer binary operations to SPIR-V operations. Cannot use
+/// tablegen for this. If the integer operation is on variables of IndexType,
+/// the type of the return value of the replacement operation differs from
+/// that of the replaced operation. This is not handled in tablegen-based
+/// pattern specification.
template <typename StdOp, typename SPIRVOp>
class IntegerOpConversion final : public ConversionPattern {
public:
@@ -396,9 +433,25 @@ public:
}
};
-/// Convert store -> spv.StoreOp. The operands of the replaced operation are of
-/// IndexType while that of the replacement operation are of type i32. This is
-/// not supported in tablegen based pattern specification.
+/// Convert select -> spv.Select
+class SelectOpConversion : public ConversionPattern {
+public:
+ SelectOpConversion(MLIRContext *context)
+ : ConversionPattern(SelectOp::getOperationName(), 1, context) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SelectOpOperandAdaptor selectOperands(operands);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
+ selectOperands.true_value(),
+ selectOperands.false_value());
+ return matchSuccess();
+ }
+};
+
+/// Convert store -> spv.StoreOp. The operands of the replaced operation are
+/// of IndexType while that of the replacement operation are of type i32. This
+/// is not supported in tablegen based pattern specification.
// TODO(ravishankarm) : These could potentially be templated on the operation
// being converted, since the same logic should work for linalg.store.
class StoreOpConversion final : public ConversionPattern {
@@ -437,9 +490,14 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
- patterns.insert<ConstantIndexOpConversion,
- IntegerOpConversion<AddIOp, spirv::IAddOp>,
- IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
- ReturnToSPIRVConversion, StoreOpConversion>(context);
+ patterns
+ .insert<ConstantIndexOpConversion, CmpIOpConversion,
+ IntegerOpConversion<AddIOp, spirv::IAddOp>,
+ IntegerOpConversion<MulIOp, spirv::IMulOp>,
+ IntegerOpConversion<DivISOp, spirv::SDivOp>,
+ IntegerOpConversion<RemISOp, spirv::SModOp>,
+ IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
+ ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
+ context);
}
} // namespace mlir
diff --git a/mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir b/mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir
index 334920c3626..d0effdd3fe4 100644
--- a/mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir
@@ -57,4 +57,47 @@ func @constval() {
// CHECK: spv.constant 1 : i32
%4 = constant 1 : index
return
-} \ No newline at end of file
+}
+
+// CHECK-LABEL: @cmpiop
+func @cmpiop(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spv.IEqual
+ %0 = cmpi "eq", %arg0, %arg1 : i32
+ // CHECK: spv.INotEqual
+ %1 = cmpi "ne", %arg0, %arg1 : i32
+ // CHECK: spv.SLessThan
+ %2 = cmpi "slt", %arg0, %arg1 : i32
+ // CHECK: spv.SLessThanEqual
+ %3 = cmpi "sle", %arg0, %arg1 : i32
+ // CHECK: spv.SGreaterThan
+ %4 = cmpi "sgt", %arg0, %arg1 : i32
+ // CHECK: spv.SGreaterThanEqual
+ %5 = cmpi "sge", %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @select
+func @selectOp(%arg0 : i32, %arg1 : i32) {
+ %0 = cmpi "sle", %arg0, %arg1 : i32
+ // CHECK: spv.Select
+ %1 = select %0, %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @div_rem
+func @div_rem(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spv.SDiv
+ %0 = divis %arg0, %arg1 : i32
+ // CHECK: spv.SMod
+ %1 = remis %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: @add_sub
+func @add_sub(%arg0 : i32, %arg1 : i32) {
+ // CHECK: spv.IAdd
+ %0 = addi %arg0, %arg1 : i32
+ // CHECK: spv.ISub
+ %1 = subi %arg0, %arg1 : i32
+ return
+}
OpenPOWER on IntegriCloud