diff options
Diffstat (limited to 'llvm/lib/Target')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 166 |
1 files changed, 65 insertions, 101 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 5ea17a4e8c0..512a9f91d02 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1507,6 +1507,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::AND, MVT::v4i32, Legal); setOperationAction(ISD::OR, MVT::v4i32, Legal); setOperationAction(ISD::XOR, MVT::v4i32, Legal); + setOperationAction(ISD::SRA, MVT::v2i64, Custom); + setOperationAction(ISD::SRA, MVT::v4i64, Custom); } // We want to custom lower some of our intrinsics. @@ -16328,6 +16330,53 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget *Subtarget, return DAG.getMergeValues(Ops, dl); } +// Return true if the requred (according to Opcode) shift-imm form is natively +// supported by the Subtarget +static bool SupportedVectorShiftWithImm(MVT VT, const X86Subtarget *Subtarget, + unsigned Opcode) { + if (VT.getScalarSizeInBits() < 16) + return false; + + if (VT.is512BitVector() && + (VT.getScalarSizeInBits() > 16 || Subtarget->hasBWI())) + return true; + + bool LShift = VT.is128BitVector() || + (VT.is256BitVector() && Subtarget->hasInt256()); + + bool AShift = LShift && (Subtarget->hasVLX() || + (VT != MVT::v2i64 && VT != MVT::v4i64)); + return (Opcode == ISD::SRA) ? AShift : LShift; +} + +// The shift amount is a variable, but it is the same for all vector lanes. +// These instrcutions are defined together with shift-immediate. +static +bool SupportedVectorShiftWithBaseAmnt(MVT VT, const X86Subtarget *Subtarget, + unsigned Opcode) { + return SupportedVectorShiftWithImm(VT, Subtarget, Opcode); +} + +// Return true if the requred (according to Opcode) variable-shift form is +// natively supported by the Subtarget +static bool SupportedVectorVarShift(MVT VT, const X86Subtarget *Subtarget, + unsigned Opcode) { + + if (!Subtarget->hasInt256() || VT.getScalarSizeInBits() < 16) + return false; + + // vXi16 supported only on AVX-512, BWI + if (VT.getScalarSizeInBits() == 16 && !Subtarget->hasBWI()) + return false; + + if (VT.is512BitVector() || Subtarget->hasVLX()) + return true; + + bool LShift = VT.is128BitVector() || VT.is256BitVector(); + bool AShift = LShift && VT != MVT::v2i64 && VT != MVT::v4i64; + return (Opcode == ISD::SRA) ? AShift : LShift; +} + static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, const X86Subtarget *Subtarget) { MVT VT = Op.getSimpleValueType(); @@ -16335,26 +16384,16 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); + unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI : + (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; + // Optimize shl/srl/sra with constant shift amount. if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) { if (auto *ShiftConst = BVAmt->getConstantSplatNode()) { uint64_t ShiftAmt = ShiftConst->getZExtValue(); - if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 || - (Subtarget->hasInt256() && - (VT == MVT::v4i64 || VT == MVT::v8i32 || VT == MVT::v16i16)) || - (Subtarget->hasAVX512() && - (VT == MVT::v8i64 || VT == MVT::v16i32))) { - if (Op.getOpcode() == ISD::SHL) - return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt, - DAG); - if (Op.getOpcode() == ISD::SRL) - return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt, - DAG); - if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64) - return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt, - DAG); - } + if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode())) + return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG); if (VT == MVT::v16i8 || (Subtarget->hasInt256() && VT == MVT::v32i8)) { unsigned NumElts = VT.getVectorNumElements(); @@ -16435,19 +16474,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, if (ShAmt != ShiftAmt) return SDValue(); } - switch (Op.getOpcode()) { - default: - llvm_unreachable("Unknown shift opcode!"); - case ISD::SHL: - return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt, - DAG); - case ISD::SRL: - return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt, - DAG); - case ISD::SRA: - return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt, - DAG); - } + return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG); } return SDValue(); @@ -16460,12 +16487,13 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); - if ((VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) || - VT == MVT::v4i32 || VT == MVT::v8i16 || - (Subtarget->hasInt256() && - ((VT == MVT::v4i64 && Op.getOpcode() != ISD::SRA) || - VT == MVT::v8i32 || VT == MVT::v16i16)) || - (Subtarget->hasAVX512() && (VT == MVT::v8i64 || VT == MVT::v16i32))) { + unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI : + (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; + + unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL : + (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA; + + if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) { SDValue BaseShAmt; EVT EltVT = VT.getVectorElementType(); @@ -16509,47 +16537,7 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, else if (EltVT.bitsLT(MVT::i32)) BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt); - switch (Op.getOpcode()) { - default: - llvm_unreachable("Unknown shift opcode!"); - case ISD::SHL: - switch (VT.SimpleTy) { - default: return SDValue(); - case MVT::v2i64: - case MVT::v4i32: - case MVT::v8i16: - case MVT::v4i64: - case MVT::v8i32: - case MVT::v16i16: - case MVT::v16i32: - case MVT::v8i64: - return getTargetVShiftNode(X86ISD::VSHLI, dl, VT, R, BaseShAmt, DAG); - } - case ISD::SRA: - switch (VT.SimpleTy) { - default: return SDValue(); - case MVT::v4i32: - case MVT::v8i16: - case MVT::v8i32: - case MVT::v16i16: - case MVT::v16i32: - case MVT::v8i64: - return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, R, BaseShAmt, DAG); - } - case ISD::SRL: - switch (VT.SimpleTy) { - default: return SDValue(); - case MVT::v2i64: - case MVT::v4i32: - case MVT::v8i16: - case MVT::v4i64: - case MVT::v8i32: - case MVT::v16i16: - case MVT::v16i32: - case MVT::v8i64: - return getTargetVShiftNode(X86ISD::VSRLI, dl, VT, R, BaseShAmt, DAG); - } - } + return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, DAG); } } @@ -16568,18 +16556,8 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, if (Vals[j] != Amt.getOperand(i + j)) return SDValue(); } - switch (Op.getOpcode()) { - default: - llvm_unreachable("Unknown shift opcode!"); - case ISD::SHL: - return DAG.getNode(X86ISD::VSHL, dl, VT, R, Op.getOperand(1)); - case ISD::SRL: - return DAG.getNode(X86ISD::VSRL, dl, VT, R, Op.getOperand(1)); - case ISD::SRA: - return DAG.getNode(X86ISD::VSRA, dl, VT, R, Op.getOperand(1)); - } + return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1)); } - return SDValue(); } @@ -16599,23 +16577,9 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget, if (SDValue V = LowerScalarVariableShift(Op, DAG, Subtarget)) return V; - if (Subtarget->hasAVX512() && (VT == MVT::v16i32 || VT == MVT::v8i64)) + if (SupportedVectorVarShift(VT, Subtarget, Op.getOpcode())) return Op; - // AVX2 has VPSLLV/VPSRAV/VPSRLV. - if (Subtarget->hasInt256()) { - if (Op.getOpcode() == ISD::SRL && - (VT == MVT::v2i64 || VT == MVT::v4i32 || - VT == MVT::v4i64 || VT == MVT::v8i32)) - return Op; - if (Op.getOpcode() == ISD::SHL && - (VT == MVT::v2i64 || VT == MVT::v4i32 || - VT == MVT::v4i64 || VT == MVT::v8i32)) - return Op; - if (Op.getOpcode() == ISD::SRA && (VT == MVT::v4i32 || VT == MVT::v8i32)) - return Op; - } - // 2i64 vector logical shifts can efficiently avoid scalarization - do the // shifts per-lane and then shuffle the partial results back together. if (VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) { |

