diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 69 |
1 files changed, 33 insertions, 36 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c03d32cfbef..9983f1046f3 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -20311,6 +20311,25 @@ static SDValue LowerVACOPY(SDValue Op, const X86Subtarget &Subtarget, MachinePointerInfo(DstSV), MachinePointerInfo(SrcSV)); } +// Helper to get immediate/variable SSE shift opcode from other shift opcodes. +static unsigned getTargetVShiftUniformOpcode(unsigned Opc, bool IsVariable) { + switch (Opc) { + case ISD::SHL: + case X86ISD::VSHL: + case X86ISD::VSHLI: + return IsVariable ? X86ISD::VSHL : X86ISD::VSHLI; + case ISD::SRL: + case X86ISD::VSRL: + case X86ISD::VSRLI: + return IsVariable ? X86ISD::VSRL : X86ISD::VSRLI; + case ISD::SRA: + case X86ISD::VSRA: + case X86ISD::VSRAI: + return IsVariable ? X86ISD::VSRA : X86ISD::VSRAI; + } + llvm_unreachable("Unknown target vector shift node"); +} + /// Handle vector element shifts where the shift amount is a constant. /// Takes immediate version of shift as input. static SDValue getTargetVShiftByConstNode(unsigned Opc, const SDLoc &dl, MVT VT, @@ -20406,13 +20425,8 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT, return getTargetVShiftByConstNode(Opc, dl, VT, SrcOp, CShAmt->getZExtValue(), DAG); - // Change opcode to non-immediate version - switch (Opc) { - default: llvm_unreachable("Unknown target vector shift node"); - case X86ISD::VSHLI: Opc = X86ISD::VSHL; break; - case X86ISD::VSRLI: Opc = X86ISD::VSRL; break; - case X86ISD::VSRAI: Opc = X86ISD::VSRA; break; - } + // Change opcode to non-immediate version. + Opc = getTargetVShiftUniformOpcode(Opc, true); // Need to build a vector containing shift amount. // SSE/AVX packed shifts only use the lower 64-bit of the shift count. @@ -23212,9 +23226,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, SDLoc dl(Op); SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); - - unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI : - (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; + unsigned X86Opc = getTargetVShiftUniformOpcode(Op.getOpcode(), false); auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) { assert((VT == MVT::v2i64 || VT == MVT::v4i64) && "Unexpected SRA type"); @@ -23459,12 +23471,8 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, SDValue R = Op.getOperand(0); SDValue Amt = Op.getOperand(1); unsigned Opcode = Op.getOpcode(); - - unsigned X86OpcI = (Opcode == ISD::SHL) ? X86ISD::VSHLI : - (Opcode == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI; - - unsigned X86OpcV = (Opcode == ISD::SHL) ? X86ISD::VSHL : - (Opcode == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA; + unsigned X86OpcI = getTargetVShiftUniformOpcode(Opcode, false); + unsigned X86OpcV = getTargetVShiftUniformOpcode(Opcode, true); Amt = peekThroughEXTRACT_SUBVECTORs(Amt); @@ -23607,8 +23615,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // Splat the shift amounts so the scalar shifts above will catch it. SDValue Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Amt, {0, 0}); SDValue Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Amt, {1, 1}); - SDValue R0 = DAG.getNode(Op->getOpcode(), dl, VT, R, Amt0); - SDValue R1 = DAG.getNode(Op->getOpcode(), dl, VT, R, Amt1); + SDValue R0 = DAG.getNode(Opc, dl, VT, R, Amt0); + SDValue R1 = DAG.getNode(Opc, dl, VT, R, Amt1); return DAG.getVectorShuffle(VT, dl, R0, R1, {0, 3}); } @@ -23714,19 +23722,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, Amt3 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {3, 3, 3, 3}); } else { // ISD::SHL is handled above but we include it here for completeness. - switch (Opc) { - default: - llvm_unreachable("Unknown target vector shift node"); - case ISD::SHL: - ShOpc = X86ISD::VSHL; - break; - case ISD::SRL: - ShOpc = X86ISD::VSRL; - break; - case ISD::SRA: - ShOpc = X86ISD::VSRA; - break; - } + ShOpc = getTargetVShiftUniformOpcode(Opc, true); + // The SSE2 shifts use the lower i64 as the same shift amount for // all lanes and the upper i64 is ignored. On AVX we're better off // just zero-extending, but for SSE just duplicating the top 16-bits is @@ -23827,7 +23824,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, SDValue LoA = DAG.getBuildVector(VT16, dl, LoAmt); SDValue HiA = DAG.getBuildVector(VT16, dl, HiAmt); - unsigned ShiftOp = Opc == ISD::SRA ? X86ISD::VSRAI : X86ISD::VSRLI; + unsigned ShiftOp = getTargetVShiftUniformOpcode(Opc, false); SDValue LoR = DAG.getBitcast(VT16, getUnpackl(DAG, dl, VT, R, R)); SDValue HiR = DAG.getBitcast(VT16, getUnpackh(DAG, dl, VT, R, R)); LoR = DAG.getNode(ShiftOp, dl, VT16, LoR, Cst8); @@ -23843,7 +23840,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) || (VT == MVT::v64i8 && Subtarget.hasBWI())) { MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); - unsigned ShiftOpcode = Op->getOpcode(); + unsigned ShiftOpcode = Opc; auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) { if (VT.is512BitVector()) { @@ -23880,7 +23877,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, Amt = DAG.getNode(ISD::SHL, dl, ExtVT, Amt, DAG.getConstant(5, dl, ExtVT)); Amt = DAG.getBitcast(VT, Amt); - if (Op->getOpcode() == ISD::SHL || Op->getOpcode() == ISD::SRL) { + if (Opc == ISD::SHL || Opc == ISD::SRL) { // r = VSELECT(r, shift(r, 4), a); SDValue M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(4, dl, VT)); @@ -23902,7 +23899,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, return R; } - if (Op->getOpcode() == ISD::SRA) { + if (Opc == ISD::SRA) { // For SRA we need to unpack each byte to the higher byte of a i16 vector // so we can correctly sign extend. We don't care what happens to the // lower byte. @@ -23977,7 +23974,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, } if (VT == MVT::v8i16) { - unsigned ShiftOpcode = Op->getOpcode(); + unsigned ShiftOpcode = Opc; // If we have a constant shift amount, the non-SSE41 path is best as // avoiding bitcasts make it easier to constant fold and reduce to PBLENDW. |