diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 98616119fa0..e50ca6d08f1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1560,6 +1560,7 @@ void X86TargetLowering::resetOperationActions() { setOperationAction(ISD::SETCC, MVT::v4i1, Custom); setOperationAction(ISD::SETCC, MVT::v2i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Legal); } // SIGN_EXTEND_INREGs are evaluated by the extend type. Handle the expansion @@ -15867,6 +15868,8 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, EVT VT = Op.getValueType(); EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, VT.getVectorNumElements()); + EVT BitcastVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + Mask.getValueType().getSizeInBits()); SDLoc dl(Op); assert(MaskVT.isSimple() && "invalid mask type"); @@ -15874,19 +15877,22 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, if (isAllOnes(Mask)) return Op; + // In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements + // are extracted by EXTRACT_SUBVECTOR. + SDValue VMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT, + DAG.getNode(ISD::BITCAST, dl, BitcastVT, Mask), + DAG.getIntPtrConstant(0)); + switch (Op.getOpcode()) { default: break; case X86ISD::PCMPEQM: case X86ISD::PCMPGTM: case X86ISD::CMPM: case X86ISD::CMPMU: - return DAG.getNode(ISD::AND, dl, VT, Op, - DAG.getNode(ISD::BITCAST, dl, MaskVT, Mask)); + return DAG.getNode(ISD::AND, dl, VT, Op, VMask); } - return DAG.getNode(ISD::VSELECT, dl, VT, - DAG.getNode(ISD::BITCAST, dl, MaskVT, Mask), - Op, PreservedSrc); + return DAG.getNode(ISD::VSELECT, dl, VT, VMask, Op, PreservedSrc); } static unsigned getOpcodeForFMAIntrinsic(unsigned IntNo) { @@ -15953,13 +15959,28 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) { return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); case CMP_MASK: { + // Comparison intrinsics with masks. + // Example of transformation: + // (i8 (int_x86_avx512_mask_pcmpeq_q_128 + // (v2i64 %a), (v2i64 %b), (i8 %mask))) -> + // (i8 (bitcast + // (v8i1 (insert_subvector undef, + // (v2i1 (and (PCMPEQM %a, %b), + // (extract_subvector + // (v8i1 (bitcast %mask)), 0))), 0)))) EVT VT = Op.getOperand(1).getValueType(); EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, VT.getVectorNumElements()); + SDValue Mask = Op.getOperand(3); + EVT BitcastVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + Mask.getValueType().getSizeInBits()); SDValue Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1), Op.getOperand(2)); - SDValue Res = getVectorMaskingNode(Cmp, Op.getOperand(3), - DAG.getTargetConstant(0, MaskVT), DAG); + SDValue CmpMask = getVectorMaskingNode(Cmp, Op.getOperand(3), + DAG.getTargetConstant(0, MaskVT), DAG); + SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT, + DAG.getUNDEF(BitcastVT), CmpMask, + DAG.getIntPtrConstant(0)); return DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Res); } case COMI: { // Comparison intrinsics |