diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 123 |
1 files changed, 67 insertions, 56 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 538a116902e..e87e7edab3c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1306,9 +1306,13 @@ void X86TargetLowering::resetOperationActions() { addRegisterClass(MVT::v8i64, &X86::VR512RegClass); addRegisterClass(MVT::v8f64, &X86::VR512RegClass); + addRegisterClass(MVT::i1, &X86::VK1RegClass); addRegisterClass(MVT::v8i1, &X86::VK8RegClass); addRegisterClass(MVT::v16i1, &X86::VK16RegClass); + setOperationAction(ISD::BR_CC, MVT::i1, Expand); + setOperationAction(ISD::SETCC, MVT::i1, Custom); + setOperationAction(ISD::XOR, MVT::i1, Legal); setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, Legal); setOperationAction(ISD::LOAD, MVT::v16f32, Legal); setOperationAction(ISD::LOAD, MVT::v8f64, Legal); @@ -1376,6 +1380,8 @@ void X86TargetLowering::resetOperationActions() { setOperationAction(ISD::MUL, MVT::v8i64, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v8i1, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v16i1, Custom); setOperationAction(ISD::BUILD_VECTOR, MVT::v8i1, Custom); setOperationAction(ISD::BUILD_VECTOR, MVT::v16i1, Custom); setOperationAction(ISD::SELECT, MVT::v8f64, Custom); @@ -2221,6 +2227,8 @@ X86TargetLowering::LowerFormalArguments(SDValue Chain, RC = &X86::VR128RegClass; else if (RegVT == MVT::x86mmx) RC = &X86::VR64RegClass; + else if (RegVT == MVT::i1) + RC = &X86::VK1RegClass; else if (RegVT == MVT::v8i1) RC = &X86::VK8RegClass; else if (RegVT == MVT::v16i1) @@ -7669,6 +7677,39 @@ static SDValue LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op, SelectionDAG &DAG) { return SDValue(); } +/// Extract one bit from mask vector, like v16i1 or v8i1. +/// AVX-512 feature. +static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG) { + SDValue Vec = Op.getOperand(0); + SDLoc dl(Vec); + MVT VecVT = Vec.getSimpleValueType(); + SDValue Idx = Op.getOperand(1); + MVT EltVT = Op.getSimpleValueType(); + + assert((EltVT == MVT::i1) && "Unexpected operands in ExtractBitFromMaskVector"); + + // variable index can't be handled in mask registers, + // extend vector to VR512 + if (!isa<ConstantSDNode>(Idx)) { + MVT ExtVT = (VecVT == MVT::v8i1 ? MVT::v8i64 : MVT::v16i32); + SDValue Ext = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Vec); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + ExtVT.getVectorElementType(), Ext, Idx); + return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); + } + + unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); + if (IdxVal) { + unsigned MaxSift = VecVT.getSizeInBits() - 1; + Vec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, Vec, + DAG.getConstant(MaxSift - IdxVal, MVT::i8)); + Vec = DAG.getNode(X86ISD::VSRLI, dl, VecVT, Vec, + DAG.getConstant(MaxSift, MVT::i8)); + } + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i1, Vec, + DAG.getIntPtrConstant(0)); +} + SDValue X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { @@ -7676,6 +7717,10 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SDValue Vec = Op.getOperand(0); MVT VecVT = Vec.getSimpleValueType(); SDValue Idx = Op.getOperand(1); + + if (Op.getSimpleValueType() == MVT::i1) + return ExtractBitFromMaskVector(Op, DAG); + if (!isa<ConstantSDNode>(Idx)) { if (VecVT.is512BitVector() || (VecVT.is256BitVector() && Subtarget->hasInt256() && @@ -9681,11 +9726,17 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, /// equivalent. SDValue X86TargetLowering::EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC, SelectionDAG &DAG) const { - if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op1)) + SDLoc dl(Op0); + if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op1)) { if (C->getAPIntValue() == 0) return EmitTest(Op0, X86CC, DAG); - SDLoc dl(Op0); + if (Op0.getValueType() == MVT::i1) { + Op0 = DAG.getNode(ISD::XOR, dl, MVT::i1, Op0, DAG.getConstant(-1, MVT::i1)); + return DAG.getNode(X86ISD::CMP, dl, MVT::i1, Op0, Op0); + } + } + if ((Op0.getValueType() == MVT::i8 || Op0.getValueType() == MVT::i16 || Op0.getValueType() == MVT::i32 || Op0.getValueType() == MVT::i64)) { // Do the comparison at i32 if it's smaller. This avoids subregister @@ -10121,7 +10172,8 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (VT.isVector()) return LowerVSETCC(Op, Subtarget, DAG); - assert(VT == MVT::i8 && "SetCC type must be 8-bit integer"); + assert((VT == MVT::i8 || (Subtarget->hasAVX512() && VT == MVT::i1)) + && "SetCC type must be 8-bit or 1-bit integer"); SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); SDLoc dl(Op); @@ -10234,8 +10286,12 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { cast<CondCodeSDNode>(Cond.getOperand(2))->get(), CondOp0, CondOp1); if (SSECC != 8) { - unsigned Opcode = VT == MVT::f32 ? X86ISD::FSETCCss : X86ISD::FSETCCsd; - SDValue Cmp = DAG.getNode(Opcode, DL, VT, CondOp0, CondOp1, + if (Subtarget->hasAVX512()) { + SDValue Cmp = DAG.getNode(X86ISD::FSETCC, DL, MVT::i1, CondOp0, CondOp1, + DAG.getConstant(SSECC, MVT::i8)); + return DAG.getNode(X86ISD::SELECT, DL, VT, Cmp, Op1, Op2); + } + SDValue Cmp = DAG.getNode(X86ISD::FSETCC, DL, VT, CondOp0, CondOp1, DAG.getConstant(SSECC, MVT::i8)); SDValue AndN = DAG.getNode(X86ISD::FANDN, DL, VT, Cmp, Op2); SDValue And = DAG.getNode(X86ISD::FAND, DL, VT, Cmp, Op1); @@ -13774,8 +13830,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::CMPMU: return "X86ISD::CMPMU"; case X86ISD::SETCC: return "X86ISD::SETCC"; case X86ISD::SETCC_CARRY: return "X86ISD::SETCC_CARRY"; - case X86ISD::FSETCCsd: return "X86ISD::FSETCCsd"; - case X86ISD::FSETCCss: return "X86ISD::FSETCCss"; + case X86ISD::FSETCC: return "X86ISD::FSETCC"; case X86ISD::CMOV: return "X86ISD::CMOV"; case X86ISD::BRCOND: return "X86ISD::BRCOND"; case X86ISD::RET_FLAG: return "X86ISD::RET_FLAG"; @@ -13870,7 +13925,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::TESTP: return "X86ISD::TESTP"; case X86ISD::TESTM: return "X86ISD::TESTM"; case X86ISD::KORTEST: return "X86ISD::KORTEST"; - case X86ISD::KTEST: return "X86ISD::KTEST"; case X86ISD::PALIGNR: return "X86ISD::PALIGNR"; case X86ISD::PSHUFD: return "X86ISD::PSHUFD"; case X86ISD::PSHUFHW: return "X86ISD::PSHUFHW"; @@ -16420,44 +16474,6 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, EltNo); } -/// Extract one bit from mask vector, like v16i1 or v8i1. -/// AVX-512 feature. -static SDValue ExtractBitFromMaskVector(SDNode *N, SelectionDAG &DAG) { - SDValue Vec = N->getOperand(0); - SDLoc dl(Vec); - MVT VecVT = Vec.getSimpleValueType(); - SDValue Idx = N->getOperand(1); - MVT EltVT = N->getSimpleValueType(0); - - assert((VecVT.getVectorElementType() == MVT::i1 && EltVT == MVT::i8) || - "Unexpected operands in ExtractBitFromMaskVector"); - - // variable index - if (!isa<ConstantSDNode>(Idx)) { - MVT ExtVT = (VecVT == MVT::v8i1 ? MVT::v8i64 : MVT::v16i32); - SDValue Ext = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Vec); - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - ExtVT.getVectorElementType(), Ext); - return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); - } - - unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - - MVT ScalarVT = MVT::getIntegerVT(VecVT.getSizeInBits()); - unsigned MaxShift = VecVT.getSizeInBits() - 1; - Vec = DAG.getNode(ISD::BITCAST, dl, ScalarVT, Vec); - Vec = DAG.getNode(ISD::SHL, dl, ScalarVT, Vec, - DAG.getConstant(MaxShift - IdxVal, ScalarVT)); - Vec = DAG.getNode(ISD::SRL, dl, ScalarVT, Vec, - DAG.getConstant(MaxShift, ScalarVT)); - - if (VecVT == MVT::v16i1) { - Vec = DAG.getNode(ISD::BITCAST, dl, MVT::i16, Vec); - return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Vec); - } - return DAG.getNode(ISD::BITCAST, dl, MVT::i8, Vec); -} - /// PerformEXTRACT_VECTOR_ELTCombine - Detect vector gather/scatter index /// generation and convert it from being a bunch of shuffles and extracts /// to a simple store and scalar loads to extract the elements. @@ -16469,10 +16485,6 @@ static SDValue PerformEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, SDValue InputVector = N->getOperand(0); - if (InputVector.getValueType().getVectorElementType() == MVT::i1 && - !DCI.isBeforeLegalize()) - return ExtractBitFromMaskVector(N, DAG); - // Detect whether we are trying to convert from mmx to i32 and the bitcast // from mmx to v2i32 has a single usage. if (InputVector.getNode()->getOpcode() == llvm::ISD::BITCAST && @@ -17616,17 +17628,16 @@ static SDValue CMPEQCombine(SDNode *N, SelectionDAG &DAG, if ((cc0 == X86::COND_E && cc1 == X86::COND_NP) || (cc0 == X86::COND_NE && cc1 == X86::COND_P)) { bool is64BitFP = (CMP00.getValueType() == MVT::f64); - X86ISD::NodeType NTOperator = is64BitFP ? - X86ISD::FSETCCsd : X86ISD::FSETCCss; // FIXME: need symbolic constants for these magic numbers. // See X86ATTInstPrinter.cpp:printSSECC(). unsigned x86cc = (cc0 == X86::COND_E) ? 0 : 4; - SDValue OnesOrZeroesF = DAG.getNode(NTOperator, DL, MVT::f32, CMP00, CMP01, + SDValue OnesOrZeroesF = DAG.getNode(X86ISD::FSETCC, DL, CMP00.getValueType(), CMP00, CMP01, DAG.getConstant(x86cc, MVT::i8)); - SDValue OnesOrZeroesI = DAG.getNode(ISD::BITCAST, DL, MVT::i32, + MVT IntVT = (is64BitFP ? MVT::i64 : MVT::i32); + SDValue OnesOrZeroesI = DAG.getNode(ISD::BITCAST, DL, IntVT, OnesOrZeroesF); - SDValue ANDed = DAG.getNode(ISD::AND, DL, MVT::i32, OnesOrZeroesI, - DAG.getConstant(1, MVT::i32)); + SDValue ANDed = DAG.getNode(ISD::AND, DL, IntVT, OnesOrZeroesI, + DAG.getConstant(1, IntVT)); SDValue OneBitOfTruth = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ANDed); return OneBitOfTruth; } |

