diff options
author | Artem Belevich <tra@google.com> | 2017-01-13 20:56:17 +0000 |
---|---|---|
committer | Artem Belevich <tra@google.com> | 2017-01-13 20:56:17 +0000 |
commit | 64dc9be7b48e3e3ec1f7dec270bf433d2084915a (patch) | |
tree | ca4f810515a29fb080b1341aaee69917aadc06b0 /llvm/lib | |
parent | 836a3b416d222f0703d2e964a8e08f3892dd0278 (diff) | |
download | bcm5719-llvm-64dc9be7b48e3e3ec1f7dec270bf433d2084915a.tar.gz bcm5719-llvm-64dc9be7b48e3e3ec1f7dec270bf433d2084915a.zip |
[NVPTX] Added support for half-precision floating point.
Only scalar half-precision operations are supported at the moment.
- Adds general support for 'half' type in NVPTX.
- fp16 math operations are supported on sm_53+ GPUs only
(can be disabled with --nvptx-no-f16-math).
- Type conversions to/from fp16 are supported on all GPU variants.
- On GPU variants that do not have full fp16 support (or if it's disabled),
fp16 operations are promoted to fp32 and results are converted back
to fp16 for storage.
Differential Revision: https://reviews.llvm.org/D28540
llvm-svn: 291956
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp | 9 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTX.h | 3 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 21 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 71 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 88 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp | 3 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 222 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 50 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXMCExpr.h | 10 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp | 48 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td | 2 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp | 9 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXSubtarget.h | 2 |
16 files changed, 449 insertions, 98 deletions
diff --git a/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp index 4594c22b870..04ae3c2533d 100644 --- a/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp @@ -61,6 +61,9 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, unsigned RegNo) const { case 6: OS << "%fd"; break; + case 7: + OS << "%h"; + break; } unsigned VReg = RegNo & 0x0FFFFFFF; @@ -247,8 +250,12 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum, O << "s"; else if (Imm == NVPTX::PTXLdStInstCode::Unsigned) O << "u"; - else + else if (Imm == NVPTX::PTXLdStInstCode::Untyped) + O << "b"; + else if (Imm == NVPTX::PTXLdStInstCode::Float) O << "f"; + else + llvm_unreachable("Unknown register type"); } else if (!strcmp(Modifier, "vec")) { if (Imm == NVPTX::PTXLdStInstCode::V2) O << ".v2"; diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index c455a437d8d..fef9f789bda 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -108,7 +108,8 @@ enum AddressSpace { enum FromType { Unsigned = 0, Signed, - Float + Float, + Untyped }; enum VecType { Scalar = 1, diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 3c2594c77f4..fa4b389cdb2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -320,6 +320,10 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, switch (Cnt->getType()->getTypeID()) { default: report_fatal_error("Unsupported FP type"); break; + case Type::HalfTyID: + MCOp = MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); + break; case Type::FloatTyID: MCOp = MCOperand::createExpr( NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); @@ -357,6 +361,8 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { Ret = (5 << 28); } else if (RC == &NVPTX::Float64RegsRegClass) { Ret = (6 << 28); + } else if (RC == &NVPTX::Float16RegsRegClass) { + Ret = (7 << 28); } else { report_fatal_error("Bad register class"); } @@ -396,12 +402,15 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { unsigned size = 0; if (auto *ITy = dyn_cast<IntegerType>(Ty)) { size = ITy->getBitWidth(); - if (size < 32) - size = 32; } else { assert(Ty->isFloatingPointTy() && "Floating point type expected here"); size = Ty->getPrimitiveSizeInBits(); } + // PTX ABI requires all scalar return values to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type in + // PTX, so its size must be adjusted here, too. + if (size < 32) + size = 32; O << ".param .b" << size << " func_retval0"; } else if (isa<PointerType>(Ty)) { @@ -1376,6 +1385,9 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { } break; } + case Type::HalfTyID: + // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly. + return "b16"; case Type::FloatTyID: return "f32"; case Type::DoubleTyID: @@ -1601,6 +1613,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { sz = 32; } else if (isa<PointerType>(Ty)) sz = thePointerTy.getSizeInBits(); + else if (Ty->isHalfTy()) + // PTX ABI requires all scalar parameters to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type + // in PTX, so its size must be adjusted here, too. + sz = 32; else sz = Ty->getPrimitiveSizeInBits(); if (isABI) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 4f3129c0774..6548dad1d58 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -42,7 +42,6 @@ FtzEnabled("nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden, cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."), cl::init(false)); - /// createNVPTXISelDag - This pass converts a legalized DAG into a /// NVPTX-specific DAG, ready for instruction scheduling. FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM, @@ -520,6 +519,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { case ISD::ADDRSPACECAST: SelectAddrSpaceCast(N); return; + case ISD::ConstantFP: + if (tryConstantFP16(N)) + return; + break; default: break; } @@ -541,6 +544,19 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { } } +// There's no way to specify FP16 immediates in .f16 ops, so we have to +// load them into an .f16 register first. +bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) { + if (N->getValueType(0) != MVT::f16) + return false; + SDValue Val = CurDAG->getTargetConstantFP( + cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16); + SDNode *LoadConstF16 = + CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val); + ReplaceNode(N, LoadConstF16); + return true; +} + static unsigned int getCodeAddrSpace(MemSDNode *N) { const Value *Src = N->getMemOperand()->getValue(); @@ -740,7 +756,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { if ((LD->getExtensionType() == ISD::SEXTLOAD)) fromType = NVPTX::PTXLdStInstCode::Signed; else if (ScalarVT.isFloatingPoint()) - fromType = NVPTX::PTXLdStInstCode::Float; + // f16 uses .b16 as its storage type. + fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else fromType = NVPTX::PTXLdStInstCode::Unsigned; @@ -766,6 +784,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_avar; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_avar; + break; case MVT::f32: Opcode = NVPTX::LD_f32_avar; break; @@ -794,6 +815,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_asi; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_asi; + break; case MVT::f32: Opcode = NVPTX::LD_f32_asi; break; @@ -823,6 +847,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_ari_64; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_ari_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari_64; break; @@ -846,6 +873,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_ari; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_ari; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari; break; @@ -875,6 +905,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_areg_64; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_areg_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg_64; break; @@ -898,6 +931,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::i64: Opcode = NVPTX::LD_i64_areg; break; + case MVT::f16: + Opcode = NVPTX::LD_f16_areg; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg; break; @@ -2173,7 +2209,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { unsigned toTypeWidth = ScalarVT.getSizeInBits(); unsigned int toType; if (ScalarVT.isFloatingPoint()) - toType = NVPTX::PTXLdStInstCode::Float; + // f16 uses .b16 as its storage type. + toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else toType = NVPTX::PTXLdStInstCode::Unsigned; @@ -2200,6 +2238,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_avar; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_avar; + break; case MVT::f32: Opcode = NVPTX::ST_f32_avar; break; @@ -2229,6 +2270,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_asi; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_asi; + break; case MVT::f32: Opcode = NVPTX::ST_f32_asi; break; @@ -2259,6 +2303,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_ari_64; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_ari_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari_64; break; @@ -2282,6 +2329,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_ari; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_ari; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari; break; @@ -2312,6 +2362,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_areg_64; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_areg_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg_64; break; @@ -2335,6 +2388,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::i64: Opcode = NVPTX::ST_i64_areg; break; + case MVT::f16: + Opcode = NVPTX::ST_f16_areg; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg; break; @@ -2786,6 +2842,9 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) { case MVT::i64: Opc = NVPTX::LoadParamMemI64; break; + case MVT::f16: + Opc = NVPTX::LoadParamMemF16; + break; case MVT::f32: Opc = NVPTX::LoadParamMemF32; break; @@ -2921,6 +2980,9 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) { case MVT::i64: Opcode = NVPTX::StoreRetvalI64; break; + case MVT::f16: + Opcode = NVPTX::StoreRetvalF16; + break; case MVT::f32: Opcode = NVPTX::StoreRetvalF32; break; @@ -3054,6 +3116,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { case MVT::i64: Opcode = NVPTX::StoreParamI64; break; + case MVT::f16: + Opcode = NVPTX::StoreParamF16; + break; case MVT::f32: Opcode = NVPTX::StoreParamF32; break; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index b4cbc8a3f44..889575cdf7c 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -70,6 +70,7 @@ private: bool tryTextureIntrinsic(SDNode *N); bool trySurfaceIntrinsic(SDNode *N); bool tryBFE(SDNode *N); + bool tryConstantFP16(SDNode *N); inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) { return CurDAG->getTargetConstant(Imm, DL, MVT::i32); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 30870c6ee59..f553a808845 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -164,8 +164,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass); addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass); addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); + addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass); + + setOperationAction(ISD::SETCC, MVT::f16, + STI.allowFP16Math() ? Legal : Promote); // Operations not directly supported by NVPTX. + setOperationAction(ISD::SELECT_CC, MVT::f16, + STI.allowFP16Math() ? Expand : Promote); setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); setOperationAction(ISD::SELECT_CC, MVT::i1, Expand); @@ -173,6 +179,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::SELECT_CC, MVT::i16, Expand); setOperationAction(ISD::SELECT_CC, MVT::i32, Expand); setOperationAction(ISD::SELECT_CC, MVT::i64, Expand); + setOperationAction(ISD::BR_CC, MVT::f16, + STI.allowFP16Math() ? Expand : Promote); setOperationAction(ISD::BR_CC, MVT::f32, Expand); setOperationAction(ISD::BR_CC, MVT::f64, Expand); setOperationAction(ISD::BR_CC, MVT::i1, Expand); @@ -259,6 +267,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // This is legal in NVPTX setOperationAction(ISD::ConstantFP, MVT::f64, Legal); setOperationAction(ISD::ConstantFP, MVT::f32, Legal); + setOperationAction(ISD::ConstantFP, MVT::f16, Legal); // TRAP can be lowered to PTX trap setOperationAction(ISD::TRAP, MVT::Other, Legal); @@ -305,18 +314,36 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine(ISD::SREM); setTargetDAGCombine(ISD::UREM); + if (!STI.allowFP16Math()) { + // Promote fp16 arithmetic if fp16 hardware isn't available or the + // user passed --nvptx-no-fp16-math. The flag is useful because, + // although sm_53+ GPUs have some sort of FP16 support in + // hardware, only sm_53 and sm_60 have full implementation. Others + // only have token amount of hardware and are likely to run faster + // by using fp32 units instead. + setOperationAction(ISD::FADD, MVT::f16, Promote); + setOperationAction(ISD::FMUL, MVT::f16, Promote); + setOperationAction(ISD::FSUB, MVT::f16, Promote); + setOperationAction(ISD::FMA, MVT::f16, Promote); + } + // Library functions. These default to Expand, but we have instructions // for them. + setOperationAction(ISD::FCEIL, MVT::f16, Legal); setOperationAction(ISD::FCEIL, MVT::f32, Legal); setOperationAction(ISD::FCEIL, MVT::f64, Legal); + setOperationAction(ISD::FFLOOR, MVT::f16, Legal); setOperationAction(ISD::FFLOOR, MVT::f32, Legal); setOperationAction(ISD::FFLOOR, MVT::f64, Legal); setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal); setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal); + setOperationAction(ISD::FRINT, MVT::f16, Legal); setOperationAction(ISD::FRINT, MVT::f32, Legal); setOperationAction(ISD::FRINT, MVT::f64, Legal); + setOperationAction(ISD::FROUND, MVT::f16, Legal); setOperationAction(ISD::FROUND, MVT::f32, Legal); setOperationAction(ISD::FROUND, MVT::f64, Legal); + setOperationAction(ISD::FTRUNC, MVT::f16, Legal); setOperationAction(ISD::FTRUNC, MVT::f32, Legal); setOperationAction(ISD::FTRUNC, MVT::f64, Legal); setOperationAction(ISD::FMINNUM, MVT::f32, Legal); @@ -324,6 +351,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::FMAXNUM, MVT::f32, Legal); setOperationAction(ISD::FMAXNUM, MVT::f64, Legal); + // 'Expand' implements FCOPYSIGN without calling an external library. + setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); + + // FP16 does not support these nodes in hardware, but we can perform + // these ops using single-precision hardware. + setOperationAction(ISD::FDIV, MVT::f16, Promote); + setOperationAction(ISD::FREM, MVT::f16, Promote); + setOperationAction(ISD::FSQRT, MVT::f16, Promote); + setOperationAction(ISD::FSIN, MVT::f16, Promote); + setOperationAction(ISD::FCOS, MVT::f16, Promote); + setOperationAction(ISD::FABS, MVT::f16, Promote); + setOperationAction(ISD::FMINNUM, MVT::f16, Promote); + setOperationAction(ISD::FMAXNUM, MVT::f16, Promote); + setOperationAction(ISD::FMINNAN, MVT::f16, Promote); + setOperationAction(ISD::FMAXNAN, MVT::f16, Promote); + // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. // No FPOW or FREM in PTX. @@ -967,19 +1012,21 @@ std::string NVPTXTargetLowering::getPrototype( unsigned size = 0; if (auto *ITy = dyn_cast<IntegerType>(retTy)) { size = ITy->getBitWidth(); - if (size < 32) - size = 32; } else { assert(retTy->isFloatingPointTy() && "Floating point type expected here"); size = retTy->getPrimitiveSizeInBits(); } + // PTX ABI requires all scalar return values to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type in + // PTX, so its size must be adjusted here, too. + if (size < 32) + size = 32; O << ".param .b" << size << " _"; } else if (isa<PointerType>(retTy)) { O << ".param .b" << PtrVT.getSizeInBits() << " _"; - } else if ((retTy->getTypeID() == Type::StructTyID) || - isa<VectorType>(retTy)) { + } else if (retTy->isAggregateType() || retTy->isVectorTy()) { auto &DL = CS->getCalledFunction()->getParent()->getDataLayout(); O << ".param .align " << retAlignment << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]"; @@ -1018,7 +1065,7 @@ std::string NVPTXTargetLowering::getPrototype( OIdx += len - 1; continue; } - // i8 types in IR will be i16 types in SDAG + // i8 types in IR will be i16 types in SDAG assert((getValueType(DL, Ty) == Outs[OIdx].VT || (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) && "type mismatch between callee prototype and arguments"); @@ -1028,8 +1075,13 @@ std::string NVPTXTargetLowering::getPrototype( sz = cast<IntegerType>(Ty)->getBitWidth(); if (sz < 32) sz = 32; - } else if (isa<PointerType>(Ty)) + } else if (isa<PointerType>(Ty)) { sz = PtrVT.getSizeInBits(); + } else if (Ty->isHalfTy()) + // PTX ABI requires all scalar parameters to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type + // in PTX, so its size must be adjusted here, too. + sz = 32; else sz = Ty->getPrimitiveSizeInBits(); O << ".param .b" << sz << " "; @@ -1340,7 +1392,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, needExtend = true; if (sz < 32) sz = 32; - } + } else if (VT.isFloatingPoint() && sz < 32) + // PTX ABI requires all scalar parameters to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type + // in PTX, so its size must be adjusted here, too. + sz = 32; SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue DeclareParamOps[] = { Chain, DAG.getConstant(paramCount, dl, MVT::i32), @@ -1952,12 +2008,15 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const { SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { EVT ValVT = Op.getOperand(1).getValueType(); - if (ValVT == MVT::i1) + switch (ValVT.getSimpleVT().SimpleTy) { + case MVT::i1: return LowerSTOREi1(Op, DAG); - else if (ValVT.isVector()) - return LowerSTOREVector(Op, DAG); - else - return SDValue(); + default: + if (ValVT.isVector()) + return LowerSTOREVector(Op, DAG); + else + return SDValue(); + } } SDValue @@ -2557,8 +2616,9 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, // specifically not for aggregates. TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal); TheStoreType = MVT::i32; - } - else if (TmpVal.getValueSizeInBits() < 16) + } else if (RetTy->isHalfTy()) { + TheStoreType = MVT::f16; + } else if (TmpVal.getValueSizeInBits() < 16) TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal); SDValue Ops[] = { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index fa25fe1264a..0b6ebd8cf68 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -528,6 +528,7 @@ private: SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSTOREf16(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index 7f89742a321..67e6e252eb9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -52,6 +52,9 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB, } else if (DestRC == &NVPTX::Int64RegsRegClass) { Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr : NVPTX::BITCONVERT_64_F2I); + } else if (DestRC == &NVPTX::Float16RegsRegClass) { + Op = (SrcRC == &NVPTX::Float16RegsRegClass ? NVPTX::FMOV16rr + : NVPTX::BITCONVERT_16_I2F); } else if (DestRC == &NVPTX::Float32RegsRegClass) { Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr : NVPTX::BITCONVERT_32_I2F); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index c276f177fab..7cc1141e20d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -18,6 +18,10 @@ let hasSideEffects = 0 in { def NOP : NVPTXInst<(outs), (ins), "", []>; } +let OperandType = "OPERAND_IMMEDIATE" in { + def f16imm : Operand<f16>; +} + // List of vector specific properties def isVecLD : VecInstTypeEnum<1>; def isVecST : VecInstTypeEnum<2>; @@ -149,6 +153,7 @@ def true : Predicate<"true">; def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">; +def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; //===----------------------------------------------------------------------===// // Some Common Instruction Class Templates @@ -240,11 +245,11 @@ multiclass F3<string OpcStr, SDNode OpNode> { [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>; } -// Template for instructions which take three fp64 or fp32 args. The +// Template for instructions which take three FP args. The // instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64"). // // Also defines ftz (flush subnormal inputs and results to sign-preserving -// zero) variants for fp32 functions. +// zero) variants for fp32/fp16 functions. // // This multiclass should be used for nodes that can be folded to make fma ops. // In this case, we use the ".rn" variant when FMA is disabled, as this behaves @@ -287,6 +292,19 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> { [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[allowFMA]>; + def f16rr_ftz : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[useFP16Math, allowFMA, doF32FTZ]>; + def f16rr : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[useFP16Math, allowFMA]>; + // These have strange names so we don't perturb existing mir tests. def _rnf64rr : NVPTXInst<(outs Float64Regs:$dst), @@ -324,6 +342,18 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> { !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[noFMA]>; + def _rnf16rr_ftz : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[useFP16Math, noFMA, doF32FTZ]>; + def _rnf16rr : + NVPTXInst<(outs Float16Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b), + !strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"), + [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>, + Requires<[useFP16Math, noFMA]>; } // Template for operations which take two f32 or f64 operands. Provides three @@ -375,11 +405,6 @@ let hasSideEffects = 0 in { (ins Int16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", FromName, ".u16\t$dst, $src;"), []>; - def _f16 : - NVPTXInst<(outs RC:$dst), - (ins Int16Regs:$src, CvtMode:$mode), - !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".f16\t$dst, $src;"), []>; def _s32 : NVPTXInst<(outs RC:$dst), (ins Int32Regs:$src, CvtMode:$mode), @@ -400,6 +425,11 @@ let hasSideEffects = 0 in { (ins Int64Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", FromName, ".u64\t$dst, $src;"), []>; + def _f16 : + NVPTXInst<(outs RC:$dst), + (ins Float16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".f16\t$dst, $src;"), []>; def _f32 : NVPTXInst<(outs RC:$dst), (ins Float32Regs:$src, CvtMode:$mode), @@ -417,11 +447,11 @@ let hasSideEffects = 0 in { defm CVT_u8 : CVT_FROM_ALL<"u8", Int16Regs>; defm CVT_s16 : CVT_FROM_ALL<"s16", Int16Regs>; defm CVT_u16 : CVT_FROM_ALL<"u16", Int16Regs>; - defm CVT_f16 : CVT_FROM_ALL<"f16", Int16Regs>; defm CVT_s32 : CVT_FROM_ALL<"s32", Int32Regs>; defm CVT_u32 : CVT_FROM_ALL<"u32", Int32Regs>; defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>; defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>; + defm CVT_f16 : CVT_FROM_ALL<"f16", Float16Regs>; defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>; defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>; @@ -749,6 +779,15 @@ def DoubleConst1 : PatLeaf<(fpimm), [{ N->getValueAPF().convertToDouble() == 1.0; }]>; +// Loads FP16 constant into a register. +// +// ptxas does not have hex representation for fp16, so we can't use +// fp16 immediate values in .f16 instructions. Instead we have to load +// the constant into a register using mov.b16. +def LOAD_CONST_F16 : + NVPTXInst<(outs Float16Regs:$dst), (ins f16imm:$a), + "mov.b16 \t$dst, $a;", []>; + defm FADD : F3_fma_component<"add", fadd>; defm FSUB : F3_fma_component<"sub", fsub>; defm FMUL : F3_fma_component<"mul", fmul>; @@ -943,6 +982,15 @@ multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> Requires<[Pred]>; } +multiclass FMA_F16<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> { + def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), + !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), + [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>, + Requires<[useFP16Math, Pred]>; +} + +defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", Float16Regs, f16imm, doF32FTZ>; +defm FMA16 : FMA_F16<"fma.rn.f16", Float16Regs, f16imm, true>; defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>; defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, true>; defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, true>; @@ -1320,6 +1368,11 @@ defm SETP_s64 : SETP<"s64", Int64Regs, i64imm>; defm SETP_u64 : SETP<"u64", Int64Regs, i64imm>; defm SETP_f32 : SETP<"f32", Float32Regs, f32imm>; defm SETP_f64 : SETP<"f64", Float64Regs, f64imm>; +def SETP_f16rr : + NVPTXInst<(outs Int1Regs:$dst), + (ins Float16Regs:$a, Float16Regs:$b, CmpMode:$cmp), + "setp${cmp:base}${cmp:ftz}.f16 $dst, $a, $b;", + []>, Requires<[useFP16Math]>; // FIXME: This doesn't appear to be correct. The "set" mnemonic has the form // "set.CmpOp{.ftz}.dtype.stype", where dtype is the type of the destination @@ -1348,6 +1401,7 @@ defm SET_u32 : SET<"u32", Int32Regs, i32imm>; defm SET_b64 : SET<"b64", Int64Regs, i64imm>; defm SET_s64 : SET<"s64", Int64Regs, i64imm>; defm SET_u64 : SET<"u64", Int64Regs, i64imm>; +defm SET_f16 : SET<"f16", Float16Regs, f16imm>; defm SET_f32 : SET<"f32", Float32Regs, f32imm>; defm SET_f64 : SET<"f64", Float64Regs, f64imm>; @@ -1411,6 +1465,7 @@ defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>; defm SELP_b64 : SELP_PATTERN<"b64", Int64Regs, i64imm, imm>; defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>; defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>; +defm SELP_f16 : SELP_PATTERN<"b16", Float16Regs, f16imm, fpimm>; defm SELP_f32 : SELP_PATTERN<"f32", Float32Regs, f32imm, fpimm>; defm SELP_f64 : SELP_PATTERN<"f64", Float64Regs, f64imm, fpimm>; @@ -1475,6 +1530,9 @@ let IsSimpleMove=1, hasSideEffects=0 in { def IMOV64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$sss), "mov.u64 \t$dst, $sss;", []>; + def FMOV16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$src), + // We have to use .b16 here as there's no mov.f16. + "mov.b16 \t$dst, $src;", []>; def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), "mov.f32 \t$dst, $src;", []>; def FMOV64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src), @@ -1636,6 +1694,26 @@ def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)), multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> { + // f16 -> pred + def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SETP_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>, + Requires<[useFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SETP_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>, + Requires<[useFP16Math]>; + def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)), + (SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>, + Requires<[useFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode Float16Regs:$a, fpimm:$b)), + (SETP_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>, + Requires<[useFP16Math]>; + def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)), + (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>, + Requires<[useFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode fpimm:$a, Float16Regs:$b)), + (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, + Requires<[useFP16Math]>; + // f32 -> pred def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)), (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, @@ -1661,6 +1739,26 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> { def : Pat<(i1 (OpNode fpimm:$a, Float64Regs:$b)), (SETP_f64ir fpimm:$a, Float64Regs:$b, Mode)>; + // f16 -> i32 + def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SET_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>, + Requires<[useFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode Float16Regs:$a, Float16Regs:$b)), + (SET_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>, + Requires<[useFP16Math]>; + def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)), + (SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), ModeFTZ)>, + Requires<[useFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode Float16Regs:$a, fpimm:$b)), + (SET_f16rr Float16Regs:$a, (LOAD_CONST_F16 fpimm:$b), Mode)>, + Requires<[useFP16Math]>; + def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)), + (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, ModeFTZ)>, + Requires<[useFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode fpimm:$a, Float16Regs:$b)), + (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, + Requires<[useFP16Math]>; + // f32 -> i32 def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)), (SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, @@ -1944,6 +2042,7 @@ def LoadParamMemV2I8 : LoadParamV2MemInst<Int16Regs, ".b8">; def LoadParamMemV4I32 : LoadParamV4MemInst<Int32Regs, ".b32">; def LoadParamMemV4I16 : LoadParamV4MemInst<Int16Regs, ".b16">; def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">; +def LoadParamMemF16 : LoadParamMemInst<Float16Regs, ".b16">; def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">; def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">; def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">; @@ -1964,6 +2063,7 @@ def StoreParamV4I32 : StoreParamV4Inst<Int32Regs, ".b32">; def StoreParamV4I16 : StoreParamV4Inst<Int16Regs, ".b16">; def StoreParamV4I8 : StoreParamV4Inst<Int16Regs, ".b8">; +def StoreParamF16 : StoreParamInst<Float16Regs, ".b16">; def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">; def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">; def StoreParamV2F32 : StoreParamV2Inst<Float32Regs, ".f32">; @@ -1984,6 +2084,7 @@ def StoreRetvalV4I8 : StoreRetvalV4Inst<Int16Regs, ".b8">; def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".f64">; def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">; +def StoreRetvalF16 : StoreRetvalInst<Float16Regs, ".b16">; def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">; def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">; def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".f32">; @@ -2071,6 +2172,7 @@ def MoveParamI16 : [(set Int16Regs:$dst, (MoveParam Int16Regs:$src))]>; def MoveParamF64 : MoveParamInst<Float64Regs, ".f64">; def MoveParamF32 : MoveParamInst<Float32Regs, ".f32">; +def MoveParamF16 : MoveParamInst<Float16Regs, ".f16">; class PseudoUseParamInst<NVPTXRegClass regclass> : NVPTXInst<(outs), (ins regclass:$src), @@ -2131,6 +2233,7 @@ let mayLoad=1, hasSideEffects=0 in { defm LD_i16 : LD<Int16Regs>; defm LD_i32 : LD<Int32Regs>; defm LD_i64 : LD<Int64Regs>; + defm LD_f16 : LD<Float16Regs>; defm LD_f32 : LD<Float32Regs>; defm LD_f64 : LD<Float64Regs>; } @@ -2179,6 +2282,7 @@ let mayStore=1, hasSideEffects=0 in { defm ST_i16 : ST<Int16Regs>; defm ST_i32 : ST<Int32Regs>; defm ST_i64 : ST<Int64Regs>; + defm ST_f16 : ST<Float16Regs>; defm ST_f32 : ST<Float32Regs>; defm ST_f64 : ST<Float64Regs>; } @@ -2371,6 +2475,8 @@ class F_BITCONVERT<string SzStr, NVPTXRegClass regclassIn, !strconcat("mov.b", !strconcat(SzStr, " \t $d, $a;")), [(set regclassOut:$d, (bitconvert regclassIn:$a))]>; +def BITCONVERT_16_I2F : F_BITCONVERT<"16", Int16Regs, Float16Regs>; +def BITCONVERT_16_F2I : F_BITCONVERT<"16", Float16Regs, Int16Regs>; def BITCONVERT_32_I2F : F_BITCONVERT<"32", Int32Regs, Float32Regs>; def BITCONVERT_32_F2I : F_BITCONVERT<"32", Float32Regs, Int32Regs>; def BITCONVERT_64_I2F : F_BITCONVERT<"64", Int64Regs, Float64Regs>; @@ -2380,6 +2486,26 @@ def BITCONVERT_64_F2I : F_BITCONVERT<"64", Float64Regs, Int64Regs>; // we cannot specify floating-point literals in isel patterns. Therefore, we // use an integer selp to select either 1 or 0 and then cvt to floating-point. +// sint -> f16 +def : Pat<(f16 (sint_to_fp Int1Regs:$a)), + (CVT_f16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f16 (sint_to_fp Int16Regs:$a)), + (CVT_f16_s16 Int16Regs:$a, CvtRN)>; +def : Pat<(f16 (sint_to_fp Int32Regs:$a)), + (CVT_f16_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(f16 (sint_to_fp Int64Regs:$a)), + (CVT_f16_s64 Int64Regs:$a, CvtRN)>; + +// uint -> f16 +def : Pat<(f16 (uint_to_fp Int1Regs:$a)), + (CVT_f16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f16 (uint_to_fp Int16Regs:$a)), + (CVT_f16_u16 Int16Regs:$a, CvtRN)>; +def : Pat<(f16 (uint_to_fp Int32Regs:$a)), + (CVT_f16_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(f16 (uint_to_fp Int64Regs:$a)), + (CVT_f16_u64 Int64Regs:$a, CvtRN)>; + // sint -> f32 def : Pat<(f32 (sint_to_fp Int1Regs:$a)), (CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; @@ -2421,6 +2547,38 @@ def : Pat<(f64 (uint_to_fp Int64Regs:$a)), (CVT_f64_u64 Int64Regs:$a, CvtRN)>; +// f16 -> sint +def : Pat<(i1 (fp_to_sint Float16Regs:$a)), + (SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_sint Float16Regs:$a)), + (CVT_s16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i16 (fp_to_sint Float16Regs:$a)), + (CVT_s16_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_sint Float16Regs:$a)), + (CVT_s32_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i32 (fp_to_sint Float16Regs:$a)), + (CVT_s32_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_sint Float16Regs:$a)), + (CVT_s64_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i64 (fp_to_sint Float16Regs:$a)), + (CVT_s64_f16 Float16Regs:$a, CvtRZI)>; + +// f16 -> uint +def : Pat<(i1 (fp_to_uint Float16Regs:$a)), + (SETP_b16ri (BITCONVERT_16_F2I Float16Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_uint Float16Regs:$a)), + (CVT_u16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i16 (fp_to_uint Float16Regs:$a)), + (CVT_u16_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_uint Float16Regs:$a)), + (CVT_u32_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i32 (fp_to_uint Float16Regs:$a)), + (CVT_u32_f16 Float16Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_uint Float16Regs:$a)), + (CVT_u64_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i64 (fp_to_uint Float16Regs:$a)), + (CVT_u64_f16 Float16Regs:$a, CvtRZI)>; + // f32 -> sint def : Pat<(i1 (fp_to_sint Float32Regs:$a)), (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; @@ -2650,12 +2808,36 @@ def : Pat<(ctpop Int64Regs:$a), (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>; def : Pat<(ctpop Int16Regs:$a), (CVT_u16_u32 (POPCr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE)), CvtNONE)>; +// fpround f32 -> f16 +def : Pat<(f16 (fpround Float32Regs:$a)), + (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f16 (fpround Float32Regs:$a)), + (CVT_f16_f32 Float32Regs:$a, CvtRN)>; + +// fpround f64 -> f16 +def : Pat<(f16 (fpround Float64Regs:$a)), + (CVT_f16_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f16 (fpround Float64Regs:$a)), + (CVT_f16_f64 Float64Regs:$a, CvtRN)>; + // fpround f64 -> f32 def : Pat<(f32 (fpround Float64Regs:$a)), (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fpround Float64Regs:$a)), (CVT_f32_f64 Float64Regs:$a, CvtRN)>; +// fpextend f16 -> f32 +def : Pat<(f32 (fpextend Float16Regs:$a)), + (CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f32 (fpextend Float16Regs:$a)), + (CVT_f32_f16 Float16Regs:$a, CvtNONE)>; + +// fpextend f16 -> f64 +def : Pat<(f64 (fpextend Float16Regs:$a)), + (CVT_f64_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f64 (fpextend Float16Regs:$a)), + (CVT_f64_f16 Float16Regs:$a, CvtNONE)>; + // fpextend f32 -> f64 def : Pat<(f64 (fpextend Float32Regs:$a)), (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; @@ -2667,6 +2849,10 @@ def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone, // fceil, ffloor, fround, ftrunc. +def : Pat<(fceil Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fceil Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRPI)>, Requires<[doNoF32FTZ]>; def : Pat<(fceil Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(fceil Float32Regs:$a), @@ -2674,6 +2860,10 @@ def : Pat<(fceil Float32Regs:$a), def : Pat<(fceil Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRPI)>; +def : Pat<(ffloor Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ffloor Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRMI)>, Requires<[doNoF32FTZ]>; def : Pat<(ffloor Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(ffloor Float32Regs:$a), @@ -2681,6 +2871,10 @@ def : Pat<(ffloor Float32Regs:$a), def : Pat<(ffloor Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; +def : Pat<(fround Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f16 (fround Float16Regs:$a)), + (CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; def : Pat<(fround Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fround Float32Regs:$a)), @@ -2688,6 +2882,10 @@ def : Pat<(f32 (fround Float32Regs:$a)), def : Pat<(f64 (fround Float64Regs:$a)), (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(ftrunc Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ftrunc Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRZI)>, Requires<[doNoF32FTZ]>; def : Pat<(ftrunc Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(ftrunc Float32Regs:$a), @@ -2699,6 +2897,10 @@ def : Pat<(ftrunc Float64Regs:$a), // strictly correct, because it causes us to ignore the rounding mode. But it // matches what CUDA's "libm" does. +def : Pat<(fnearbyint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fnearbyint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; def : Pat<(fnearbyint Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(fnearbyint Float32Regs:$a), @@ -2706,6 +2908,10 @@ def : Pat<(fnearbyint Float32Regs:$a), def : Pat<(fnearbyint Float64Regs:$a), (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(frint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(frint Float16Regs:$a), + (CVT_f16_f16 Float16Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; def : Pat<(frint Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(frint Float32Regs:$a), diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index b0408f12f5b..8f0477c93a0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -803,49 +803,13 @@ def : Pat<(int_nvvm_ull2d_rp Int64Regs:$a), (CVT_f64_u64 Int64Regs:$a, CvtRP)>; -// FIXME: Ideally, we could use these patterns instead of the scope-creating -// patterns, but ptxas does not like these since .s16 is not compatible with -// .f16. The solution is to use .bXX for all integer register types, but we -// are not there yet. -//def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a), -// (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>; -//def : Pat<(int_nvvm_f2h_rn Float32Regs:$a), -// (CVT_f16_f32 Float32Regs:$a, CvtRN)>; -// -//def : Pat<(int_nvvm_h2f Int16Regs:$a), -// (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; - -def INT_NVVM_F2H_RN_FTZ : F_MATH_1<!strconcat("{{\n\t", - !strconcat(".reg .b16 %temp;\n\t", - !strconcat("cvt.rn.ftz.f16.f32 \t%temp, $src0;\n\t", - !strconcat("mov.b16 \t$dst, %temp;\n", - "}}")))), - Int16Regs, Float32Regs, int_nvvm_f2h_rn_ftz>; -def INT_NVVM_F2H_RN : F_MATH_1<!strconcat("{{\n\t", - !strconcat(".reg .b16 %temp;\n\t", - !strconcat("cvt.rn.f16.f32 \t%temp, $src0;\n\t", - !strconcat("mov.b16 \t$dst, %temp;\n", - "}}")))), - Int16Regs, Float32Regs, int_nvvm_f2h_rn>; - -def INT_NVVM_H2F : F_MATH_1<!strconcat("{{\n\t", - !strconcat(".reg .b16 %temp;\n\t", - !strconcat("mov.b16 \t%temp, $src0;\n\t", - !strconcat("cvt.f32.f16 \t$dst, %temp;\n\t", - "}}")))), - Float32Regs, Int16Regs, int_nvvm_h2f>; - -def : Pat<(f32 (f16_to_fp Int16Regs:$a)), - (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; -def : Pat<(i16 (fp_to_f16 Float32Regs:$a)), - (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; -def : Pat<(i16 (fp_to_f16 Float32Regs:$a)), - (CVT_f16_f32 Float32Regs:$a, CvtRN)>; - -def : Pat<(f64 (f16_to_fp Int16Regs:$a)), - (CVT_f64_f16 Int16Regs:$a, CvtNONE)>; -def : Pat<(i16 (fp_to_f16 Float64Regs:$a)), - (CVT_f16_f64 Float64Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a), + (BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ))>; +def : Pat<(int_nvvm_f2h_rn Float32Regs:$a), + (BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN))>; + +def : Pat<(int_nvvm_h2f Int16Regs:$a), + (CVT_f32_f16 (BITCONVERT_16_I2F Int16Regs:$a), CvtNONE)>; // // Bitcast diff --git a/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp b/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp index eab5ee80561..86a28f7d070 100644 --- a/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp @@ -27,6 +27,13 @@ void NVPTXFloatMCExpr::printImpl(raw_ostream &OS, const MCAsmInfo *MAI) const { switch (Kind) { default: llvm_unreachable("Invalid kind!"); + case VK_NVPTX_HALF_PREC_FLOAT: + // ptxas does not have a way to specify half-precision floats. + // Instead we have to print and load fp16 constants as .b16 + OS << "0x"; + NumHex = 4; + APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); + break; case VK_NVPTX_SINGLE_PREC_FLOAT: OS << "0f"; NumHex = 8; diff --git a/llvm/lib/Target/NVPTX/NVPTXMCExpr.h b/llvm/lib/Target/NVPTX/NVPTXMCExpr.h index 7f833c42fa8..95741d9b045 100644 --- a/llvm/lib/Target/NVPTX/NVPTXMCExpr.h +++ b/llvm/lib/Target/NVPTX/NVPTXMCExpr.h @@ -22,8 +22,9 @@ class NVPTXFloatMCExpr : public MCTargetExpr { public: enum VariantKind { VK_NVPTX_None, - VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision - VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision + VK_NVPTX_HALF_PREC_FLOAT, // FP constant in half-precision + VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision + VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision }; private: @@ -40,6 +41,11 @@ public: static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt, MCContext &Ctx); + static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt, + MCContext &Ctx) { + return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx); + } + static const NVPTXFloatMCExpr *createConstantFPSingle(const APFloat &Flt, MCContext &Ctx) { return create(VK_NVPTX_SINGLE_PREC_FLOAT, Flt, Ctx); diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp index 6cbf0604d7e..9caedfb0fef 100644 --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp @@ -27,12 +27,17 @@ using namespace llvm; namespace llvm { std::string getNVPTXRegClassName(TargetRegisterClass const *RC) { - if (RC == &NVPTX::Float32RegsRegClass) { + if (RC == &NVPTX::Float32RegsRegClass) return ".f32"; - } - if (RC == &NVPTX::Float64RegsRegClass) { + if (RC == &NVPTX::Float16RegsRegClass) + // Ideally fp16 registers should be .f16, but this syntax is only + // supported on sm_53+. On the other hand, .b16 registers are + // accepted for all supported fp16 instructions on all GPU + // variants, so we can use them instead. + return ".b16"; + if (RC == &NVPTX::Float64RegsRegClass) return ".f64"; - } else if (RC == &NVPTX::Int64RegsRegClass) { + if (RC == &NVPTX::Int64RegsRegClass) // We use untyped (.b) integer registers here as NVCC does. // Correctness of generated code does not depend on register type, // but using .s/.u registers runs into ptxas bug that prevents @@ -52,40 +57,35 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) { // add.f16v2 rb32,rb32,rb32; // OK // add.f16v2 rs32,rs32,rs32; // OK return ".b64"; - } else if (RC == &NVPTX::Int32RegsRegClass) { + if (RC == &NVPTX::Int32RegsRegClass) return ".b32"; - } else if (RC == &NVPTX::Int16RegsRegClass) { + if (RC == &NVPTX::Int16RegsRegClass) return ".b16"; - } else if (RC == &NVPTX::Int1RegsRegClass) { + if (RC == &NVPTX::Int1RegsRegClass) return ".pred"; - } else if (RC == &NVPTX::SpecialRegsRegClass) { + if (RC == &NVPTX::SpecialRegsRegClass) return "!Special!"; - } else { - return "INTERNAL"; - } - return ""; + return "INTERNAL"; } std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) { - if (RC == &NVPTX::Float32RegsRegClass) { + if (RC == &NVPTX::Float32RegsRegClass) return "%f"; - } - if (RC == &NVPTX::Float64RegsRegClass) { + if (RC == &NVPTX::Float16RegsRegClass) + return "%h"; + if (RC == &NVPTX::Float64RegsRegClass) return "%fd"; - } else if (RC == &NVPTX::Int64RegsRegClass) { + if (RC == &NVPTX::Int64RegsRegClass) return "%rd"; - } else if (RC == &NVPTX::Int32RegsRegClass) { + if (RC == &NVPTX::Int32RegsRegClass) return "%r"; - } else if (RC == &NVPTX::Int16RegsRegClass) { + if (RC == &NVPTX::Int16RegsRegClass) return "%rs"; - } else if (RC == &NVPTX::Int1RegsRegClass) { + if (RC == &NVPTX::Int1RegsRegClass) return "%p"; - } else if (RC == &NVPTX::SpecialRegsRegClass) { + if (RC == &NVPTX::SpecialRegsRegClass) return "!Special!"; - } else { - return "INTERNAL"; - } - return ""; + return "INTERNAL"; } } diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td index ff6ccc457db..fd255bdb6d2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -36,6 +36,7 @@ foreach i = 0-4 in { def RS#i : NVPTXReg<"%rs"#i>; // 16-bit def R#i : NVPTXReg<"%r"#i>; // 32-bit def RL#i : NVPTXReg<"%rd"#i>; // 64-bit + def H#i : NVPTXReg<"%h"#i>; // 16-bit float def F#i : NVPTXReg<"%f"#i>; // 32-bit float def FL#i : NVPTXReg<"%fd"#i>; // 64-bit float @@ -57,6 +58,7 @@ def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>; def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>; def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4))>; def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4))>; +def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>; def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>; def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>; def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>; diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp index 6e1f427ed02..acbee86ae38 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -23,6 +23,11 @@ using namespace llvm; #define GET_SUBTARGETINFO_CTOR #include "NVPTXGenSubtargetInfo.inc" +static cl::opt<bool> + NoF16Math("nvptx-no-f16-math", cl::ZeroOrMore, cl::Hidden, + cl::desc("NVPTX Specific: Disable generation of f16 math ops."), + cl::init(false)); + // Pin the vtable to this file. void NVPTXSubtarget::anchor() {} @@ -57,3 +62,7 @@ bool NVPTXSubtarget::hasImageHandles() const { // Disabled, otherwise return false; } + +bool NVPTXSubtarget::allowFP16Math() const { + return hasFP16Math() && NoF16Math == false; +} diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index da020a94bcd..96618cf4637 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -101,6 +101,8 @@ public: inline bool hasROT32() const { return hasHWROT32() || hasSWROT32(); } inline bool hasROT64() const { return SmVersion >= 20; } bool hasImageHandles() const; + bool hasFP16Math() const { return SmVersion >= 53; } + bool allowFP16Math() const; unsigned int getSmVersion() const { return SmVersion; } std::string getTargetName() const { return TargetName; } |