summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
diff options
context:
space:
mode:
authorArtem Belevich <tra@google.com>2017-01-13 20:56:17 +0000
committerArtem Belevich <tra@google.com>2017-01-13 20:56:17 +0000
commit64dc9be7b48e3e3ec1f7dec270bf433d2084915a (patch)
treeca4f810515a29fb080b1341aaee69917aadc06b0 /llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
parent836a3b416d222f0703d2e964a8e08f3892dd0278 (diff)
downloadbcm5719-llvm-64dc9be7b48e3e3ec1f7dec270bf433d2084915a.tar.gz
bcm5719-llvm-64dc9be7b48e3e3ec1f7dec270bf433d2084915a.zip
[NVPTX] Added support for half-precision floating point.
Only scalar half-precision operations are supported at the moment. - Adds general support for 'half' type in NVPTX. - fp16 math operations are supported on sm_53+ GPUs only (can be disabled with --nvptx-no-f16-math). - Type conversions to/from fp16 are supported on all GPU variants. - On GPU variants that do not have full fp16 support (or if it's disabled), fp16 operations are promoted to fp32 and results are converted back to fp16 for storage. Differential Revision: https://reviews.llvm.org/D28540 llvm-svn: 291956
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp71
1 files changed, 68 insertions, 3 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 4f3129c0774..6548dad1d58 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -42,7 +42,6 @@ FtzEnabled("nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden,
cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."),
cl::init(false));
-
/// createNVPTXISelDag - This pass converts a legalized DAG into a
/// NVPTX-specific DAG, ready for instruction scheduling.
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -520,6 +519,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
case ISD::ADDRSPACECAST:
SelectAddrSpaceCast(N);
return;
+ case ISD::ConstantFP:
+ if (tryConstantFP16(N))
+ return;
+ break;
default:
break;
}
@@ -541,6 +544,19 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
}
}
+// There's no way to specify FP16 immediates in .f16 ops, so we have to
+// load them into an .f16 register first.
+bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
+ if (N->getValueType(0) != MVT::f16)
+ return false;
+ SDValue Val = CurDAG->getTargetConstantFP(
+ cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16);
+ SDNode *LoadConstF16 =
+ CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val);
+ ReplaceNode(N, LoadConstF16);
+ return true;
+}
+
static unsigned int getCodeAddrSpace(MemSDNode *N) {
const Value *Src = N->getMemOperand()->getValue();
@@ -740,7 +756,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
if ((LD->getExtensionType() == ISD::SEXTLOAD))
fromType = NVPTX::PTXLdStInstCode::Signed;
else if (ScalarVT.isFloatingPoint())
- fromType = NVPTX::PTXLdStInstCode::Float;
+ // f16 uses .b16 as its storage type.
+ fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+ : NVPTX::PTXLdStInstCode::Float;
else
fromType = NVPTX::PTXLdStInstCode::Unsigned;
@@ -766,6 +784,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_avar;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_avar;
break;
@@ -794,6 +815,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_asi;
break;
@@ -823,6 +847,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari_64;
break;
@@ -846,6 +873,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari;
break;
@@ -875,6 +905,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg_64;
break;
@@ -898,6 +931,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LD_i64_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg;
break;
@@ -2173,7 +2209,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
unsigned toTypeWidth = ScalarVT.getSizeInBits();
unsigned int toType;
if (ScalarVT.isFloatingPoint())
- toType = NVPTX::PTXLdStInstCode::Float;
+ // f16 uses .b16 as its storage type.
+ toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+ : NVPTX::PTXLdStInstCode::Float;
else
toType = NVPTX::PTXLdStInstCode::Unsigned;
@@ -2200,6 +2238,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_avar;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_avar;
break;
@@ -2229,6 +2270,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_asi;
break;
@@ -2259,6 +2303,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari_64;
break;
@@ -2282,6 +2329,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari;
break;
@@ -2312,6 +2362,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg_64;
break;
@@ -2335,6 +2388,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::ST_i64_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg;
break;
@@ -2786,6 +2842,9 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
case MVT::i64:
Opc = NVPTX::LoadParamMemI64;
break;
+ case MVT::f16:
+ Opc = NVPTX::LoadParamMemF16;
+ break;
case MVT::f32:
Opc = NVPTX::LoadParamMemF32;
break;
@@ -2921,6 +2980,9 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::StoreRetvalI64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreRetvalF16;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreRetvalF32;
break;
@@ -3054,6 +3116,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::StoreParamI64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreParamF16;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreParamF32;
break;
OpenPOWER on IntegriCloud