From ed6a0a817ffdd4078c25336b473094c60f98b509 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Sun, 4 Nov 2018 17:31:27 +0000 Subject: [X86] Add vector shift by immediate to SimplifyDemandedBitsForTargetNode. Summary: This also enables some constant folding from KnownBits propagation. This helps on some cases vXi64 case in 32-bit mode where constant vectors appear as vXi32 and a bitcast. This can prevent getNode from constant folding sra/shl/srl. Reviewers: RKSimon, spatel Reviewed By: spatel Subscribers: llvm-commits Differential Revision: https://reviews.llvm.org/D54069 llvm-svn: 346102 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) (limited to 'llvm/lib/Target') diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 905b99590a6..891f4a4cbdf 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -31817,6 +31817,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode( bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( SDValue Op, const APInt &OriginalDemandedBits, KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const { + unsigned BitWidth = OriginalDemandedBits.getBitWidth(); unsigned Opc = Op.getOpcode(); switch(Opc) { case X86ISD::PMULDQ: @@ -31833,6 +31834,42 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( return true; break; } + case X86ISD::VSHLI: { + if (auto *ShiftImm = dyn_cast(Op.getOperand(1))) { + if (ShiftImm->getAPIntValue().uge(BitWidth)) + break; + + KnownBits KnownOp; + unsigned ShAmt = ShiftImm->getZExtValue(); + APInt DemandedMask = OriginalDemandedBits.lshr(ShAmt); + if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, KnownOp, TLO, + Depth + 1)) + return true; + } + break; + } + case X86ISD::VSRAI: + case X86ISD::VSRLI: { + if (auto *ShiftImm = dyn_cast(Op.getOperand(1))) { + if (ShiftImm->getAPIntValue().uge(BitWidth)) + break; + + KnownBits KnownOp; + unsigned ShAmt = ShiftImm->getZExtValue(); + APInt DemandedMask = OriginalDemandedBits << ShAmt; + + // If any of the demanded bits are produced by the sign extension, we also + // demand the input sign bit. + if (Opc == X86ISD::VSRAI && + OriginalDemandedBits.countLeadingZeros() < ShAmt) + DemandedMask.setSignBit(); + + if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, KnownOp, TLO, + Depth + 1)) + return true; + } + break; + } } return TargetLowering::SimplifyDemandedBitsForTargetNode( @@ -34861,6 +34898,11 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, return getConstVector(EltBits, UndefElts, VT.getSimpleVT(), DAG, SDLoc(N)); } + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (TLI.SimplifyDemandedBits(SDValue(N, 0), + APInt::getAllOnesValue(NumBitsPerElt), DCI)) + return SDValue(N, 0); + return SDValue(); } -- cgit v1.2.3