diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 46 |
1 files changed, 36 insertions, 10 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7f1432c8128..14a0a51ec1d 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23123,8 +23123,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, } // Determine if V is a splat value, and return the scalar. -// TODO: Add support for SUB(SPLAT_CST, SPLAT) cases to support rotate patterns. -static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) { +static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl, + SelectionDAG &DAG, const X86Subtarget &Subtarget, + unsigned Opcode) { // Check if this is a splat build_vector node. if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) { SDValue SplatAmt = BV->getSplatValue(); @@ -23133,6 +23134,30 @@ static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) { return SplatAmt; } + // Check for SUB(SPLAT_BV, SPLAT) cases from rotate patterns. + if (V.getOpcode() == ISD::SUB && + !SupportedVectorVarShift(VT, Subtarget, Opcode)) { + // Peek through any EXTRACT_SUBVECTORs. + SDValue LHS = V.getOperand(0); + SDValue RHS = V.getOperand(1); + while (LHS.getOpcode() == ISD::EXTRACT_SUBVECTOR) + LHS = LHS.getOperand(0); + while (RHS.getOpcode() == ISD::EXTRACT_SUBVECTOR) + RHS = RHS.getOperand(0); + + // Ensure that the corresponding splat BV element is not UNDEF. + BitVector UndefElts; + BuildVectorSDNode *BV0 = dyn_cast<BuildVectorSDNode>(LHS); + ShuffleVectorSDNode *SVN1 = dyn_cast<ShuffleVectorSDNode>(RHS); + if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) { + unsigned SplatIdx = (unsigned)SVN1->getSplatIndex(); + if (!UndefElts[SplatIdx]) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + VT.getVectorElementType(), V, + DAG.getIntPtrConstant(SplatIdx, dl)); + } + } + // Check if this is a shuffle node doing a splat. ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V); if (!SVN || !SVN->isSplat()) @@ -23141,7 +23166,7 @@ static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) { unsigned SplatIdx = (unsigned)SVN->getSplatIndex(); SDValue InVec = V.getOperand(0); if (InVec.getOpcode() == ISD::BUILD_VECTOR) { - assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) && + assert((SplatIdx < VT.getVectorNumElements()) && "Unexpected shuffle index found!"); return InVec.getOperand(SplatIdx); } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) { @@ -23152,7 +23177,7 @@ static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) { // Avoid introducing an extract element from a shuffle. return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - V.getValueType().getVectorElementType(), InVec, + VT.getVectorElementType(), InVec, DAG.getIntPtrConstant(SplatIdx, dl)); } @@ -23162,19 +23187,20 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, SDLoc dl(Op); SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); + unsigned Opcode = Op.getOpcode(); - unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI : - (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; + unsigned X86OpcI = (Opcode == ISD::SHL) ? X86ISD::VSHLI : + (Opcode == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; - unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL : - (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA; + unsigned X86OpcV = (Opcode == ISD::SHL) ? X86ISD::VSHL : + (Opcode == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA; // Peek through any EXTRACT_SUBVECTORs. while (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR) Amt = Amt.getOperand(0); - if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) { - if (SDValue BaseShAmt = IsSplatValue(Amt, dl, DAG)) { + if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) { + if (SDValue BaseShAmt = IsSplatValue(VT, Amt, dl, DAG, Subtarget, Opcode)) { MVT EltVT = VT.getVectorElementType(); assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!"); if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32)) |

