diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 44 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 11 |
3 files changed, 49 insertions, 10 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 1fb42496d95..194e46b0448 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1043,6 +1043,50 @@ NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const { return TargetLoweringBase::getPreferredVectorAction(VT); } +SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, + int Enabled, int &ExtraSteps, + bool &UseOneConst, + bool Reciprocal) const { + if (!(Enabled == ReciprocalEstimate::Enabled || + (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32()))) + return SDValue(); + + if (ExtraSteps == ReciprocalEstimate::Unspecified) + ExtraSteps = 0; + + SDLoc DL(Operand); + EVT VT = Operand.getValueType(); + bool Ftz = useF32FTZ(DAG.getMachineFunction()); + + auto MakeIntrinsicCall = [&](Intrinsic::ID IID) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Operand); + }; + + // The sqrt and rsqrt refinement processes assume we always start out with an + // approximation of the rsqrt. Therefore, if we're going to do any refinement + // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing + // any refinement, we must return a regular sqrt. + if (Reciprocal || ExtraSteps > 0) { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f + : Intrinsic::nvvm_rsqrt_approx_f); + else if (VT == MVT::f64) + return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d); + else + return SDValue(); + } else { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f + : Intrinsic::nvvm_sqrt_approx_f); + else { + // There's no sqrt.approx.f64 instruction, so we emit x * rsqrt(x). + return DAG.getNode(ISD::FMUL, DL, VT, Operand, + MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); + } + } +} + SDValue NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 05c54018b73..f6494f6d37e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -526,6 +526,10 @@ public: // to sign-preserving zero. bool useF32FTZ(const MachineFunction &MF) const; + SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, + int &ExtraSteps, bool &UseOneConst, + bool Reciprocal) const override; + bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const; bool allowUnsafeFPMath(MachineFunction &MF) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 8b703bd196e..3345ce8d3cb 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -966,18 +966,9 @@ def FDIV32ri_prec : Requires<[reqPTX20]>; // -// F32 rsqrt +// FMA // -def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b), - "rsqrt.approx.f32 \t$dst, $b;", []>; - -// Convert 1.0f/sqrt(x) to rsqrt.approx.f32. (There is an rsqrt.approx.f64, but -// it's emulated in software.) -def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)), - (RSQRTF32approx1r Float32Regs:$b)>, - Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>; - multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> { def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), |

