summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp46
1 files changed, 36 insertions, 10 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7f1432c8128..14a0a51ec1d 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -23123,8 +23123,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}
// Determine if V is a splat value, and return the scalar.
-// TODO: Add support for SUB(SPLAT_CST, SPLAT) cases to support rotate patterns.
-static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) {
+static SDValue IsSplatValue(MVT VT, SDValue V, const SDLoc &dl,
+ SelectionDAG &DAG, const X86Subtarget &Subtarget,
+ unsigned Opcode) {
// Check if this is a splat build_vector node.
if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) {
SDValue SplatAmt = BV->getSplatValue();
@@ -23133,6 +23134,30 @@ static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) {
return SplatAmt;
}
+ // Check for SUB(SPLAT_BV, SPLAT) cases from rotate patterns.
+ if (V.getOpcode() == ISD::SUB &&
+ !SupportedVectorVarShift(VT, Subtarget, Opcode)) {
+ // Peek through any EXTRACT_SUBVECTORs.
+ SDValue LHS = V.getOperand(0);
+ SDValue RHS = V.getOperand(1);
+ while (LHS.getOpcode() == ISD::EXTRACT_SUBVECTOR)
+ LHS = LHS.getOperand(0);
+ while (RHS.getOpcode() == ISD::EXTRACT_SUBVECTOR)
+ RHS = RHS.getOperand(0);
+
+ // Ensure that the corresponding splat BV element is not UNDEF.
+ BitVector UndefElts;
+ BuildVectorSDNode *BV0 = dyn_cast<BuildVectorSDNode>(LHS);
+ ShuffleVectorSDNode *SVN1 = dyn_cast<ShuffleVectorSDNode>(RHS);
+ if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) {
+ unsigned SplatIdx = (unsigned)SVN1->getSplatIndex();
+ if (!UndefElts[SplatIdx])
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+ VT.getVectorElementType(), V,
+ DAG.getIntPtrConstant(SplatIdx, dl));
+ }
+ }
+
// Check if this is a shuffle node doing a splat.
ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V);
if (!SVN || !SVN->isSplat())
@@ -23141,7 +23166,7 @@ static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) {
unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
SDValue InVec = V.getOperand(0);
if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
- assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) &&
+ assert((SplatIdx < VT.getVectorNumElements()) &&
"Unexpected shuffle index found!");
return InVec.getOperand(SplatIdx);
} else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
@@ -23152,7 +23177,7 @@ static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) {
// Avoid introducing an extract element from a shuffle.
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
- V.getValueType().getVectorElementType(), InVec,
+ VT.getVectorElementType(), InVec,
DAG.getIntPtrConstant(SplatIdx, dl));
}
@@ -23162,19 +23187,20 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
SDLoc dl(Op);
SDValue R = Op.getOperand(0);
SDValue Amt = Op.getOperand(1);
+ unsigned Opcode = Op.getOpcode();
- unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
- (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
+ unsigned X86OpcI = (Opcode == ISD::SHL) ? X86ISD::VSHLI :
+ (Opcode == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
- unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL :
- (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
+ unsigned X86OpcV = (Opcode == ISD::SHL) ? X86ISD::VSHL :
+ (Opcode == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
// Peek through any EXTRACT_SUBVECTORs.
while (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR)
Amt = Amt.getOperand(0);
- if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
- if (SDValue BaseShAmt = IsSplatValue(Amt, dl, DAG)) {
+ if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) {
+ if (SDValue BaseShAmt = IsSplatValue(VT, Amt, dl, DAG, Subtarget, Opcode)) {
MVT EltVT = VT.getVectorElementType();
assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!");
if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32))
OpenPOWER on IntegriCloud