diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 117 |
1 files changed, 75 insertions, 42 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c9c6ec22d9e..0d178043469 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23447,7 +23447,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, // Lower v16i8/v32i8/v64i8 mul as sign-extension to v8i16/v16i16/v32i16 // vector pairs, multiply and truncate. if (VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) { - int NumElts = VT.getVectorNumElements(); + unsigned NumElts = VT.getVectorNumElements(); if ((VT == MVT::v16i8 && Subtarget.hasInt256()) || (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) { @@ -23461,24 +23461,33 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); - // Extract the lo parts to any extend to i16. + // Extract the lo/hi parts to any extend to i16. // We're going to mask off the low byte of each result element of the // pmullw, so it doesn't matter what's in the high byte of each 16-bit // element. SDValue Undef = DAG.getUNDEF(VT); - SDValue ALo = getUnpackl(DAG, dl, VT, A, Undef); - SDValue BLo = getUnpackl(DAG, dl, VT, B, Undef); - ALo = DAG.getBitcast(ExVT, ALo); - BLo = DAG.getBitcast(ExVT, BLo); + SDValue ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Undef)); + SDValue AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Undef)); - // Extract the hi parts to any extend to i16. - // We're going to mask off the low byte of each result element of the - // pmullw, so it doesn't matter what's in the high byte of each 16-bit - // element. - SDValue AHi = getUnpackh(DAG, dl, VT, A, Undef); - SDValue BHi = getUnpackh(DAG, dl, VT, B, Undef); - AHi = DAG.getBitcast(ExVT, AHi); - BHi = DAG.getBitcast(ExVT, BHi); + SDValue BLo, BHi; + if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) { + // If the LHS is a constant, manually unpackl/unpackh. + SmallVector<SDValue, 16> LoOps, HiOps; + for (unsigned i = 0; i != NumElts; i += 16) { + for (unsigned j = 0; j != 8; ++j) { + LoOps.push_back(DAG.getAnyExtOrTrunc(B.getOperand(i + j), dl, + MVT::i16)); + HiOps.push_back(DAG.getAnyExtOrTrunc(B.getOperand(i + j + 8), dl, + MVT::i16)); + } + } + + BLo = DAG.getBuildVector(ExVT, dl, LoOps); + BHi = DAG.getBuildVector(ExVT, dl, HiOps); + } else { + BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Undef)); + BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Undef)); + } // Multiply, mask the lower 8bits of the lo/hi results and pack. SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo); @@ -23707,51 +23716,75 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); + static const int PSHUFDMask[] = { 8, 9, 10, 11, 12, 13, 14, 15, + -1, -1, -1, -1, -1, -1, -1, -1}; + // Extract the lo parts and zero/sign extend to i16. // Only use SSE4.1 instructions for signed v16i8 where using unpack requires // shifts to sign extend. Using unpack for unsigned only requires an xor to // create zeros and a copy due to tied registers contraints pre-avx. But using // zero_extend_vector_inreg would require an additional pshufd for the high // part. - SDValue ALo, BLo; + + SDValue ALo, AHi; if (IsSigned && VT == MVT::v16i8 && Subtarget.hasSSE41()) { ALo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, A); - BLo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, B); + + AHi = DAG.getVectorShuffle(VT, dl, A, A, PSHUFDMask); + AHi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, AHi); } else if (IsSigned) { - ALo = getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), A); - BLo = getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), B); - ALo = DAG.getBitcast(ExVT, ALo); - BLo = DAG.getBitcast(ExVT, BLo); + ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), A)); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), A)); + ALo = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, ALo, 8, DAG); - BLo = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, BLo, 8, DAG); + AHi = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, AHi, 8, DAG); } else { - ALo = getUnpackl(DAG, dl, VT, A, DAG.getConstant(0, dl, VT)); - BLo = getUnpackl(DAG, dl, VT, B, DAG.getConstant(0, dl, VT)); - ALo = DAG.getBitcast(ExVT, ALo); - BLo = DAG.getBitcast(ExVT, BLo); + ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, + DAG.getConstant(0, dl, VT))); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, + DAG.getConstant(0, dl, VT))); } - // Extract the hi parts and zero/sign extend to i16. - SDValue AHi, BHi; - if (IsSigned && VT == MVT::v16i8 && Subtarget.hasSSE41()) { - const int ShufMask[] = { 8, 9, 10, 11, 12, 13, 14, 15, - -1, -1, -1, -1, -1, -1, -1, -1}; - AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask); - BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask); - AHi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, AHi); + SDValue BLo, BHi; + if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) { + // If the LHS is a constant, manually unpackl/unpackh and extend. + SmallVector<SDValue, 16> LoOps, HiOps; + for (unsigned i = 0; i != NumElts; i += 16) { + for (unsigned j = 0; j != 8; ++j) { + SDValue LoOp = B.getOperand(i + j); + SDValue HiOp = B.getOperand(i + j + 8); + + if (IsSigned) { + LoOp = DAG.getSExtOrTrunc(LoOp, dl, MVT::i16); + HiOp = DAG.getSExtOrTrunc(HiOp, dl, MVT::i16); + } else { + LoOp = DAG.getZExtOrTrunc(LoOp, dl, MVT::i16); + HiOp = DAG.getZExtOrTrunc(HiOp, dl, MVT::i16); + } + + LoOps.push_back(LoOp); + HiOps.push_back(HiOp); + } + } + + BLo = DAG.getBuildVector(ExVT, dl, LoOps); + BHi = DAG.getBuildVector(ExVT, dl, HiOps); + } else if (IsSigned && VT == MVT::v16i8 && Subtarget.hasSSE41()) { + BLo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, B); + + BHi = DAG.getVectorShuffle(VT, dl, B, B, PSHUFDMask); BHi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, ExVT, BHi); } else if (IsSigned) { - AHi = getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), A); - BHi = getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), B); - AHi = DAG.getBitcast(ExVT, AHi); - BHi = DAG.getBitcast(ExVT, BHi); - AHi = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, AHi, 8, DAG); + BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), B)); + BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), B)); + + BLo = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, BLo, 8, DAG); BHi = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, BHi, 8, DAG); } else { - AHi = getUnpackh(DAG, dl, VT, A, DAG.getConstant(0, dl, VT)); - BHi = getUnpackh(DAG, dl, VT, B, DAG.getConstant(0, dl, VT)); - AHi = DAG.getBitcast(ExVT, AHi); - BHi = DAG.getBitcast(ExVT, BHi); + BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, + DAG.getConstant(0, dl, VT))); + BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, + DAG.getConstant(0, dl, VT))); } // Multiply, lshr the upper 8bits to the lower 8bits of the lo/hi results and |