summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
authorWei Mi <wmi@google.com>2016-06-14 18:53:20 +0000
committerWei Mi <wmi@google.com>2016-06-14 18:53:20 +0000
commitb799a625f9221b017e2ab2503e291c503da6c763 (patch)
tree11a43258809590cacd2698956315b61d21259186 /llvm/lib
parent07c229c9e7f5520953207ab8cf159674c9bd33f2 (diff)
downloadbcm5719-llvm-b799a625f9221b017e2ab2503e291c503da6c763.tar.gz
bcm5719-llvm-b799a625f9221b017e2ab2503e291c503da6c763.zip
[X86] Reduce the width of multiplification when its operands are extended from i8 or i16
For <N x i32> type mul, pmuludq will be used for targets without SSE41, which often introduces many extra pack and unpack instructions in vectorized loop body because pmuludq generates <N/2 x i64> type value. However when the operands of <N x i32> mul are extended from smaller size values like i8 and i16, the type of mul may be shrunk to use pmullw + pmulhw/pmulhuw instead of pmuludq, which generates better code. For targets with SSE41, pmulld is supported so no shrinking is needed. Differential Revision: http://reviews.llvm.org/D20931 llvm-svn: 272694
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp2
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp211
2 files changed, 210 insertions, 3 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 9ba68bf2be3..572fecac219 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -670,6 +670,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
+ case ISD::MULHS:
+ case ISD::MULHU:
case ISD::FADD:
case ISD::FSUB:
case ISD::FMUL:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 69cf0d269a5..06a0aa39603 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -26962,10 +26962,216 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+/// Different mul shrinking modes.
+enum ShrinkMode { MULS8, MULU8, MULS16, MULU16 };
+
+static bool canReduceVMulWidth(SDNode *N, SelectionDAG &DAG, ShrinkMode &Mode) {
+ EVT VT = N->getOperand(0).getValueType();
+ if (VT.getScalarSizeInBits() != 32)
+ return false;
+
+ assert(N->getNumOperands() == 2 && "NumOperands of Mul are 2");
+ unsigned SignBits[2] = {1, 1};
+ bool IsPositive[2] = {false, false};
+ for (unsigned i = 0; i < 2; i++) {
+ SDValue Opd = N->getOperand(i);
+
+ // DAG.ComputeNumSignBits return 1 for ISD::ANY_EXTEND, so we need to
+ // compute signbits for it separately.
+ if (Opd.getOpcode() == ISD::ANY_EXTEND) {
+ // For anyextend, it is safe to assume an appropriate number of leading
+ // sign/zero bits.
+ if (Opd.getOperand(0).getValueType().getVectorElementType() == MVT::i8)
+ SignBits[i] = 25;
+ else if (Opd.getOperand(0).getValueType().getVectorElementType() ==
+ MVT::i16)
+ SignBits[i] = 17;
+ else
+ return false;
+ IsPositive[i] = true;
+ } else if (Opd.getOpcode() == ISD::BUILD_VECTOR) {
+ // All the operands of BUILD_VECTOR need to be int constant.
+ // Find the smallest value range which all the operands belong to.
+ SignBits[i] = 32;
+ IsPositive[i] = true;
+ for (const SDValue &SubOp : Opd.getNode()->op_values()) {
+ if (SubOp.isUndef())
+ continue;
+ auto *CN = dyn_cast<ConstantSDNode>(SubOp);
+ if (!CN)
+ return false;
+ APInt IntVal = CN->getAPIntValue();
+ if (IntVal.isNegative())
+ IsPositive[i] = false;
+ SignBits[i] = std::min(SignBits[i], IntVal.getNumSignBits());
+ }
+ } else {
+ SignBits[i] = DAG.ComputeNumSignBits(Opd);
+ if (Opd.getOpcode() == ISD::ZERO_EXTEND)
+ IsPositive[i] = true;
+ }
+ }
+
+ bool AllPositive = IsPositive[0] && IsPositive[1];
+ unsigned MinSignBits = std::min(SignBits[0], SignBits[1]);
+ // When ranges are from -128 ~ 127, use MULS8 mode.
+ if (MinSignBits >= 25)
+ Mode = MULS8;
+ // When ranges are from 0 ~ 255, use MULU8 mode.
+ else if (AllPositive && MinSignBits >= 24)
+ Mode = MULU8;
+ // When ranges are from -32768 ~ 32767, use MULS16 mode.
+ else if (MinSignBits >= 17)
+ Mode = MULS16;
+ // When ranges are from 0 ~ 65535, use MULU16 mode.
+ else if (AllPositive && MinSignBits >= 16)
+ Mode = MULU16;
+ else
+ return false;
+ return true;
+}
+
+/// When the operands of vector mul are extended from smaller size values,
+/// like i8 and i16, the type of mul may be shrinked to generate more
+/// efficient code. Two typical patterns are handled:
+/// Pattern1:
+/// %2 = sext/zext <N x i8> %1 to <N x i32>
+/// %4 = sext/zext <N x i8> %3 to <N x i32>
+// or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
+/// %5 = mul <N x i32> %2, %4
+///
+/// Pattern2:
+/// %2 = zext/sext <N x i16> %1 to <N x i32>
+/// %4 = zext/sext <N x i16> %3 to <N x i32>
+/// or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
+/// %5 = mul <N x i32> %2, %4
+///
+/// There are four mul shrinking modes:
+/// If %2 == sext32(trunc8(%2)), i.e., the scalar value range of %2 is
+/// -128 to 128, and the scalar value range of %4 is also -128 to 128,
+/// generate pmullw+sext32 for it (MULS8 mode).
+/// If %2 == zext32(trunc8(%2)), i.e., the scalar value range of %2 is
+/// 0 to 255, and the scalar value range of %4 is also 0 to 255,
+/// generate pmullw+zext32 for it (MULU8 mode).
+/// If %2 == sext32(trunc16(%2)), i.e., the scalar value range of %2 is
+/// -32768 to 32767, and the scalar value range of %4 is also -32768 to 32767,
+/// generate pmullw+pmulhw for it (MULS16 mode).
+/// If %2 == zext32(trunc16(%2)), i.e., the scalar value range of %2 is
+/// 0 to 65535, and the scalar value range of %4 is also 0 to 65535,
+/// generate pmullw+pmulhuw for it (MULU16 mode).
+static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ // pmulld is supported since SSE41. It is better to use pmulld
+ // instead of pmullw+pmulhw.
+ if (Subtarget.hasSSE41())
+ return SDValue();
+
+ ShrinkMode Mode;
+ if (!canReduceVMulWidth(N, DAG, Mode))
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N->getOperand(0).getValueType();
+ unsigned RegSize = 128;
+ MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16);
+ EVT ReducedVT =
+ EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements());
+ // Shrink the operands of mul.
+ SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0);
+ SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1);
+
+ if (VT.getVectorNumElements() >= OpsVT.getVectorNumElements()) {
+ // Generate the lower part of mul: pmullw. For MULU8/MULS8, only the
+ // lower part is needed.
+ SDValue MulLo = DAG.getNode(ISD::MUL, DL, ReducedVT, NewN0, NewN1);
+ if (Mode == MULU8 || Mode == MULS8) {
+ return DAG.getNode((Mode == MULU8) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
+ DL, VT, MulLo);
+ } else {
+ MVT ResVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2);
+ // Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16,
+ // the higher part is also needed.
+ SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL,
+ ReducedVT, NewN0, NewN1);
+
+ // Repack the lower part and higher part result of mul into a wider
+ // result.
+ // Generate shuffle functioning as punpcklwd.
+ SmallVector<int, 16> ShuffleMask(VT.getVectorNumElements());
+ for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) {
+ ShuffleMask[2 * i] = i;
+ ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements();
+ }
+ SDValue ResLo =
+ DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]);
+ ResLo = DAG.getNode(ISD::BITCAST, DL, ResVT, ResLo);
+ // Generate shuffle functioning as punpckhwd.
+ for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) {
+ ShuffleMask[2 * i] = i + VT.getVectorNumElements() / 2;
+ ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements() * 3 / 2;
+ }
+ SDValue ResHi =
+ DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]);
+ ResHi = DAG.getNode(ISD::BITCAST, DL, ResVT, ResHi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi);
+ }
+ } else {
+ // When VT.getVectorNumElements() < OpsVT.getVectorNumElements(), we want
+ // to legalize the mul explicitly because implicit legalization for type
+ // <4 x i16> to <4 x i32> sometimes involves unnecessary unpack
+ // instructions which will not exist when we explicitly legalize it by
+ // extending <4 x i16> to <8 x i16> (concatenating the <4 x i16> val with
+ // <4 x i16> undef).
+ //
+ // Legalize the operands of mul.
+ SmallVector<SDValue, 16> Ops(RegSize / ReducedVT.getSizeInBits(),
+ DAG.getUNDEF(ReducedVT));
+ Ops[0] = NewN0;
+ NewN0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops);
+ Ops[0] = NewN1;
+ NewN1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops);
+
+ if (Mode == MULU8 || Mode == MULS8) {
+ // Generate lower part of mul: pmullw. For MULU8/MULS8, only the lower
+ // part is needed.
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1);
+
+ // convert the type of mul result to VT.
+ MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
+ SDValue Res = DAG.getNode(Mode == MULU8 ? ISD::ZERO_EXTEND_VECTOR_INREG
+ : ISD::SIGN_EXTEND_VECTOR_INREG,
+ DL, ResVT, Mul);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
+ DAG.getIntPtrConstant(0, DL));
+ } else {
+ // Generate the lower and higher part of mul: pmulhw/pmulhuw. For
+ // MULU16/MULS16, both parts are needed.
+ SDValue MulLo = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1);
+ SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL,
+ OpsVT, NewN0, NewN1);
+
+ // Repack the lower part and higher part result of mul into a wider
+ // result. Make sure the type of mul result is VT.
+ MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
+ SDValue Res = DAG.getNode(X86ISD::UNPCKL, DL, OpsVT, MulLo, MulHi);
+ Res = DAG.getNode(ISD::BITCAST, DL, ResVT, Res);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
+ DAG.getIntPtrConstant(0, DL));
+ }
+ }
+}
+
/// Optimize a single multiply with constant into two operations in order to
/// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA.
static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI) {
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
+ EVT VT = N->getValueType(0);
+ if (DCI.isBeforeLegalize() && VT.isVector())
+ return reduceVMULWidth(N, DAG, Subtarget);
+
// An imul is usually smaller than the alternative sequence.
if (DAG.getMachineFunction().getFunction()->optForMinSize())
return SDValue();
@@ -26973,7 +27179,6 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer())
return SDValue();
- EVT VT = N->getValueType(0);
if (VT != MVT::i64 && VT != MVT::i32)
return SDValue();
@@ -30268,7 +30473,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::ADD: return combineAdd(N, DAG, Subtarget);
case ISD::SUB: return combineSub(N, DAG, Subtarget);
case X86ISD::ADC: return combineADC(N, DAG, DCI);
- case ISD::MUL: return combineMul(N, DAG, DCI);
+ case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
case ISD::SHL:
case ISD::SRA:
case ISD::SRL: return combineShift(N, DAG, DCI, Subtarget);
OpenPOWER on IntegriCloud