diff options
| author | Denis Khalikov <khalikov.denis@huawei.com> | 2020-01-07 21:47:49 -0500 |
|---|---|---|
| committer | Lei Zhang <antiagainst@google.com> | 2020-01-07 21:51:51 -0500 |
| commit | dd495e8a877784df413679e5ec380985b60c0b2c (patch) | |
| tree | ece71c2332a241313d227efdb6d50041adf47c97 /mlir/lib/Conversion | |
| parent | 9883b14cd1a4ea2dec8d7ed30df632671f56c69b (diff) | |
| download | bcm5719-llvm-dd495e8a877784df413679e5ec380985b60c0b2c.tar.gz bcm5719-llvm-dd495e8a877784df413679e5ec380985b60c0b2c.zip | |
[mlir][spirv] Add lowering for std cmp ops.
Differential Revision: https://reviews.llvm.org/D72296
Diffstat (limited to 'mlir/lib/Conversion')
| -rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | 59 |
1 files changed, 55 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index a02dee4419a..10533a4eed2 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -39,6 +39,16 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Convert floating-point comparison operations to SPIR-V dialect. +class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> { +public: + using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Convert compare operation to SPIR-V dialect. class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> { public: @@ -196,6 +206,46 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( } //===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + CmpFOpOperandAdaptor cmpFOpOperands(operands); + + switch (cmpFOp.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp<spirvOp>( \ + cmpFOp, cmpFOp.getResult()->getType(), cmpFOpOperands.lhs(), \ + cmpFOpOperands.rhs()); \ + return matchSuccess(); + + // Ordered. + DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + // Unordered. + DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + +#undef DISPATCH + + default: + break; + } + return matchFailure(); +} + +//===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// @@ -218,11 +268,12 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands, DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); + DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); + DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); + DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH - - default: - break; } return matchFailure(); } @@ -302,7 +353,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, OwningRewritePatternList &patterns) { // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); - patterns.insert<ConstantIndexOpConversion, CmpIOpConversion, + patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion, IntegerOpConversion<AddIOp, spirv::IAddOp>, IntegerOpConversion<MulIOp, spirv::IMulOp>, IntegerOpConversion<SignedDivIOp, spirv::SDivOp>, |

