summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
diff options
context:
space:
mode:
authorArtem Belevich <tra@google.com>2017-02-23 22:38:24 +0000
committerArtem Belevich <tra@google.com>2017-02-23 22:38:24 +0000
commit620db1f3dd08ebbba71b0e16f83c11323e04bc05 (patch)
treeec38800ee0d0ae0282c4d5e47a193336c80b82a6 /llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
parent063a56e81c89dfbf7896c71c8cd849b911eb7098 (diff)
downloadbcm5719-llvm-620db1f3dd08ebbba71b0e16f83c11323e04bc05.tar.gz
bcm5719-llvm-620db1f3dd08ebbba71b0e16f83c11323e04bc05.zip
[NVPTX] Added support for .f16x2 instructions.
This patch enables support for .f16x2 operations. Added new register type Float16x2. Added support for .f16x2 instructions. Added handling of vectorized loads/stores of v2f16 values. Differential Revision: https://reviews.llvm.org/D30057 Differential Revision: https://reviews.llvm.org/D30310 llvm-svn: 296032
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp332
1 files changed, 249 insertions, 83 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 27d9f34850c..c2877c34f63 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -146,6 +146,9 @@ static bool IsPTXVectorType(MVT VT) {
case MVT::v2i32:
case MVT::v4i32:
case MVT::v2i64:
+ case MVT::v2f16:
+ case MVT::v4f16:
+ case MVT::v8f16: // <4 x f16x2>
case MVT::v2f32:
case MVT::v4f32:
case MVT::v2f64:
@@ -170,13 +173,24 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
EVT VT = TempVTs[i];
uint64_t Off = TempOffsets[i];
- if (VT.isVector())
- for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
- ValueVTs.push_back(VT.getVectorElementType());
+ // Split vectors into individual elements, except for v2f16, which
+ // we will pass as a single scalar.
+ if (VT.isVector()) {
+ unsigned NumElts = VT.getVectorNumElements();
+ EVT EltVT = VT.getVectorElementType();
+ // Vectors with an even number of f16 elements will be passed to
+ // us as an array of v2f16 elements. We must match this so we
+ // stay in sync with Ins/Outs.
+ if (EltVT == MVT::f16 && NumElts % 2 == 0) {
+ EltVT = MVT::v2f16;
+ NumElts /= 2;
+ }
+ for (unsigned j = 0; j != NumElts; ++j) {
+ ValueVTs.push_back(EltVT);
if (Offsets)
- Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
+ Offsets->push_back(Off + j * EltVT.getStoreSize());
}
- else {
+ } else {
ValueVTs.push_back(VT);
if (Offsets)
Offsets->push_back(Off);
@@ -331,6 +345,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
else
setSchedulingPreference(Sched::Source);
+ auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
+ LegalizeAction NoF16Action) {
+ setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
+ };
+
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
@@ -338,13 +357,20 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
+ addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
+
+ // Conversion to/from FP16/FP16x2 is always legal.
+ setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
+ setOperationAction(ISD::FP_TO_SINT, MVT::f16, Legal);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom);
- setOperationAction(ISD::SETCC, MVT::f16,
- STI.allowFP16Math() ? Legal : Promote);
+ setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
+ setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
// Operations not directly supported by NVPTX.
- setOperationAction(ISD::SELECT_CC, MVT::f16,
- STI.allowFP16Math() ? Expand : Promote);
+ setOperationAction(ISD::SELECT_CC, MVT::f16, Expand);
+ setOperationAction(ISD::SELECT_CC, MVT::v2f16, Expand);
setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
setOperationAction(ISD::SELECT_CC, MVT::i1, Expand);
@@ -352,8 +378,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::SELECT_CC, MVT::i16, Expand);
setOperationAction(ISD::SELECT_CC, MVT::i32, Expand);
setOperationAction(ISD::SELECT_CC, MVT::i64, Expand);
- setOperationAction(ISD::BR_CC, MVT::f16,
- STI.allowFP16Math() ? Expand : Promote);
+ setOperationAction(ISD::BR_CC, MVT::f16, Expand);
+ setOperationAction(ISD::BR_CC, MVT::v2f16, Expand);
setOperationAction(ISD::BR_CC, MVT::f32, Expand);
setOperationAction(ISD::BR_CC, MVT::f64, Expand);
setOperationAction(ISD::BR_CC, MVT::i1, Expand);
@@ -493,58 +519,53 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setTargetDAGCombine(ISD::SREM);
setTargetDAGCombine(ISD::UREM);
- if (!STI.allowFP16Math()) {
- // Promote fp16 arithmetic if fp16 hardware isn't available or the
- // user passed --nvptx-no-fp16-math. The flag is useful because,
- // although sm_53+ GPUs have some sort of FP16 support in
- // hardware, only sm_53 and sm_60 have full implementation. Others
- // only have token amount of hardware and are likely to run faster
- // by using fp32 units instead.
- setOperationAction(ISD::FADD, MVT::f16, Promote);
- setOperationAction(ISD::FMUL, MVT::f16, Promote);
- setOperationAction(ISD::FSUB, MVT::f16, Promote);
- setOperationAction(ISD::FMA, MVT::f16, Promote);
+ // setcc for f16x2 needs special handling to prevent legalizer's
+ // attempt to scalarize it due to v2i1 not being legal.
+ if (STI.allowFP16Math())
+ setTargetDAGCombine(ISD::SETCC);
+
+ // Promote fp16 arithmetic if fp16 hardware isn't available or the
+ // user passed --nvptx-no-fp16-math. The flag is useful because,
+ // although sm_53+ GPUs have some sort of FP16 support in
+ // hardware, only sm_53 and sm_60 have full implementation. Others
+ // only have token amount of hardware and are likely to run faster
+ // by using fp32 units instead.
+ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
+ setFP16OperationAction(Op, MVT::f16, Legal, Promote);
+ setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
}
- // There's no neg.f16 instruction.
+
+ // There's no neg.f16 instruction. Expand to (0-x).
setOperationAction(ISD::FNEG, MVT::f16, Expand);
+ setOperationAction(ISD::FNEG, MVT::v2f16, Expand);
+
+ // (would be) Library functions.
- // Library functions. These default to Expand, but we have instructions
- // for them.
- setOperationAction(ISD::FCEIL, MVT::f16, Legal);
- setOperationAction(ISD::FCEIL, MVT::f32, Legal);
- setOperationAction(ISD::FCEIL, MVT::f64, Legal);
- setOperationAction(ISD::FFLOOR, MVT::f16, Legal);
- setOperationAction(ISD::FFLOOR, MVT::f32, Legal);
- setOperationAction(ISD::FFLOOR, MVT::f64, Legal);
- setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
- setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal);
- setOperationAction(ISD::FRINT, MVT::f16, Legal);
- setOperationAction(ISD::FRINT, MVT::f32, Legal);
- setOperationAction(ISD::FRINT, MVT::f64, Legal);
- setOperationAction(ISD::FROUND, MVT::f16, Legal);
- setOperationAction(ISD::FROUND, MVT::f32, Legal);
- setOperationAction(ISD::FROUND, MVT::f64, Legal);
- setOperationAction(ISD::FTRUNC, MVT::f16, Legal);
- setOperationAction(ISD::FTRUNC, MVT::f32, Legal);
- setOperationAction(ISD::FTRUNC, MVT::f64, Legal);
- setOperationAction(ISD::FMINNUM, MVT::f32, Legal);
- setOperationAction(ISD::FMINNUM, MVT::f64, Legal);
- setOperationAction(ISD::FMAXNUM, MVT::f32, Legal);
- setOperationAction(ISD::FMAXNUM, MVT::f64, Legal);
+ // These map to conversion instructions for scalar FP types.
+ for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
+ ISD::FROUND, ISD::FTRUNC}) {
+ setOperationAction(Op, MVT::f16, Legal);
+ setOperationAction(Op, MVT::f32, Legal);
+ setOperationAction(Op, MVT::f64, Legal);
+ setOperationAction(Op, MVT::v2f16, Expand);
+ }
// 'Expand' implements FCOPYSIGN without calling an external library.
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
- // FP16 does not support these nodes in hardware, but we can perform
- // these ops using single-precision hardware.
- setOperationAction(ISD::FDIV, MVT::f16, Promote);
- setOperationAction(ISD::FREM, MVT::f16, Promote);
- setOperationAction(ISD::FSQRT, MVT::f16, Promote);
- setOperationAction(ISD::FSIN, MVT::f16, Promote);
- setOperationAction(ISD::FCOS, MVT::f16, Promote);
- setOperationAction(ISD::FABS, MVT::f16, Promote);
+ // These map to corresponding instructions for f32/f64. f16 must be
+ // promoted to f32. v2f16 is expanded to f16, which is then promoted
+ // to f32.
+ for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS,
+ ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM}) {
+ setOperationAction(Op, MVT::f16, Promote);
+ setOperationAction(Op, MVT::f32, Legal);
+ setOperationAction(Op, MVT::f64, Legal);
+ setOperationAction(Op, MVT::v2f16, Expand);
+ }
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
setOperationAction(ISD::FMINNAN, MVT::f16, Promote);
@@ -660,6 +681,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "NVPTXISD::FUN_SHFR_CLAMP";
case NVPTXISD::IMAD:
return "NVPTXISD::IMAD";
+ case NVPTXISD::SETP_F16X2:
+ return "NVPTXISD::SETP_F16X2";
case NVPTXISD::Dummy:
return "NVPTXISD::Dummy";
case NVPTXISD::MUL_WIDE_SIGNED:
@@ -1158,7 +1181,8 @@ TargetLoweringBase::LegalizeTypeAction
NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const {
if (VT.getVectorNumElements() != 1 && VT.getScalarType() == MVT::i1)
return TypeSplitVector;
-
+ if (VT == MVT::v2f16)
+ return TypeLegal;
return TargetLoweringBase::getPreferredVectorAction(VT);
}
@@ -1723,7 +1747,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
bool ExtendIntegerRetVal =
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
- for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
+ for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
bool needTruncate = false;
EVT TheLoadType = VTs[i];
EVT EltType = Ins[i].VT;
@@ -1765,11 +1789,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
llvm_unreachable("Invalid vector info.");
}
- SDValue VectorOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(Offsets[VecIdx], dl, MVT::i32),
- InFlag};
+ SDValue LoadOperands[] = {
+ Chain, DAG.getConstant(1, dl, MVT::i32),
+ DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InFlag};
SDValue RetVal = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(LoadVTs), VectorOps, TheLoadType,
+ Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
MachinePointerInfo(), EltAlign);
for (unsigned j = 0; j < NumElts; ++j) {
@@ -1823,6 +1847,55 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}
+// We can init constant f16x2 with a single .b32 move. Normally it
+// would get lowered as two constant loads and vector-packing move.
+// mov.b16 %h1, 0x4000;
+// mov.b16 %h2, 0x3C00;
+// mov.b32 %hh2, {%h2, %h1};
+// Instead we want just a constant move:
+// mov.b32 %hh2, 0x40003C00
+//
+// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
+// generates good SASS in both cases.
+SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
+ SelectionDAG &DAG) const {
+ //return Op;
+ if (!(Op->getValueType(0) == MVT::v2f16 &&
+ isa<ConstantFPSDNode>(Op->getOperand(0)) &&
+ isa<ConstantFPSDNode>(Op->getOperand(1))))
+ return Op;
+
+ APInt E0 =
+ cast<ConstantFPSDNode>(Op->getOperand(0))->getValueAPF().bitcastToAPInt();
+ APInt E1 =
+ cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
+ SDValue Const =
+ DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
+ return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+}
+
+SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue Index = Op->getOperand(1);
+ // Constant index will be matched by tablegen.
+ if (isa<ConstantSDNode>(Index.getNode()))
+ return Op;
+
+ // Extract individual elements and select one of them.
+ SDValue Vector = Op->getOperand(0);
+ EVT VectorVT = Vector.getValueType();
+ assert(VectorVT == MVT::v2f16 && "Unexpected vector type.");
+ EVT EltVT = VectorVT.getVectorElementType();
+
+ SDLoc dl(Op.getNode());
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
+ DAG.getIntPtrConstant(0, dl));
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
+ DAG.getIntPtrConstant(1, dl));
+ return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
+ ISD::CondCode::SETEQ);
+}
+
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
/// amount, or
@@ -1956,8 +2029,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::INTRINSIC_W_CHAIN:
return Op;
case ISD::BUILD_VECTOR:
+ return LowerBUILD_VECTOR(Op, DAG);
case ISD::EXTRACT_SUBVECTOR:
return Op;
+ case ISD::EXTRACT_VECTOR_ELT:
+ return LowerEXTRACT_VECTOR_ELT(Op, DAG);
case ISD::CONCAT_VECTORS:
return LowerCONCAT_VECTORS(Op, DAG);
case ISD::STORE:
@@ -2054,12 +2130,15 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
case MVT::v2i16:
case MVT::v2i32:
case MVT::v2i64:
+ case MVT::v2f16:
case MVT::v2f32:
case MVT::v2f64:
case MVT::v4i8:
case MVT::v4i16:
case MVT::v4i32:
+ case MVT::v4f16:
case MVT::v4f32:
+ case MVT::v8f16: // <4 x f16x2>
// This is a "native" vector type
break;
}
@@ -2090,6 +2169,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
if (EltVT.getSizeInBits() < 16)
NeedExt = true;
+ bool StoreF16x2 = false;
switch (NumElts) {
default:
return SDValue();
@@ -2099,6 +2179,14 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
case 4:
Opcode = NVPTXISD::StoreV4;
break;
+ case 8:
+ // v8f16 is a special case. PTX doesn't have st.v8.f16
+ // instruction. Instead, we split the vector into v2f16 chunks and
+ // store them with st.v4.b32.
+ assert(EltVT == MVT::f16 && "Wrong type for the vector.");
+ Opcode = NVPTXISD::StoreV4;
+ StoreF16x2 = true;
+ break;
}
SmallVector<SDValue, 8> Ops;
@@ -2106,23 +2194,36 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
// First is the chain
Ops.push_back(N->getOperand(0));
- // Then the split values
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
- DAG.getIntPtrConstant(i, DL));
- if (NeedExt)
- ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
- Ops.push_back(ExtVal);
+ if (StoreF16x2) {
+ // Combine f16,f16 -> v2f16
+ NumElts /= 2;
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ DAG.getIntPtrConstant(i * 2, DL));
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ DAG.getIntPtrConstant(i * 2 + 1, DL));
+ SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
+ Ops.push_back(V2);
+ }
+ } else {
+ // Then the split values
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
+ DAG.getIntPtrConstant(i, DL));
+ if (NeedExt)
+ ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
+ Ops.push_back(ExtVal);
+ }
}
// Then any remaining arguments
Ops.append(N->op_begin() + 2, N->op_end());
- SDValue NewSt = DAG.getMemIntrinsicNode(
- Opcode, DL, DAG.getVTList(MVT::Other), Ops,
- MemSD->getMemoryVT(), MemSD->getMemOperand());
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
- //return DCI.CombineTo(N, NewSt, true);
+ // return DCI.CombineTo(N, NewSt, true);
return NewSt;
}
@@ -2282,7 +2383,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
- assert(VTs.size() > 0 && "empty aggregate type not expected");
+ assert(VTs.size() > 0 && "Unexpected empty type.");
auto VectorInfo =
VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlignment(Ty));
@@ -2299,7 +2400,15 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
unsigned NumElts = parti - VecIdx + 1;
EVT EltVT = VTs[parti];
// i1 is loaded/stored as i8.
- EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
+ EVT LoadVT = EltVT;
+ if (EltVT == MVT::i1)
+ LoadVT = MVT::i8;
+ else if (EltVT == MVT::v2f16)
+ // getLoad needs a vector type, but it can't handle
+ // vectors which contain v2f16 elements. So we must load
+ // using i32 here and then bitcast back.
+ LoadVT = MVT::i32;
+
EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
SDValue VecAddr =
DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
@@ -2319,15 +2428,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// We've loaded i1 as an i8 and now must truncate it back to i1
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
- // Extend the element if necesary (e.g an i8 is loaded
+ // v2f16 was loaded as an i32. Now we must bitcast it back.
+ else if (EltVT == MVT::v2f16)
+ Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt);
+ // Extend the element if necesary (e.g. an i8 is loaded
// into an i16 register)
- if (Ins[InsIdx].VT.getSizeInBits() > LoadVT.getSizeInBits()) {
+ if (Ins[InsIdx].VT.isInteger() &&
+ Ins[InsIdx].VT.getSizeInBits() > LoadVT.getSizeInBits()) {
unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;
Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
}
InVals.push_back(Elt);
}
+
// Reset vector tracking state.
VecIdx = -1;
}
@@ -2399,7 +2513,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
SmallVector<SDValue, 6> StoreOperands;
- for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
+ for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
// New load/store. Record chain and offset operands.
if (VectorInfo[i] & PVF_FIRST) {
assert(StoreOperands.empty() && "Orphaned operand list.");
@@ -4168,6 +4282,27 @@ static SDValue PerformSHLCombine(SDNode *N,
return SDValue();
}
+static SDValue PerformSETCCCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ EVT CCType = N->getValueType(0);
+ SDValue A = N->getOperand(0);
+ SDValue B = N->getOperand(1);
+
+ if (CCType != MVT::v2i1 || A.getValueType() != MVT::v2f16)
+ return SDValue();
+
+ SDLoc DL(N);
+ // setp.f16x2 returns two scalar predicates, which we need to
+ // convert back to v2i1. The returned result will be scalarized by
+ // the legalizer, but the comparison will remain a single vector
+ // instruction.
+ SDValue CCNode = DCI.DAG.getNode(NVPTXISD::SETP_F16X2, DL,
+ DCI.DAG.getVTList(MVT::i1, MVT::i1),
+ {A, B, N->getOperand(2)});
+ return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
+ CCNode.getValue(1));
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOpt::Level OptLevel = getTargetMachine().getOptLevel();
@@ -4185,6 +4320,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::UREM:
case ISD::SREM:
return PerformREMCombine(N, DCI, OptLevel);
+ case ISD::SETCC:
+ return PerformSETCCCombine(N, DCI);
}
return SDValue();
}
@@ -4208,12 +4345,15 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
case MVT::v2i16:
case MVT::v2i32:
case MVT::v2i64:
+ case MVT::v2f16:
case MVT::v2f32:
case MVT::v2f64:
case MVT::v4i8:
case MVT::v4i16:
case MVT::v4i32:
+ case MVT::v4f16:
case MVT::v4f32:
+ case MVT::v8f16: // <4 x f16x2>
// This is a "native" vector type
break;
}
@@ -4247,6 +4387,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
unsigned Opcode = 0;
SDVTList LdResVTs;
+ bool LoadF16x2 = false;
switch (NumElts) {
default:
@@ -4261,6 +4402,18 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
LdResVTs = DAG.getVTList(ListVTs);
break;
}
+ case 8: {
+ // v8f16 is a special case. PTX doesn't have ld.v8.f16
+ // instruction. Instead, we split the vector into v2f16 chunks and
+ // load them with ld.v4.b32.
+ assert(EltVT == MVT::f16 && "Unsupported v8 vector type.");
+ LoadF16x2 = true;
+ Opcode = NVPTXISD::LoadV4;
+ EVT ListVTs[] = {MVT::v2f16, MVT::v2f16, MVT::v2f16, MVT::v2f16,
+ MVT::Other};
+ LdResVTs = DAG.getVTList(ListVTs);
+ break;
+ }
}
// Copy regular operands
@@ -4274,13 +4427,26 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
LD->getMemoryVT(),
LD->getMemOperand());
- SmallVector<SDValue, 4> ScalarRes;
-
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue Res = NewLD.getValue(i);
- if (NeedTrunc)
- Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
- ScalarRes.push_back(Res);
+ SmallVector<SDValue, 8> ScalarRes;
+ if (LoadF16x2) {
+ // Split v2f16 subvectors back into individual elements.
+ NumElts /= 2;
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue SubVector = NewLD.getValue(i);
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector,
+ DAG.getIntPtrConstant(0, DL));
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector,
+ DAG.getIntPtrConstant(1, DL));
+ ScalarRes.push_back(E0);
+ ScalarRes.push_back(E1);
+ }
+ } else {
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue Res = NewLD.getValue(i);
+ if (NeedTrunc)
+ Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
+ ScalarRes.push_back(Res);
+ }
}
SDValue LoadChain = NewLD.getValue(NumElts);
OpenPOWER on IntegriCloud