diff options
author | Wei Mi <wmi@google.com> | 2016-06-14 18:53:20 +0000 |
---|---|---|
committer | Wei Mi <wmi@google.com> | 2016-06-14 18:53:20 +0000 |
commit | b799a625f9221b017e2ab2503e291c503da6c763 (patch) | |
tree | 11a43258809590cacd2698956315b61d21259186 /llvm/lib | |
parent | 07c229c9e7f5520953207ab8cf159674c9bd33f2 (diff) | |
download | bcm5719-llvm-b799a625f9221b017e2ab2503e291c503da6c763.tar.gz bcm5719-llvm-b799a625f9221b017e2ab2503e291c503da6c763.zip |
[X86] Reduce the width of multiplification when its operands are extended from i8 or i16
For <N x i32> type mul, pmuludq will be used for targets without SSE41, which
often introduces many extra pack and unpack instructions in vectorized loop
body because pmuludq generates <N/2 x i64> type value. However when the operands
of <N x i32> mul are extended from smaller size values like i8 and i16, the type
of mul may be shrunk to use pmullw + pmulhw/pmulhuw instead of pmuludq, which
generates better code. For targets with SSE41, pmulld is supported so no
shrinking is needed.
Differential Revision: http://reviews.llvm.org/D20931
llvm-svn: 272694
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 211 |
2 files changed, 210 insertions, 3 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 9ba68bf2be3..572fecac219 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -670,6 +670,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) { case ISD::ADD: case ISD::SUB: case ISD::MUL: + case ISD::MULHS: + case ISD::MULHU: case ISD::FADD: case ISD::FSUB: case ISD::FMUL: diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 69cf0d269a5..06a0aa39603 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -26962,10 +26962,216 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, return SDValue(); } +/// Different mul shrinking modes. +enum ShrinkMode { MULS8, MULU8, MULS16, MULU16 }; + +static bool canReduceVMulWidth(SDNode *N, SelectionDAG &DAG, ShrinkMode &Mode) { + EVT VT = N->getOperand(0).getValueType(); + if (VT.getScalarSizeInBits() != 32) + return false; + + assert(N->getNumOperands() == 2 && "NumOperands of Mul are 2"); + unsigned SignBits[2] = {1, 1}; + bool IsPositive[2] = {false, false}; + for (unsigned i = 0; i < 2; i++) { + SDValue Opd = N->getOperand(i); + + // DAG.ComputeNumSignBits return 1 for ISD::ANY_EXTEND, so we need to + // compute signbits for it separately. + if (Opd.getOpcode() == ISD::ANY_EXTEND) { + // For anyextend, it is safe to assume an appropriate number of leading + // sign/zero bits. + if (Opd.getOperand(0).getValueType().getVectorElementType() == MVT::i8) + SignBits[i] = 25; + else if (Opd.getOperand(0).getValueType().getVectorElementType() == + MVT::i16) + SignBits[i] = 17; + else + return false; + IsPositive[i] = true; + } else if (Opd.getOpcode() == ISD::BUILD_VECTOR) { + // All the operands of BUILD_VECTOR need to be int constant. + // Find the smallest value range which all the operands belong to. + SignBits[i] = 32; + IsPositive[i] = true; + for (const SDValue &SubOp : Opd.getNode()->op_values()) { + if (SubOp.isUndef()) + continue; + auto *CN = dyn_cast<ConstantSDNode>(SubOp); + if (!CN) + return false; + APInt IntVal = CN->getAPIntValue(); + if (IntVal.isNegative()) + IsPositive[i] = false; + SignBits[i] = std::min(SignBits[i], IntVal.getNumSignBits()); + } + } else { + SignBits[i] = DAG.ComputeNumSignBits(Opd); + if (Opd.getOpcode() == ISD::ZERO_EXTEND) + IsPositive[i] = true; + } + } + + bool AllPositive = IsPositive[0] && IsPositive[1]; + unsigned MinSignBits = std::min(SignBits[0], SignBits[1]); + // When ranges are from -128 ~ 127, use MULS8 mode. + if (MinSignBits >= 25) + Mode = MULS8; + // When ranges are from 0 ~ 255, use MULU8 mode. + else if (AllPositive && MinSignBits >= 24) + Mode = MULU8; + // When ranges are from -32768 ~ 32767, use MULS16 mode. + else if (MinSignBits >= 17) + Mode = MULS16; + // When ranges are from 0 ~ 65535, use MULU16 mode. + else if (AllPositive && MinSignBits >= 16) + Mode = MULU16; + else + return false; + return true; +} + +/// When the operands of vector mul are extended from smaller size values, +/// like i8 and i16, the type of mul may be shrinked to generate more +/// efficient code. Two typical patterns are handled: +/// Pattern1: +/// %2 = sext/zext <N x i8> %1 to <N x i32> +/// %4 = sext/zext <N x i8> %3 to <N x i32> +// or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants) +/// %5 = mul <N x i32> %2, %4 +/// +/// Pattern2: +/// %2 = zext/sext <N x i16> %1 to <N x i32> +/// %4 = zext/sext <N x i16> %3 to <N x i32> +/// or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants) +/// %5 = mul <N x i32> %2, %4 +/// +/// There are four mul shrinking modes: +/// If %2 == sext32(trunc8(%2)), i.e., the scalar value range of %2 is +/// -128 to 128, and the scalar value range of %4 is also -128 to 128, +/// generate pmullw+sext32 for it (MULS8 mode). +/// If %2 == zext32(trunc8(%2)), i.e., the scalar value range of %2 is +/// 0 to 255, and the scalar value range of %4 is also 0 to 255, +/// generate pmullw+zext32 for it (MULU8 mode). +/// If %2 == sext32(trunc16(%2)), i.e., the scalar value range of %2 is +/// -32768 to 32767, and the scalar value range of %4 is also -32768 to 32767, +/// generate pmullw+pmulhw for it (MULS16 mode). +/// If %2 == zext32(trunc16(%2)), i.e., the scalar value range of %2 is +/// 0 to 65535, and the scalar value range of %4 is also 0 to 65535, +/// generate pmullw+pmulhuw for it (MULU16 mode). +static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // pmulld is supported since SSE41. It is better to use pmulld + // instead of pmullw+pmulhw. + if (Subtarget.hasSSE41()) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(N, DAG, Mode)) + return SDValue(); + + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getOperand(0).getValueType(); + unsigned RegSize = 128; + MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16); + EVT ReducedVT = + EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); + // Shrink the operands of mul. + SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0); + SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1); + + if (VT.getVectorNumElements() >= OpsVT.getVectorNumElements()) { + // Generate the lower part of mul: pmullw. For MULU8/MULS8, only the + // lower part is needed. + SDValue MulLo = DAG.getNode(ISD::MUL, DL, ReducedVT, NewN0, NewN1); + if (Mode == MULU8 || Mode == MULS8) { + return DAG.getNode((Mode == MULU8) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, + DL, VT, MulLo); + } else { + MVT ResVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2); + // Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16, + // the higher part is also needed. + SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL, + ReducedVT, NewN0, NewN1); + + // Repack the lower part and higher part result of mul into a wider + // result. + // Generate shuffle functioning as punpcklwd. + SmallVector<int, 16> ShuffleMask(VT.getVectorNumElements()); + for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) { + ShuffleMask[2 * i] = i; + ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements(); + } + SDValue ResLo = + DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]); + ResLo = DAG.getNode(ISD::BITCAST, DL, ResVT, ResLo); + // Generate shuffle functioning as punpckhwd. + for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) { + ShuffleMask[2 * i] = i + VT.getVectorNumElements() / 2; + ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements() * 3 / 2; + } + SDValue ResHi = + DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]); + ResHi = DAG.getNode(ISD::BITCAST, DL, ResVT, ResHi); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi); + } + } else { + // When VT.getVectorNumElements() < OpsVT.getVectorNumElements(), we want + // to legalize the mul explicitly because implicit legalization for type + // <4 x i16> to <4 x i32> sometimes involves unnecessary unpack + // instructions which will not exist when we explicitly legalize it by + // extending <4 x i16> to <8 x i16> (concatenating the <4 x i16> val with + // <4 x i16> undef). + // + // Legalize the operands of mul. + SmallVector<SDValue, 16> Ops(RegSize / ReducedVT.getSizeInBits(), + DAG.getUNDEF(ReducedVT)); + Ops[0] = NewN0; + NewN0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops); + Ops[0] = NewN1; + NewN1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops); + + if (Mode == MULU8 || Mode == MULS8) { + // Generate lower part of mul: pmullw. For MULU8/MULS8, only the lower + // part is needed. + SDValue Mul = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1); + + // convert the type of mul result to VT. + MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + SDValue Res = DAG.getNode(Mode == MULU8 ? ISD::ZERO_EXTEND_VECTOR_INREG + : ISD::SIGN_EXTEND_VECTOR_INREG, + DL, ResVT, Mul); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, + DAG.getIntPtrConstant(0, DL)); + } else { + // Generate the lower and higher part of mul: pmulhw/pmulhuw. For + // MULU16/MULS16, both parts are needed. + SDValue MulLo = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1); + SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL, + OpsVT, NewN0, NewN1); + + // Repack the lower part and higher part result of mul into a wider + // result. Make sure the type of mul result is VT. + MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + SDValue Res = DAG.getNode(X86ISD::UNPCKL, DL, OpsVT, MulLo, MulHi); + Res = DAG.getNode(ISD::BITCAST, DL, ResVT, Res); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, + DAG.getIntPtrConstant(0, DL)); + } + } +} + /// Optimize a single multiply with constant into two operations in order to /// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA. static SDValue combineMul(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + EVT VT = N->getValueType(0); + if (DCI.isBeforeLegalize() && VT.isVector()) + return reduceVMULWidth(N, DAG, Subtarget); + // An imul is usually smaller than the alternative sequence. if (DAG.getMachineFunction().getFunction()->optForMinSize()) return SDValue(); @@ -26973,7 +27179,6 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer()) return SDValue(); - EVT VT = N->getValueType(0); if (VT != MVT::i64 && VT != MVT::i32) return SDValue(); @@ -30268,7 +30473,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::ADD: return combineAdd(N, DAG, Subtarget); case ISD::SUB: return combineSub(N, DAG, Subtarget); case X86ISD::ADC: return combineADC(N, DAG, DCI); - case ISD::MUL: return combineMul(N, DAG, DCI); + case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget); case ISD::SHL: case ISD::SRA: case ISD::SRL: return combineShift(N, DAG, DCI, Subtarget); |