summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
authorDenis Khalikov <khalikov.denis@huawei.com>2020-01-07 21:47:49 -0500
committerLei Zhang <antiagainst@google.com>2020-01-07 21:51:51 -0500
commitdd495e8a877784df413679e5ec380985b60c0b2c (patch)
treeece71c2332a241313d227efdb6d50041adf47c97 /mlir/lib/Conversion
parent9883b14cd1a4ea2dec8d7ed30df632671f56c69b (diff)
downloadbcm5719-llvm-dd495e8a877784df413679e5ec380985b60c0b2c.tar.gz
bcm5719-llvm-dd495e8a877784df413679e5ec380985b60c0b2c.zip
[mlir][spirv] Add lowering for std cmp ops.
Differential Revision: https://reviews.llvm.org/D72296
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp59
1 files changed, 55 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index a02dee4419a..10533a4eed2 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -39,6 +39,16 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
+/// Convert floating-point comparison operations to SPIR-V dialect.
+class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
+public:
+ using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Convert compare operation to SPIR-V dialect.
class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
public:
@@ -196,6 +206,46 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
}
//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ CmpFOpOperandAdaptor cmpFOpOperands(operands);
+
+ switch (cmpFOp.getPredicate()) {
+#define DISPATCH(cmpPredicate, spirvOp) \
+ case cmpPredicate: \
+ rewriter.replaceOpWithNewOp<spirvOp>( \
+ cmpFOp, cmpFOp.getResult()->getType(), cmpFOpOperands.lhs(), \
+ cmpFOpOperands.rhs()); \
+ return matchSuccess();
+
+ // Ordered.
+ DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
+ DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
+ DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
+ DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
+ DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
+ DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
+ // Unordered.
+ DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
+ DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
+ DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
+ DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
+ DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
+ DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
+
+#undef DISPATCH
+
+ default:
+ break;
+ }
+ return matchFailure();
+}
+
+//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
@@ -218,11 +268,12 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
+ DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
+ DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
+ DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
+ DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
#undef DISPATCH
-
- default:
- break;
}
return matchFailure();
}
@@ -302,7 +353,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
// Add patterns that lower operations into SPIR-V dialect.
populateWithGenerated(context, &patterns);
- patterns.insert<ConstantIndexOpConversion, CmpIOpConversion,
+ patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
OpenPOWER on IntegriCloud