diff options
| author | Justin Lebar <jlebar@google.com> | 2017-01-31 05:58:22 +0000 |
|---|---|---|
| committer | Justin Lebar <jlebar@google.com> | 2017-01-31 05:58:22 +0000 |
| commit | 1c9692a46fd5650c65da38cb371b8e62a0303cfa (patch) | |
| tree | 16220e7e090164ccb56edf0fea98863ef2fdbff1 /llvm/lib/Target/NVPTX | |
| parent | 93590e09d517f3574a0a9130d1b56440b928933a (diff) | |
| download | bcm5719-llvm-1c9692a46fd5650c65da38cb371b8e62a0303cfa.tar.gz bcm5719-llvm-1c9692a46fd5650c65da38cb371b8e62a0303cfa.zip | |
[NVPTX] Implement NVPTXTargetLowering::getSqrtEstimate.
Summary:
This lets us lower to sqrt.approx and rsqrt.approx under more
circumstances.
* Now we emit sqrt.approx and rsqrt.approx for calls to @llvm.sqrt.f32,
when fast-math is enabled. Previously, we only would emit it for
calls to @llvm.nvvm.sqrt.f. (With this patch we no longer emit
sqrt.approx for calls to @llvm.nvvm.sqrt.f; we rely on intcombine to
simplify llvm.nvvm.sqrt.f into llvm.sqrt.f32.)
* Now we emit the ftz version of rsqrt.approx when ftz is enabled.
Previously, we only emitted rsqrt.approx when ftz was disabled.
Reviewers: hfinkel
Subscribers: llvm-commits, tra, jholewinski
Differential Revision: https://reviews.llvm.org/D28508
llvm-svn: 293605
Diffstat (limited to 'llvm/lib/Target/NVPTX')
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 44 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 11 |
3 files changed, 49 insertions, 10 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 1fb42496d95..194e46b0448 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1043,6 +1043,50 @@ NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const { return TargetLoweringBase::getPreferredVectorAction(VT); } +SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, + int Enabled, int &ExtraSteps, + bool &UseOneConst, + bool Reciprocal) const { + if (!(Enabled == ReciprocalEstimate::Enabled || + (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32()))) + return SDValue(); + + if (ExtraSteps == ReciprocalEstimate::Unspecified) + ExtraSteps = 0; + + SDLoc DL(Operand); + EVT VT = Operand.getValueType(); + bool Ftz = useF32FTZ(DAG.getMachineFunction()); + + auto MakeIntrinsicCall = [&](Intrinsic::ID IID) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Operand); + }; + + // The sqrt and rsqrt refinement processes assume we always start out with an + // approximation of the rsqrt. Therefore, if we're going to do any refinement + // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing + // any refinement, we must return a regular sqrt. + if (Reciprocal || ExtraSteps > 0) { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f + : Intrinsic::nvvm_rsqrt_approx_f); + else if (VT == MVT::f64) + return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d); + else + return SDValue(); + } else { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f + : Intrinsic::nvvm_sqrt_approx_f); + else { + // There's no sqrt.approx.f64 instruction, so we emit x * rsqrt(x). + return DAG.getNode(ISD::FMUL, DL, VT, Operand, + MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); + } + } +} + SDValue NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 05c54018b73..f6494f6d37e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -526,6 +526,10 @@ public: // to sign-preserving zero. bool useF32FTZ(const MachineFunction &MF) const; + SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, + int &ExtraSteps, bool &UseOneConst, + bool Reciprocal) const override; + bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const; bool allowUnsafeFPMath(MachineFunction &MF) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 8b703bd196e..3345ce8d3cb 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -966,18 +966,9 @@ def FDIV32ri_prec : Requires<[reqPTX20]>; // -// F32 rsqrt +// FMA // -def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b), - "rsqrt.approx.f32 \t$dst, $b;", []>; - -// Convert 1.0f/sqrt(x) to rsqrt.approx.f32. (There is an rsqrt.approx.f64, but -// it's emulated in software.) -def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)), - (RSQRTF32approx1r Float32Regs:$b)>, - Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>; - multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> { def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), |

