diff options
| author | Artem Belevich <tra@google.com> | 2017-01-13 20:56:17 +0000 |
|---|---|---|
| committer | Artem Belevich <tra@google.com> | 2017-01-13 20:56:17 +0000 |
| commit | 64dc9be7b48e3e3ec1f7dec270bf433d2084915a (patch) | |
| tree | ca4f810515a29fb080b1341aaee69917aadc06b0 /llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | |
| parent | 836a3b416d222f0703d2e964a8e08f3892dd0278 (diff) | |
| download | bcm5719-llvm-64dc9be7b48e3e3ec1f7dec270bf433d2084915a.tar.gz bcm5719-llvm-64dc9be7b48e3e3ec1f7dec270bf433d2084915a.zip | |
[NVPTX] Added support for half-precision floating point.
Only scalar half-precision operations are supported at the moment.
- Adds general support for 'half' type in NVPTX.
- fp16 math operations are supported on sm_53+ GPUs only
(can be disabled with --nvptx-no-f16-math).
- Type conversions to/from fp16 are supported on all GPU variants.
- On GPU variants that do not have full fp16 support (or if it's disabled),
fp16 operations are promoted to fp32 and results are converted back
to fp16 for storage.
Differential Revision: https://reviews.llvm.org/D28540
llvm-svn: 291956
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp')
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 71 |
1 files changed, 68 insertions, 3 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 4f3129c0774..6548dad1d58 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -42,7 +42,6 @@ FtzEnabled("nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden, cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."), cl::init(false)); - /// createNVPTXISelDag - This pass converts a legalized DAG into a /// NVPTX-specific DAG, ready for instruction scheduling. FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM, @@ -520,6 +519,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { case ISD::ADDRSPACECAST: SelectAddrSpaceCast(N); return; + case ISD::ConstantFP: + if (tryConstantFP16(N)) + return; + break; default: break; } @@ -541,6 +544,19 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { } } +// There's no way to specify FP16 immediates in .f16 ops, so we have to +// load them into an .f16 register first. +bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) { + if (N->getValueType(0) != MVT::f16) + return false; + SDValue Val = CurDAG->getTargetConstantFP( + cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16); + SDNode *LoadConstF16 = + CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val); + ReplaceNode(N, LoadConstF16); + return true; +} + static unsigned int getCodeAddrSpace(MemSDNode *N) { const Value *Src = N->getMemOperand()->getValue(); @@ -740,7 +756,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { if ((LD->getExtensionType() == ISD::SEXTLOAD)) fromType = NVPTX::PTXLdStInstCode::Signed; else if (ScalarVT.isFloatingPoint()) - fromType = NVPTX::PTXLdStInstCode::Float; + // f16 uses .b16 as its storage type. + fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else fromType = NVPTX::PTXLdStInstCode::Unsigned; @@ -766,6 +784,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_avar; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_avar; + break; case MVT::f32: Opcode = NVPTX::LD_f32_avar; break; @@ -794,6 +815,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_asi; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_asi; + break; case MVT::f32: Opcode = NVPTX::LD_f32_asi; break; @@ -823,6 +847,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_ari_64; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_ari_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari_64; break; @@ -846,6 +873,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_ari; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_ari; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari; break; @@ -875,6 +905,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_areg_64; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_areg_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg_64; break; @@ -898,6 +931,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_areg; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_areg; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg; break; @@ -2173,7 +2209,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { unsigned toTypeWidth = ScalarVT.getSizeInBits(); unsigned int toType; if (ScalarVT.isFloatingPoint()) - toType = NVPTX::PTXLdStInstCode::Float; + // f16 uses .b16 as its storage type. + toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else toType = NVPTX::PTXLdStInstCode::Unsigned; @@ -2200,6 +2238,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_avar; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_avar; + break; case MVT::f32: Opcode = NVPTX::ST_f32_avar; break; @@ -2229,6 +2270,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_asi; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_asi; + break; case MVT::f32: Opcode = NVPTX::ST_f32_asi; break; @@ -2259,6 +2303,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_ari_64; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_ari_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari_64; break; @@ -2282,6 +2329,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_ari; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_ari; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari; break; @@ -2312,6 +2362,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_areg_64; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_areg_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg_64; break; @@ -2335,6 +2388,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_areg; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_areg; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg; break; @@ -2786,6 +2842,9 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) { case MVT::i64: Opc = NVPTX::LoadParamMemI64; break; + case MVT::f16: + Opc = NVPTX::LoadParamMemF16; + break; case MVT::f32: Opc = NVPTX::LoadParamMemF32; break; @@ -2921,6 +2980,9 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) { case MVT::i64: Opcode = NVPTX::StoreRetvalI64; break; + case MVT::f16: + Opcode = NVPTX::StoreRetvalF16; + break; case MVT::f32: Opcode = NVPTX::StoreRetvalF32; break; @@ -3054,6 +3116,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { case MVT::i64: Opcode = NVPTX::StoreParamI64; break; + case MVT::f16: + Opcode = NVPTX::StoreParamF16; + break; case MVT::f32: Opcode = NVPTX::StoreParamF32; break; |

