diff options
author | Artem Belevich <tra@google.com> | 2017-02-23 22:38:24 +0000 |
---|---|---|
committer | Artem Belevich <tra@google.com> | 2017-02-23 22:38:24 +0000 |
commit | 620db1f3dd08ebbba71b0e16f83c11323e04bc05 (patch) | |
tree | ec38800ee0d0ae0282c4d5e47a193336c80b82a6 /llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | |
parent | 063a56e81c89dfbf7896c71c8cd849b911eb7098 (diff) | |
download | bcm5719-llvm-620db1f3dd08ebbba71b0e16f83c11323e04bc05.tar.gz bcm5719-llvm-620db1f3dd08ebbba71b0e16f83c11323e04bc05.zip |
[NVPTX] Added support for .f16x2 instructions.
This patch enables support for .f16x2 operations.
Added new register type Float16x2.
Added support for .f16x2 instructions.
Added handling of vectorized loads/stores of v2f16 values.
Differential Revision: https://reviews.llvm.org/D30057
Differential Revision: https://reviews.llvm.org/D30310
llvm-svn: 296032
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 332 |
1 files changed, 249 insertions, 83 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 27d9f34850c..c2877c34f63 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -146,6 +146,9 @@ static bool IsPTXVectorType(MVT VT) { case MVT::v2i32: case MVT::v4i32: case MVT::v2i64: + case MVT::v2f16: + case MVT::v4f16: + case MVT::v8f16: // <4 x f16x2> case MVT::v2f32: case MVT::v4f32: case MVT::v2f64: @@ -170,13 +173,24 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) { EVT VT = TempVTs[i]; uint64_t Off = TempOffsets[i]; - if (VT.isVector()) - for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) { - ValueVTs.push_back(VT.getVectorElementType()); + // Split vectors into individual elements, except for v2f16, which + // we will pass as a single scalar. + if (VT.isVector()) { + unsigned NumElts = VT.getVectorNumElements(); + EVT EltVT = VT.getVectorElementType(); + // Vectors with an even number of f16 elements will be passed to + // us as an array of v2f16 elements. We must match this so we + // stay in sync with Ins/Outs. + if (EltVT == MVT::f16 && NumElts % 2 == 0) { + EltVT = MVT::v2f16; + NumElts /= 2; + } + for (unsigned j = 0; j != NumElts; ++j) { + ValueVTs.push_back(EltVT); if (Offsets) - Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize()); + Offsets->push_back(Off + j * EltVT.getStoreSize()); } - else { + } else { ValueVTs.push_back(VT); if (Offsets) Offsets->push_back(Off); @@ -331,6 +345,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, else setSchedulingPreference(Sched::Source); + auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action, + LegalizeAction NoF16Action) { + setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action); + }; + addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass); addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass); addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass); @@ -338,13 +357,20 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass); addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass); + addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass); + + // Conversion to/from FP16/FP16x2 is always legal. + setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::f16, Legal); + setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom); - setOperationAction(ISD::SETCC, MVT::f16, - STI.allowFP16Math() ? Legal : Promote); + setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote); + setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand); // Operations not directly supported by NVPTX. - setOperationAction(ISD::SELECT_CC, MVT::f16, - STI.allowFP16Math() ? Expand : Promote); + setOperationAction(ISD::SELECT_CC, MVT::f16, Expand); + setOperationAction(ISD::SELECT_CC, MVT::v2f16, Expand); setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); setOperationAction(ISD::SELECT_CC, MVT::i1, Expand); @@ -352,8 +378,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::SELECT_CC, MVT::i16, Expand); setOperationAction(ISD::SELECT_CC, MVT::i32, Expand); setOperationAction(ISD::SELECT_CC, MVT::i64, Expand); - setOperationAction(ISD::BR_CC, MVT::f16, - STI.allowFP16Math() ? Expand : Promote); + setOperationAction(ISD::BR_CC, MVT::f16, Expand); + setOperationAction(ISD::BR_CC, MVT::v2f16, Expand); setOperationAction(ISD::BR_CC, MVT::f32, Expand); setOperationAction(ISD::BR_CC, MVT::f64, Expand); setOperationAction(ISD::BR_CC, MVT::i1, Expand); @@ -493,58 +519,53 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine(ISD::SREM); setTargetDAGCombine(ISD::UREM); - if (!STI.allowFP16Math()) { - // Promote fp16 arithmetic if fp16 hardware isn't available or the - // user passed --nvptx-no-fp16-math. The flag is useful because, - // although sm_53+ GPUs have some sort of FP16 support in - // hardware, only sm_53 and sm_60 have full implementation. Others - // only have token amount of hardware and are likely to run faster - // by using fp32 units instead. - setOperationAction(ISD::FADD, MVT::f16, Promote); - setOperationAction(ISD::FMUL, MVT::f16, Promote); - setOperationAction(ISD::FSUB, MVT::f16, Promote); - setOperationAction(ISD::FMA, MVT::f16, Promote); + // setcc for f16x2 needs special handling to prevent legalizer's + // attempt to scalarize it due to v2i1 not being legal. + if (STI.allowFP16Math()) + setTargetDAGCombine(ISD::SETCC); + + // Promote fp16 arithmetic if fp16 hardware isn't available or the + // user passed --nvptx-no-fp16-math. The flag is useful because, + // although sm_53+ GPUs have some sort of FP16 support in + // hardware, only sm_53 and sm_60 have full implementation. Others + // only have token amount of hardware and are likely to run faster + // by using fp32 units instead. + for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { + setFP16OperationAction(Op, MVT::f16, Legal, Promote); + setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); } - // There's no neg.f16 instruction. + + // There's no neg.f16 instruction. Expand to (0-x). setOperationAction(ISD::FNEG, MVT::f16, Expand); + setOperationAction(ISD::FNEG, MVT::v2f16, Expand); + + // (would be) Library functions. - // Library functions. These default to Expand, but we have instructions - // for them. - setOperationAction(ISD::FCEIL, MVT::f16, Legal); - setOperationAction(ISD::FCEIL, MVT::f32, Legal); - setOperationAction(ISD::FCEIL, MVT::f64, Legal); - setOperationAction(ISD::FFLOOR, MVT::f16, Legal); - setOperationAction(ISD::FFLOOR, MVT::f32, Legal); - setOperationAction(ISD::FFLOOR, MVT::f64, Legal); - setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal); - setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal); - setOperationAction(ISD::FRINT, MVT::f16, Legal); - setOperationAction(ISD::FRINT, MVT::f32, Legal); - setOperationAction(ISD::FRINT, MVT::f64, Legal); - setOperationAction(ISD::FROUND, MVT::f16, Legal); - setOperationAction(ISD::FROUND, MVT::f32, Legal); - setOperationAction(ISD::FROUND, MVT::f64, Legal); - setOperationAction(ISD::FTRUNC, MVT::f16, Legal); - setOperationAction(ISD::FTRUNC, MVT::f32, Legal); - setOperationAction(ISD::FTRUNC, MVT::f64, Legal); - setOperationAction(ISD::FMINNUM, MVT::f32, Legal); - setOperationAction(ISD::FMINNUM, MVT::f64, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f32, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f64, Legal); + // These map to conversion instructions for scalar FP types. + for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, + ISD::FROUND, ISD::FTRUNC}) { + setOperationAction(Op, MVT::f16, Legal); + setOperationAction(Op, MVT::f32, Legal); + setOperationAction(Op, MVT::f64, Legal); + setOperationAction(Op, MVT::v2f16, Expand); + } // 'Expand' implements FCOPYSIGN without calling an external library. setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); - // FP16 does not support these nodes in hardware, but we can perform - // these ops using single-precision hardware. - setOperationAction(ISD::FDIV, MVT::f16, Promote); - setOperationAction(ISD::FREM, MVT::f16, Promote); - setOperationAction(ISD::FSQRT, MVT::f16, Promote); - setOperationAction(ISD::FSIN, MVT::f16, Promote); - setOperationAction(ISD::FCOS, MVT::f16, Promote); - setOperationAction(ISD::FABS, MVT::f16, Promote); + // These map to corresponding instructions for f32/f64. f16 must be + // promoted to f32. v2f16 is expanded to f16, which is then promoted + // to f32. + for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, + ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM}) { + setOperationAction(Op, MVT::f16, Promote); + setOperationAction(Op, MVT::f32, Legal); + setOperationAction(Op, MVT::f64, Legal); + setOperationAction(Op, MVT::v2f16, Expand); + } setOperationAction(ISD::FMINNUM, MVT::f16, Promote); setOperationAction(ISD::FMAXNUM, MVT::f16, Promote); setOperationAction(ISD::FMINNAN, MVT::f16, Promote); @@ -660,6 +681,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::FUN_SHFR_CLAMP"; case NVPTXISD::IMAD: return "NVPTXISD::IMAD"; + case NVPTXISD::SETP_F16X2: + return "NVPTXISD::SETP_F16X2"; case NVPTXISD::Dummy: return "NVPTXISD::Dummy"; case NVPTXISD::MUL_WIDE_SIGNED: @@ -1158,7 +1181,8 @@ TargetLoweringBase::LegalizeTypeAction NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const { if (VT.getVectorNumElements() != 1 && VT.getScalarType() == MVT::i1) return TypeSplitVector; - + if (VT == MVT::v2f16) + return TypeLegal; return TargetLoweringBase::getPreferredVectorAction(VT); } @@ -1723,7 +1747,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, bool ExtendIntegerRetVal = RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; - for (unsigned i = 0, e = Ins.size(); i != e; ++i) { + for (unsigned i = 0, e = VTs.size(); i != e; ++i) { bool needTruncate = false; EVT TheLoadType = VTs[i]; EVT EltType = Ins[i].VT; @@ -1765,11 +1789,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, llvm_unreachable("Invalid vector info."); } - SDValue VectorOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32), - DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), - InFlag}; + SDValue LoadOperands[] = { + Chain, DAG.getConstant(1, dl, MVT::i32), + DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InFlag}; SDValue RetVal = DAG.getMemIntrinsicNode( - Op, dl, DAG.getVTList(LoadVTs), VectorOps, TheLoadType, + Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType, MachinePointerInfo(), EltAlign); for (unsigned j = 0; j < NumElts; ++j) { @@ -1823,6 +1847,55 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { return DAG.getBuildVector(Node->getValueType(0), dl, Ops); } +// We can init constant f16x2 with a single .b32 move. Normally it +// would get lowered as two constant loads and vector-packing move. +// mov.b16 %h1, 0x4000; +// mov.b16 %h2, 0x3C00; +// mov.b32 %hh2, {%h2, %h1}; +// Instead we want just a constant move: +// mov.b32 %hh2, 0x40003C00 +// +// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0 +// generates good SASS in both cases. +SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op, + SelectionDAG &DAG) const { + //return Op; + if (!(Op->getValueType(0) == MVT::v2f16 && + isa<ConstantFPSDNode>(Op->getOperand(0)) && + isa<ConstantFPSDNode>(Op->getOperand(1)))) + return Op; + + APInt E0 = + cast<ConstantFPSDNode>(Op->getOperand(0))->getValueAPF().bitcastToAPInt(); + APInt E1 = + cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt(); + SDValue Const = + DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32); + return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const); +} + +SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, + SelectionDAG &DAG) const { + SDValue Index = Op->getOperand(1); + // Constant index will be matched by tablegen. + if (isa<ConstantSDNode>(Index.getNode())) + return Op; + + // Extract individual elements and select one of them. + SDValue Vector = Op->getOperand(0); + EVT VectorVT = Vector.getValueType(); + assert(VectorVT == MVT::v2f16 && "Unexpected vector type."); + EVT EltVT = VectorVT.getVectorElementType(); + + SDLoc dl(Op.getNode()); + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector, + DAG.getIntPtrConstant(0, dl)); + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector, + DAG.getIntPtrConstant(1, dl)); + return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1, + ISD::CondCode::SETEQ); +} + /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift /// amount, or @@ -1956,8 +2029,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::INTRINSIC_W_CHAIN: return Op; case ISD::BUILD_VECTOR: + return LowerBUILD_VECTOR(Op, DAG); case ISD::EXTRACT_SUBVECTOR: return Op; + case ISD::EXTRACT_VECTOR_ELT: + return LowerEXTRACT_VECTOR_ELT(Op, DAG); case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG); case ISD::STORE: @@ -2054,12 +2130,15 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { case MVT::v2i16: case MVT::v2i32: case MVT::v2i64: + case MVT::v2f16: case MVT::v2f32: case MVT::v2f64: case MVT::v4i8: case MVT::v4i16: case MVT::v4i32: + case MVT::v4f16: case MVT::v4f32: + case MVT::v8f16: // <4 x f16x2> // This is a "native" vector type break; } @@ -2090,6 +2169,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { if (EltVT.getSizeInBits() < 16) NeedExt = true; + bool StoreF16x2 = false; switch (NumElts) { default: return SDValue(); @@ -2099,6 +2179,14 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { case 4: Opcode = NVPTXISD::StoreV4; break; + case 8: + // v8f16 is a special case. PTX doesn't have st.v8.f16 + // instruction. Instead, we split the vector into v2f16 chunks and + // store them with st.v4.b32. + assert(EltVT == MVT::f16 && "Wrong type for the vector."); + Opcode = NVPTXISD::StoreV4; + StoreF16x2 = true; + break; } SmallVector<SDValue, 8> Ops; @@ -2106,23 +2194,36 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { // First is the chain Ops.push_back(N->getOperand(0)); - // Then the split values - for (unsigned i = 0; i < NumElts; ++i) { - SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, - DAG.getIntPtrConstant(i, DL)); - if (NeedExt) - ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal); - Ops.push_back(ExtVal); + if (StoreF16x2) { + // Combine f16,f16 -> v2f16 + NumElts /= 2; + for (unsigned i = 0; i < NumElts; ++i) { + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + DAG.getIntPtrConstant(i * 2, DL)); + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + DAG.getIntPtrConstant(i * 2 + 1, DL)); + SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1); + Ops.push_back(V2); + } + } else { + // Then the split values + for (unsigned i = 0; i < NumElts; ++i) { + SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, + DAG.getIntPtrConstant(i, DL)); + if (NeedExt) + ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal); + Ops.push_back(ExtVal); + } } // Then any remaining arguments Ops.append(N->op_begin() + 2, N->op_end()); - SDValue NewSt = DAG.getMemIntrinsicNode( - Opcode, DL, DAG.getVTList(MVT::Other), Ops, - MemSD->getMemoryVT(), MemSD->getMemOperand()); + SDValue NewSt = + DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops, + MemSD->getMemoryVT(), MemSD->getMemOperand()); - //return DCI.CombineTo(N, NewSt, true); + // return DCI.CombineTo(N, NewSt, true); return NewSt; } @@ -2282,7 +2383,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( SmallVector<EVT, 16> VTs; SmallVector<uint64_t, 16> Offsets; ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0); - assert(VTs.size() > 0 && "empty aggregate type not expected"); + assert(VTs.size() > 0 && "Unexpected empty type."); auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlignment(Ty)); @@ -2299,7 +2400,15 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( unsigned NumElts = parti - VecIdx + 1; EVT EltVT = VTs[parti]; // i1 is loaded/stored as i8. - EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT; + EVT LoadVT = EltVT; + if (EltVT == MVT::i1) + LoadVT = MVT::i8; + else if (EltVT == MVT::v2f16) + // getLoad needs a vector type, but it can't handle + // vectors which contain v2f16 elements. So we must load + // using i32 here and then bitcast back. + LoadVT = MVT::i32; + EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts); SDValue VecAddr = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, @@ -2319,15 +2428,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // We've loaded i1 as an i8 and now must truncate it back to i1 if (EltVT == MVT::i1) Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt); - // Extend the element if necesary (e.g an i8 is loaded + // v2f16 was loaded as an i32. Now we must bitcast it back. + else if (EltVT == MVT::v2f16) + Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt); + // Extend the element if necesary (e.g. an i8 is loaded // into an i16 register) - if (Ins[InsIdx].VT.getSizeInBits() > LoadVT.getSizeInBits()) { + if (Ins[InsIdx].VT.isInteger() && + Ins[InsIdx].VT.getSizeInBits() > LoadVT.getSizeInBits()) { unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt); } InVals.push_back(Elt); } + // Reset vector tracking state. VecIdx = -1; } @@ -2399,7 +2513,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; SmallVector<SDValue, 6> StoreOperands; - for (unsigned i = 0, e = Outs.size(); i != e; ++i) { + for (unsigned i = 0, e = VTs.size(); i != e; ++i) { // New load/store. Record chain and offset operands. if (VectorInfo[i] & PVF_FIRST) { assert(StoreOperands.empty() && "Orphaned operand list."); @@ -4168,6 +4282,27 @@ static SDValue PerformSHLCombine(SDNode *N, return SDValue(); } +static SDValue PerformSETCCCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + EVT CCType = N->getValueType(0); + SDValue A = N->getOperand(0); + SDValue B = N->getOperand(1); + + if (CCType != MVT::v2i1 || A.getValueType() != MVT::v2f16) + return SDValue(); + + SDLoc DL(N); + // setp.f16x2 returns two scalar predicates, which we need to + // convert back to v2i1. The returned result will be scalarized by + // the legalizer, but the comparison will remain a single vector + // instruction. + SDValue CCNode = DCI.DAG.getNode(NVPTXISD::SETP_F16X2, DL, + DCI.DAG.getVTList(MVT::i1, MVT::i1), + {A, B, N->getOperand(2)}); + return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0), + CCNode.getValue(1)); +} + SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { CodeGenOpt::Level OptLevel = getTargetMachine().getOptLevel(); @@ -4185,6 +4320,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::UREM: case ISD::SREM: return PerformREMCombine(N, DCI, OptLevel); + case ISD::SETCC: + return PerformSETCCCombine(N, DCI); } return SDValue(); } @@ -4208,12 +4345,15 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, case MVT::v2i16: case MVT::v2i32: case MVT::v2i64: + case MVT::v2f16: case MVT::v2f32: case MVT::v2f64: case MVT::v4i8: case MVT::v4i16: case MVT::v4i32: + case MVT::v4f16: case MVT::v4f32: + case MVT::v8f16: // <4 x f16x2> // This is a "native" vector type break; } @@ -4247,6 +4387,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, unsigned Opcode = 0; SDVTList LdResVTs; + bool LoadF16x2 = false; switch (NumElts) { default: @@ -4261,6 +4402,18 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, LdResVTs = DAG.getVTList(ListVTs); break; } + case 8: { + // v8f16 is a special case. PTX doesn't have ld.v8.f16 + // instruction. Instead, we split the vector into v2f16 chunks and + // load them with ld.v4.b32. + assert(EltVT == MVT::f16 && "Unsupported v8 vector type."); + LoadF16x2 = true; + Opcode = NVPTXISD::LoadV4; + EVT ListVTs[] = {MVT::v2f16, MVT::v2f16, MVT::v2f16, MVT::v2f16, + MVT::Other}; + LdResVTs = DAG.getVTList(ListVTs); + break; + } } // Copy regular operands @@ -4274,13 +4427,26 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, LD->getMemoryVT(), LD->getMemOperand()); - SmallVector<SDValue, 4> ScalarRes; - - for (unsigned i = 0; i < NumElts; ++i) { - SDValue Res = NewLD.getValue(i); - if (NeedTrunc) - Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); - ScalarRes.push_back(Res); + SmallVector<SDValue, 8> ScalarRes; + if (LoadF16x2) { + // Split v2f16 subvectors back into individual elements. + NumElts /= 2; + for (unsigned i = 0; i < NumElts; ++i) { + SDValue SubVector = NewLD.getValue(i); + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector, + DAG.getIntPtrConstant(0, DL)); + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector, + DAG.getIntPtrConstant(1, DL)); + ScalarRes.push_back(E0); + ScalarRes.push_back(E1); + } + } else { + for (unsigned i = 0; i < NumElts; ++i) { + SDValue Res = NewLD.getValue(i); + if (NeedTrunc) + Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); + ScalarRes.push_back(Res); + } } SDValue LoadChain = NewLD.getValue(NumElts); |