diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 28 |
1 files changed, 13 insertions, 15 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 05b565304ae..cee8f849669 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23397,18 +23397,10 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, // Lower v16i8/v32i8/v64i8 mul as sign-extension to v8i16/v16i16/v32i16 // vector pairs, multiply and truncate. if (VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) { - if (Subtarget.hasInt256()) { - // For 512-bit vectors, split into 256-bit vectors to allow the - // sign-extension to occur. - if (VT == MVT::v64i8) - return split512IntArith(Op, DAG); - - // For 256-bit vectors, split into 128-bit vectors to allow the - // sign-extension to occur. We don't need this on AVX512BW as we can - // safely sign-extend to v32i16. - if (VT == MVT::v32i8 && !Subtarget.canExtendTo512BW()) - return split256IntArith(Op, DAG); + int NumElts = VT.getVectorNumElements(); + if ((VT == MVT::v16i8 && Subtarget.hasInt256()) || + (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) { MVT ExVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements()); return DAG.getNode( ISD::TRUNCATE, dl, VT, @@ -23417,9 +23409,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, DAG.getNode(ISD::ANY_EXTEND, dl, ExVT, B))); } - assert(VT == MVT::v16i8 && - "Pre-AVX2 support only supports v16i8 multiplication"); - MVT ExVT = MVT::v8i16; + MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); // Extract the lo parts to any extend to i16 // We're going to mask off the low byte of each result element of the @@ -23445,7 +23435,15 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, SDValue RHi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi); RLo = DAG.getNode(ISD::AND, dl, ExVT, RLo, DAG.getConstant(255, dl, ExVT)); RHi = DAG.getNode(ISD::AND, dl, ExVT, RHi, DAG.getConstant(255, dl, ExVT)); - return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi); + RLo = DAG.getBitcast(VT, RLo); + RHi = DAG.getBitcast(VT, RHi); + + // For each 128-bit lane, we need to take the 8 even elements from RLo then + // the 8 even elements from RHi. + SmallVector<int, 64> PackMask; + createPackShuffleMask(VT, PackMask, /*Unary*/false); + + return DAG.getVectorShuffle(VT, dl, RLo, RHi, PackMask); } // Lower v4i32 mul as 2x shuffle, 2x pmuludq, 2x shuffle. |

