diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 101 |
1 files changed, 47 insertions, 54 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 890103dfd9f..475a1c646bf 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23656,69 +23656,62 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, // and then ashr/lshr the upper bits down to the lower bits before multiply. unsigned ExAVX = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; - // For 512-bit vectors, split into 256-bit vectors to allow the + if ((VT == MVT::v16i8 && Subtarget.hasInt256()) || + (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) { + MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts); + SDValue ExA = DAG.getNode(ExAVX, dl, ExVT, A); + SDValue ExB = DAG.getNode(ExAVX, dl, ExVT, B); + SDValue Mul = DAG.getNode(ISD::MUL, dl, ExVT, ExA, ExB); + Mul = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); + } + + // For signed 512-bit vectors, split into 256-bit vectors to allow the // sign-extension to occur. - if (VT == MVT::v64i8) + if (VT == MVT::v64i8 && IsSigned) return split512IntArith(Op, DAG); - // AVX2 implementations - extend xmm subvectors to ymm. - if (Subtarget.hasInt256()) { + // Signed AVX2 implementation - extend xmm subvectors to ymm. + if (VT == MVT::v32i8 && IsSigned) { SDValue Lo = DAG.getIntPtrConstant(0, dl); SDValue Hi = DAG.getIntPtrConstant(NumElts / 2, dl); - if (VT == MVT::v32i8) { - if (Subtarget.canExtendTo512BW()) { - MVT ExVT = MVT::v32i16; - SDValue ExA = DAG.getNode(ExAVX, dl, ExVT, A); - SDValue ExB = DAG.getNode(ExAVX, dl, ExVT, B); - SDValue Mul = DAG.getNode(ISD::MUL, dl, ExVT, ExA, ExB); - Mul = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG); - return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); - } - MVT ExVT = MVT::v16i16; - SDValue ALo = extract128BitVector(A, 0, DAG, dl); - SDValue BLo = extract128BitVector(B, 0, DAG, dl); - SDValue AHi = extract128BitVector(A, NumElts / 2, DAG, dl); - SDValue BHi = extract128BitVector(B, NumElts / 2, DAG, dl); - ALo = DAG.getNode(ExAVX, dl, ExVT, ALo); - BLo = DAG.getNode(ExAVX, dl, ExVT, BLo); - AHi = DAG.getNode(ExAVX, dl, ExVT, AHi); - BHi = DAG.getNode(ExAVX, dl, ExVT, BHi); - Lo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); - Hi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi); - Lo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Lo, 8, DAG); - Hi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Hi, 8, DAG); - - // Bitcast back to VT and then pack all the even elements from Lo and Hi. - // Shuffle lowering should turn this into PACKUS+PERMQ - Lo = DAG.getBitcast(VT, Lo); - Hi = DAG.getBitcast(VT, Hi); - return DAG.getVectorShuffle(VT, dl, Lo, Hi, - { 0, 2, 4, 6, 8, 10, 12, 14, - 16, 18, 20, 22, 24, 26, 28, 30, - 32, 34, 36, 38, 40, 42, 44, 46, - 48, 50, 52, 54, 56, 58, 60, 62}); - } - - assert(VT == MVT::v16i8 && "Unexpected VT"); - - SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v16i16, A); - SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v16i16, B); - SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v16i16, ExA, ExB); - Mul = - getTargetVShiftByConstNode(X86ISD::VSRLI, dl, MVT::v16i16, Mul, 8, DAG); - return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); - } - - assert(VT == MVT::v16i8 && - "Pre-AVX2 support only supports v16i8 multiplication"); - MVT ExVT = MVT::v8i16; + MVT ExVT = MVT::v16i16; + SDValue ALo = extract128BitVector(A, 0, DAG, dl); + SDValue BLo = extract128BitVector(B, 0, DAG, dl); + SDValue AHi = extract128BitVector(A, NumElts / 2, DAG, dl); + SDValue BHi = extract128BitVector(B, NumElts / 2, DAG, dl); + ALo = DAG.getNode(ExAVX, dl, ExVT, ALo); + BLo = DAG.getNode(ExAVX, dl, ExVT, BLo); + AHi = DAG.getNode(ExAVX, dl, ExVT, AHi); + BHi = DAG.getNode(ExAVX, dl, ExVT, BHi); + Lo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); + Hi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi); + Lo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Lo, 8, DAG); + Hi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Hi, 8, DAG); + + // Bitcast back to VT and then pack all the even elements from Lo and Hi. + // Shuffle lowering should turn this into PACKUS+PERMQ + Lo = DAG.getBitcast(VT, Lo); + Hi = DAG.getBitcast(VT, Hi); + return DAG.getVectorShuffle(VT, dl, Lo, Hi, + { 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, + 48, 50, 52, 54, 56, 58, 60, 62}); + } + + // For signed v16i8 and all unsigned vXi8 we will unpack the low and high + // half of each 128 bit lane to widen to a vXi16 type. Do the multiplies, + // shift the results and pack the half lane results back together. + + MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); unsigned ExSSE41 = IsSigned ? ISD::SIGN_EXTEND_VECTOR_INREG : ISD::ZERO_EXTEND_VECTOR_INREG; // Extract the lo parts and zero/sign extend to i16. SDValue ALo, BLo; - if (Subtarget.hasSSE41()) { + if (VT == MVT::v16i8 && Subtarget.hasSSE41()) { ALo = DAG.getNode(ExSSE41, dl, ExVT, A); BLo = DAG.getNode(ExSSE41, dl, ExVT, B); } else if (IsSigned) { @@ -23737,7 +23730,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, // Extract the hi parts and zero/sign extend to i16. SDValue AHi, BHi; - if (Subtarget.hasSSE41()) { + if (VT == MVT::v16i8 && Subtarget.hasSSE41()) { const int ShufMask[] = { 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1}; AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); @@ -23759,7 +23752,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, } // Multiply, lshr the upper 8bits to the lower 8bits of the lo/hi results and - // pack back to v16i8. + // pack back to vXi8. SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); SDValue RHi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi); RLo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, RLo, 8, DAG); |

