From 5bc10ede53000011f43f4943e8307304c77807b1 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Mon, 25 Sep 2017 19:26:08 +0000 Subject: [SelectionDAG] Teach simplifyDemandedBits to handle shifts by constant splat vectors This teach simplifyDemandedBits to handle constant splat vector shifts. This required changing some uses of getZExtValue to getLimitedValue since we can't rely on legalization using getShiftAmountTy for the shift amount. I believe there may have been a bug in the ((X << C1) >>u ShAmt) handling where we didn't check if the inner shift was too large. I've fixed that here. I had to add new patterns to ARM because the zext/sext the patterns were trying to look for got turned into an any_extend with this patch. Happy to split that out too, but not sure how to test without this change. Differential Revision: https://reviews.llvm.org/D37665 llvm-svn: 314139 --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 132 ++++++++++++----------- 1 file changed, 70 insertions(+), 62 deletions(-) (limited to 'llvm/lib/CodeGen') diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 4004d69c580..f6d14a8546c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -779,33 +779,38 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, break; } case ISD::SHL: - if (ConstantSDNode *SA = dyn_cast(Op.getOperand(1))) { - unsigned ShAmt = SA->getZExtValue(); + if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) { SDValue InOp = Op.getOperand(0); // If the shift count is an invalid immediate, don't do anything. - if (ShAmt >= BitWidth) + if (SA->getAPIntValue().uge(BitWidth)) break; + unsigned ShAmt = SA->getZExtValue(); + // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a // single shift. We can do this if the bottom bits (which are shifted // out) are never demanded. - if (InOp.getOpcode() == ISD::SRL && - isa(InOp.getOperand(1))) { - if (ShAmt && (NewMask & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { - unsigned C1= cast(InOp.getOperand(1))->getZExtValue(); - unsigned Opc = ISD::SHL; - int Diff = ShAmt-C1; - if (Diff < 0) { - Diff = -Diff; - Opc = ISD::SRL; - } + if (InOp.getOpcode() == ISD::SRL) { + if (ConstantSDNode *SA2 = isConstOrConstSplat(InOp.getOperand(1))) { + if (ShAmt && (NewMask & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { + if (SA2->getAPIntValue().ult(BitWidth)) { + unsigned C1 = SA2->getZExtValue(); + unsigned Opc = ISD::SHL; + int Diff = ShAmt-C1; + if (Diff < 0) { + Diff = -Diff; + Opc = ISD::SRL; + } - SDValue NewSA = - TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType()); - EVT VT = Op.getValueType(); - return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, - InOp.getOperand(0), NewSA)); + SDValue NewSA = + TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType()); + EVT VT = Op.getValueType(); + return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, + InOp.getOperand(0), + NewSA)); + } + } } } @@ -817,7 +822,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, if (InOp.getNode()->getOpcode() == ISD::ANY_EXTEND) { SDValue InnerOp = InOp.getOperand(0); EVT InnerVT = InnerOp.getValueType(); - unsigned InnerBits = InnerVT.getSizeInBits(); + unsigned InnerBits = InnerVT.getScalarSizeInBits(); if (ShAmt < InnerBits && NewMask.getActiveBits() <= InnerBits && isTypeDesirableForOp(ISD::SHL, InnerVT)) { EVT ShTy = getShiftAmountTy(InnerVT, DL); @@ -836,45 +841,42 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // (shl (anyext x), c2-c1). This requires that the bottom c1 bits // aren't demanded (as above) and that the shifted upper c1 bits of // x aren't demanded. - if (InOp.hasOneUse() && - InnerOp.getOpcode() == ISD::SRL && - InnerOp.hasOneUse() && - isa(InnerOp.getOperand(1))) { - unsigned InnerShAmt = cast(InnerOp.getOperand(1)) - ->getZExtValue(); - if (InnerShAmt < ShAmt && - InnerShAmt < InnerBits && - NewMask.getActiveBits() <= (InnerBits - InnerShAmt + ShAmt) && - NewMask.countTrailingZeros() >= ShAmt) { - SDValue NewSA = - TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, - Op.getOperand(1).getValueType()); - EVT VT = Op.getValueType(); - SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, - InnerOp.getOperand(0)); - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, - NewExt, NewSA)); + if (InOp.hasOneUse() && InnerOp.getOpcode() == ISD::SRL && + InnerOp.hasOneUse()) { + if (ConstantSDNode *SA2 = isConstOrConstSplat(InnerOp.getOperand(1))) { + unsigned InnerShAmt = SA2->getLimitedValue(InnerBits); + if (InnerShAmt < ShAmt && + InnerShAmt < InnerBits && + NewMask.getActiveBits() <= (InnerBits - InnerShAmt + ShAmt) && + NewMask.countTrailingZeros() >= ShAmt) { + SDValue NewSA = + TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, + Op.getOperand(1).getValueType()); + EVT VT = Op.getValueType(); + SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, + InnerOp.getOperand(0)); + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, + NewExt, NewSA)); + } } } } - Known.Zero <<= SA->getZExtValue(); - Known.One <<= SA->getZExtValue(); + Known.Zero <<= ShAmt; + Known.One <<= ShAmt; // low bits known zero. - Known.Zero.setLowBits(SA->getZExtValue()); + Known.Zero.setLowBits(ShAmt); } break; case ISD::SRL: - if (ConstantSDNode *SA = dyn_cast(Op.getOperand(1))) { - EVT VT = Op.getValueType(); - unsigned ShAmt = SA->getZExtValue(); - unsigned VTSize = VT.getSizeInBits(); + if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) { SDValue InOp = Op.getOperand(0); // If the shift count is an invalid immediate, don't do anything. - if (ShAmt >= BitWidth) + if (SA->getAPIntValue().uge(BitWidth)) break; + unsigned ShAmt = SA->getZExtValue(); APInt InDemandedMask = (NewMask << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that @@ -885,21 +887,27 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a // single shift. We can do this if the top bits (which are shifted out) // are never demanded. - if (InOp.getOpcode() == ISD::SHL && - isa(InOp.getOperand(1))) { - if (ShAmt && (NewMask & APInt::getHighBitsSet(VTSize, ShAmt)) == 0) { - unsigned C1= cast(InOp.getOperand(1))->getZExtValue(); - unsigned Opc = ISD::SRL; - int Diff = ShAmt-C1; - if (Diff < 0) { - Diff = -Diff; - Opc = ISD::SHL; - } + if (InOp.getOpcode() == ISD::SHL) { + if (ConstantSDNode *SA2 = isConstOrConstSplat(InOp.getOperand(1))) { + if (ShAmt && + (NewMask & APInt::getHighBitsSet(BitWidth, ShAmt)) == 0) { + if (SA2->getAPIntValue().ult(BitWidth)) { + unsigned C1 = SA2->getZExtValue(); + unsigned Opc = ISD::SRL; + int Diff = ShAmt-C1; + if (Diff < 0) { + Diff = -Diff; + Opc = ISD::SHL; + } - SDValue NewSA = - TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType()); - return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, - InOp.getOperand(0), NewSA)); + SDValue NewSA = + TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType()); + EVT VT = Op.getValueType(); + return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, + InOp.getOperand(0), + NewSA)); + } + } } } @@ -923,14 +931,14 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, TLO.DAG.getNode(ISD::SRL, dl, Op.getValueType(), Op.getOperand(0), Op.getOperand(1))); - if (ConstantSDNode *SA = dyn_cast(Op.getOperand(1))) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) { EVT VT = Op.getValueType(); - unsigned ShAmt = SA->getZExtValue(); // If the shift count is an invalid immediate, don't do anything. - if (ShAmt >= BitWidth) + if (SA->getAPIntValue().uge(BitWidth)) break; + unsigned ShAmt = SA->getZExtValue(); APInt InDemandedMask = (NewMask << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that -- cgit v1.2.3