diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 11 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 14 |
2 files changed, 16 insertions, 9 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 508316a1fad..6489ca57b33 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -116,7 +116,8 @@ bool ConstantFPSDNode::isValueValidForType(EVT VT, // ISD Namespace //===----------------------------------------------------------------------===// -bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { +bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal, + bool AllowShrink) { auto *BV = dyn_cast<BuildVectorSDNode>(N); if (!BV) return false; @@ -124,9 +125,11 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { APInt SplatUndef; unsigned SplatBitSize; bool HasUndefs; - EVT EltVT = N->getValueType(0).getVectorElementType(); - return BV->isConstantSplat(SplatVal, SplatUndef, SplatBitSize, HasUndefs) && - EltVT.getSizeInBits() >= SplatBitSize; + unsigned EltSize = N->getValueType(0).getVectorElementType().getSizeInBits(); + unsigned MinSplatBits = AllowShrink ? 0 : EltSize; + return BV->isConstantSplat(SplatVal, SplatUndef, SplatBitSize, HasUndefs, + MinSplatBits) && + EltSize >= SplatBitSize; } // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index ce00a4a9665..42c33953507 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -29567,8 +29567,9 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, // In SetLT case, The second operand of the comparison can be either 1 or 0. APInt SplatVal; if ((CC == ISD::SETLT) && - !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal) && - SplatVal == 1) || + !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal, + /*AllowShrink*/false) && + SplatVal.isOneValue()) || (ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode())))) return false; @@ -32083,7 +32084,8 @@ static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG, return SDValue(); APInt SplatVal; - if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || + if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal, + /*AllowShrink*/false) || !SplatVal.isMask()) return SDValue(); @@ -32667,7 +32669,8 @@ static SDValue detectUSatPattern(SDValue In, EVT VT) { "Unexpected types for truncate operation"); APInt C; - if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C)) { + if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C, + /*AllowShrink*/false)) { // C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according // the element size of the destination type. return C.isMask(VT.getScalarSizeInBits()) ? In.getOperand(0) : @@ -35374,7 +35377,8 @@ static SDValue combineIncDecVector(SDNode *N, SelectionDAG &DAG) { SDNode *N1 = N->getOperand(1).getNode(); APInt SplatVal; - if (!ISD::isConstantSplatVector(N1, SplatVal) || !SplatVal.isOneValue()) + if (!ISD::isConstantSplatVector(N1, SplatVal, /*AllowShrink*/false) || + !SplatVal.isOneValue()) return SDValue(); SDValue AllOnesVec = getOnesVector(VT, DAG, SDLoc(N)); |