diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 45 |
1 files changed, 39 insertions, 6 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index bd069534901..436ee4c4205 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -32466,6 +32466,39 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( } break; } + case X86ISD::MOVMSK: { + SDValue Src = Op.getOperand(0); + MVT VT = Op.getSimpleValueType(); + MVT SrcVT = Src.getSimpleValueType(); + unsigned SrcBits = SrcVT.getScalarSizeInBits(); + unsigned NumElts = SrcVT.getVectorNumElements(); + + // If we don't need the sign bits at all just return zero. + if (OriginalDemandedBits.countTrailingZeros() >= NumElts) + return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT)); + + // Only demand the vector elements of the sign bits we need. + APInt KnownUndef, KnownZero; + APInt DemandedElts = OriginalDemandedBits.zextOrTrunc(NumElts); + if (SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef, KnownZero, + TLO, Depth + 1)) + return true; + + Known.Zero = KnownZero.zextOrSelf(BitWidth); + Known.Zero.setHighBits(BitWidth - NumElts); + + // MOVMSK only uses the MSB from each vector element. + KnownBits KnownSrc; + if (SimplifyDemandedBits(Src, APInt::getSignMask(SrcBits), KnownSrc, TLO, + Depth + 1)) + return true; + + if (KnownSrc.One[SrcBits - 1]) + Known.One.setLowBits(NumElts); + else if (KnownSrc.Zero[SrcBits - 1]) + Known.Zero.setLowBits(NumElts); + return false; + } } return TargetLowering::SimplifyDemandedBitsForTargetNode( @@ -39566,10 +39599,11 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { SDValue Src = N->getOperand(0); MVT SrcVT = Src.getSimpleValueType(); + MVT VT = N->getSimpleValueType(0); // Perform constant folding. if (ISD::isBuildVectorOfConstantSDNodes(Src.getNode())) { - assert(N->getValueType(0) == MVT::i32 && "Unexpected result type"); + assert(VT== MVT::i32 && "Unexpected result type"); APInt Imm(32, 0); for (unsigned Idx = 0, e = Src.getNumOperands(); Idx < e; ++Idx) { SDValue In = Src.getOperand(Idx); @@ -39577,7 +39611,7 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, cast<ConstantSDNode>(In)->getAPIntValue().isNegative()) Imm.setBit(Idx); } - return DAG.getConstant(Imm, SDLoc(N), N->getValueType(0)); + return DAG.getConstant(Imm, SDLoc(N), VT); } // Look through int->fp bitcasts that don't change the element width. @@ -39587,11 +39621,10 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, EVT(SrcVT).changeVectorElementTypeToInteger()) Src = Src.getOperand(0); + // Simplify the inputs. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - - // MOVMSK only uses the MSB from each vector element. - APInt DemandedMask(APInt::getSignMask(SrcVT.getScalarSizeInBits())); - if (TLI.SimplifyDemandedBits(Src, DemandedMask, DCI)) + APInt DemandedMask(APInt::getAllOnesValue(VT.getScalarSizeInBits())); + if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI)) return SDValue(N, 0); // Combine (movmsk (setne (and X, (1 << C)), 0)) -> (movmsk (X << C)). |