diff options
| -rw-r--r-- | llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 | ||||
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 302 | ||||
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 | ||||
| -rw-r--r-- | llvm/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll | 6 | ||||
| -rw-r--r-- | llvm/test/CodeGen/X86/avx2-vector-shifts.ll | 27 | ||||
| -rw-r--r-- | llvm/test/CodeGen/X86/sse2-vector-shifts.ll | 152 | 
6 files changed, 341 insertions, 164 deletions
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 0b18d1d358c..a6c72ca2d1c 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1522,6 +1522,11 @@ public:                         unsigned MinSplatBits = 0,                         bool isBigEndian = false) const; +  /// isConstantSplat - Simpler form of isConstantSplat. Get the constant splat +  /// when you only care about the value. Returns nullptr if this isn't a +  /// constant splat vector. +  ConstantSDNode *isConstantSplat() const; +    bool isConstant() const;    static inline bool classof(const SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 842d5a3ee73..c45d6a1a790 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -280,7 +280,7 @@ namespace {      SDValue XformToShuffleWithZero(SDNode *N);      SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); -    SDValue visitShiftByConstant(SDNode *N, unsigned Amt); +    SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt);      bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);      SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N); @@ -634,7 +634,23 @@ static bool isOneUseSetCC(SDValue N) {    return false;  } -// \brief Returns the SDNode if it is a constant BuildVector or constant int. +/// isConstantSplatVector - Returns true if N is a BUILD_VECTOR node whose +/// elements are all the same constant or undefined. +static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { +  BuildVectorSDNode *C = dyn_cast<BuildVectorSDNode>(N); +  if (!C) +    return false; + +  APInt SplatUndef; +  unsigned SplatBitSize; +  bool HasAnyUndefs; +  EVT EltVT = N->getValueType(0).getVectorElementType(); +  return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, +                             HasAnyUndefs) && +          EltVT.getSizeInBits() >= SplatBitSize); +} + +// \brief Returns the SDNode if it is a constant BuildVector or constant.  static SDNode *isConstantBuildVectorOrConstantInt(SDValue N) {    if (isa<ConstantSDNode>(N))      return N.getNode(); @@ -644,6 +660,18 @@ static SDNode *isConstantBuildVectorOrConstantInt(SDValue N) {    return NULL;  } +// \brief Returns the SDNode if it is a constant splat BuildVector or constant +// int. +static ConstantSDNode *isConstOrConstSplat(SDValue N) { +  if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) +    return CN; + +  if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) +    return BV->isConstantSplat(); + +  return nullptr; +} +  SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL,                                      SDValue N0, SDValue N1) {    EVT VT = N0.getValueType(); @@ -1830,22 +1858,6 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) {    return SDValue();  } -/// isConstantSplatVector - Returns true if N is a BUILD_VECTOR node whose -/// elements are all the same constant or undefined. -static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { -  BuildVectorSDNode *C = dyn_cast<BuildVectorSDNode>(N); -  if (!C) -    return false; - -  APInt SplatUndef; -  unsigned SplatBitSize; -  bool HasAnyUndefs; -  EVT EltVT = N->getValueType(0).getVectorElementType(); -  return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, -                             HasAnyUndefs) && -          EltVT.getSizeInBits() >= SplatBitSize); -} -  SDValue DAGCombiner::visitMUL(SDNode *N) {    SDValue N0 = N->getOperand(0);    SDValue N1 = N->getOperand(1); @@ -3805,11 +3817,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {  /// visitShiftByConstant - Handle transforms common to the three shifts, when  /// the shift amount is a constant. -SDValue DAGCombiner::visitShiftByConstant(SDNode *N, unsigned Amt) { -  assert(isa<ConstantSDNode>(N->getOperand(1)) && -         "Expected an ConstantSDNode operand."); +SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) {    // We can't and shouldn't fold opaque constants. -  if (cast<ConstantSDNode>(N->getOperand(1))->isOpaque()) +  if (Amt->isOpaque())      return SDValue();    SDNode *LHS = N->getOperand(0).getNode(); @@ -3888,11 +3898,11 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {    if (N->hasOneUse() && N->getOperand(0).hasOneUse()) {      SDValue N01 = N->getOperand(0).getOperand(1); -    if (ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N01)) { +    if (ConstantSDNode *N01C = isConstOrConstSplat(N01)) {        EVT TruncVT = N->getValueType(0);        SDValue N00 = N->getOperand(0).getOperand(0);        APInt TruncC = N01C->getAPIntValue(); -      TruncC = TruncC.trunc(TruncVT.getScalarType().getSizeInBits()); +      TruncC = TruncC.trunc(TruncVT.getScalarSizeInBits());        return DAG.getNode(ISD::AND, SDLoc(N), TruncVT,                           DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N00), @@ -3921,7 +3931,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {    ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);    ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);    EVT VT = N0.getValueType(); -  unsigned OpSizeInBits = VT.getScalarType().getSizeInBits(); +  unsigned OpSizeInBits = VT.getScalarSizeInBits();    // fold vector ops    if (VT.isVector()) { @@ -3931,18 +3941,21 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {      BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);      // If setcc produces all-one true value then:      // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV) -    if (N1CV && N1CV->isConstant() && -        TLI.getBooleanContents(true) == -          TargetLowering::ZeroOrNegativeOneBooleanContent && -        N0.getOpcode() == ISD::AND) { -      SDValue N00 = N0->getOperand(0); -      SDValue N01 = N0->getOperand(1); -      BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01); - -      if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC) { -        SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, VT, N01CV, N1CV); -        if (C.getNode()) -          return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); +    if (N1CV && N1CV->isConstant()) { +      if (N0.getOpcode() == ISD::AND && +          TLI.getBooleanContents(true) == +          TargetLowering::ZeroOrNegativeOneBooleanContent) { +        SDValue N00 = N0->getOperand(0); +        SDValue N01 = N0->getOperand(1); +        BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01); + +        if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC) { +          SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, VT, N01CV, N1CV); +          if (C.getNode()) +            return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); +        } +      } else { +        N1C = isConstOrConstSplat(N1);        }      }    } @@ -3978,14 +3991,15 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {      return SDValue(N, 0);    // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2)) -  if (N1C && N0.getOpcode() == ISD::SHL && -      N0.getOperand(1).getOpcode() == ISD::Constant) { -    uint64_t c1 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); -    uint64_t c2 = N1C->getZExtValue(); -    if (c1 + c2 >= OpSizeInBits) -      return DAG.getConstant(0, VT); -    return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), -                       DAG.getConstant(c1 + c2, N1.getValueType())); +  if (N1C && N0.getOpcode() == ISD::SHL) { +    if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { +      uint64_t c1 = N0C1->getZExtValue(); +      uint64_t c2 = N1C->getZExtValue(); +      if (c1 + c2 >= OpSizeInBits) +        return DAG.getConstant(0, VT); +      return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), +                         DAG.getConstant(c1 + c2, N1.getValueType())); +    }    }    // fold (shl (ext (shl x, c1)), c2) -> (ext (shl x, (add c1, c2))) @@ -3996,20 +4010,21 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {    if (N1C && (N0.getOpcode() == ISD::ZERO_EXTEND ||                N0.getOpcode() == ISD::ANY_EXTEND ||                N0.getOpcode() == ISD::SIGN_EXTEND) && -      N0.getOperand(0).getOpcode() == ISD::SHL && -      isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) { -    uint64_t c1 = -      cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue(); -    uint64_t c2 = N1C->getZExtValue(); -    EVT InnerShiftVT = N0.getOperand(0).getValueType(); -    uint64_t InnerShiftSize = InnerShiftVT.getScalarType().getSizeInBits(); -    if (c2 >= OpSizeInBits - InnerShiftSize) { -      if (c1 + c2 >= OpSizeInBits) -        return DAG.getConstant(0, VT); -      return DAG.getNode(ISD::SHL, SDLoc(N0), VT, -                         DAG.getNode(N0.getOpcode(), SDLoc(N0), VT, -                                     N0.getOperand(0)->getOperand(0)), -                         DAG.getConstant(c1 + c2, N1.getValueType())); +      N0.getOperand(0).getOpcode() == ISD::SHL) { +    SDValue N0Op0 = N0.getOperand(0); +    if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { +      uint64_t c1 = N0Op0C1->getZExtValue(); +      uint64_t c2 = N1C->getZExtValue(); +      EVT InnerShiftVT = N0Op0.getValueType(); +      uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits(); +      if (c2 >= OpSizeInBits - InnerShiftSize) { +        if (c1 + c2 >= OpSizeInBits) +          return DAG.getConstant(0, VT); +        return DAG.getNode(ISD::SHL, SDLoc(N0), VT, +                           DAG.getNode(N0.getOpcode(), SDLoc(N0), VT, +                                       N0Op0->getOperand(0)), +                           DAG.getConstant(c1 + c2, N1.getValueType())); +      }      }    } @@ -4017,19 +4032,20 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {    // Only fold this if the inner zext has no other uses to avoid increasing    // the total number of instructions.    if (N1C && N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() && -      N0.getOperand(0).getOpcode() == ISD::SRL && -      isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) { -    uint64_t c1 = -      cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue(); -    if (c1 < VT.getSizeInBits()) { -      uint64_t c2 = N1C->getZExtValue(); -      if (c1 == c2) { -        SDValue NewOp0 = N0.getOperand(0); -        EVT CountVT = NewOp0.getOperand(1).getValueType(); -        SDValue NewSHL = DAG.getNode(ISD::SHL, SDLoc(N), NewOp0.getValueType(), -                                     NewOp0, DAG.getConstant(c2, CountVT)); -        AddToWorkList(NewSHL.getNode()); -        return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); +      N0.getOperand(0).getOpcode() == ISD::SRL) { +    SDValue N0Op0 = N0.getOperand(0); +    if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { +      uint64_t c1 = N0Op0C1->getZExtValue(); +      if (c1 < VT.getScalarSizeInBits()) { +        uint64_t c2 = N1C->getZExtValue(); +        if (c1 == c2) { +          SDValue NewOp0 = N0.getOperand(0); +          EVT CountVT = NewOp0.getOperand(1).getValueType(); +          SDValue NewSHL = DAG.getNode(ISD::SHL, SDLoc(N), NewOp0.getValueType(), +                                       NewOp0, DAG.getConstant(c2, CountVT)); +          AddToWorkList(NewSHL.getNode()); +          return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); +        }        }      }    } @@ -4038,40 +4054,39 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {    //                               (and (srl x, (sub c1, c2), MASK)    // Only fold this if the inner shift has no other uses -- if it does, folding    // this will increase the total number of instructions. -  if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() && -      N0.getOperand(1).getOpcode() == ISD::Constant) { -    uint64_t c1 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); -    if (c1 < VT.getSizeInBits()) { -      uint64_t c2 = N1C->getZExtValue(); -      APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(), -                                         VT.getSizeInBits() - c1); -      SDValue Shift; -      if (c2 > c1) { -        Mask = Mask.shl(c2-c1); -        Shift = DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), -                            DAG.getConstant(c2-c1, N1.getValueType())); -      } else { -        Mask = Mask.lshr(c1-c2); -        Shift = DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), -                            DAG.getConstant(c1-c2, N1.getValueType())); +  if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { +    if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { +      uint64_t c1 = N0C1->getZExtValue(); +      if (c1 < OpSizeInBits) { +        uint64_t c2 = N1C->getZExtValue(); +        APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1); +        SDValue Shift; +        if (c2 > c1) { +          Mask = Mask.shl(c2 - c1); +          Shift = DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), +                              DAG.getConstant(c2 - c1, N1.getValueType())); +        } else { +          Mask = Mask.lshr(c1 - c2); +          Shift = DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), +                              DAG.getConstant(c1 - c2, N1.getValueType())); +        } +        return DAG.getNode(ISD::AND, SDLoc(N0), VT, Shift, +                           DAG.getConstant(Mask, VT));        } -      return DAG.getNode(ISD::AND, SDLoc(N0), VT, Shift, -                         DAG.getConstant(Mask, VT));      }    }    // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))    if (N1C && N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1)) { +    unsigned BitSize = VT.getScalarSizeInBits();      SDValue HiBitsMask = -      DAG.getConstant(APInt::getHighBitsSet(VT.getSizeInBits(), -                                            VT.getSizeInBits() - -                                              N1C->getZExtValue()), -                      VT); +      DAG.getConstant(APInt::getHighBitsSet(BitSize, +                                            BitSize - N1C->getZExtValue()), VT);      return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0),                         HiBitsMask);    }    if (N1C) { -    SDValue NewSHL = visitShiftByConstant(N, N1C->getZExtValue()); +    SDValue NewSHL = visitShiftByConstant(N, N1C);      if (NewSHL.getNode())        return NewSHL;    } @@ -4091,6 +4106,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {    if (VT.isVector()) {      SDValue FoldedVOp = SimplifyVBinOp(N);      if (FoldedVOp.getNode()) return FoldedVOp; + +    N1C = isConstOrConstSplat(N1);    }    // fold (sra c1, c2) -> (sra c1, c2) @@ -4124,11 +4141,12 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {    // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))    if (N1C && N0.getOpcode() == ISD::SRA) { -    if (ConstantSDNode *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { +    if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) {        unsigned Sum = N1C->getZExtValue() + C1->getZExtValue(); -      if (Sum >= OpSizeInBits) Sum = OpSizeInBits-1; +      if (Sum >= OpSizeInBits) +        Sum = OpSizeInBits - 1;        return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0), -                         DAG.getConstant(Sum, N1C->getValueType(0))); +                         DAG.getConstant(Sum, N1.getValueType()));      }    } @@ -4137,14 +4155,17 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {    // result_size - n != m.    // If truncate is free for the target sext(shl) is likely to result in better    // code. -  if (N0.getOpcode() == ISD::SHL) { +  if (N0.getOpcode() == ISD::SHL && N1C) {      // Get the two constanst of the shifts, CN0 = m, CN = n. -    const ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); -    if (N01C && N1C) { +    const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1)); +    if (N01C) { +      LLVMContext &Ctx = *DAG.getContext();        // Determine what the truncate's result bitsize and type would be. -      EVT TruncVT = -        EVT::getIntegerVT(*DAG.getContext(), -                          OpSizeInBits - N1C->getZExtValue()); +      EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue()); + +      if (VT.isVector()) +        TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements()); +        // Determine the residual right-shift amount.        signed ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue(); @@ -4177,26 +4198,27 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {        return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);    } -  // fold (sra (trunc (sr x, c1)), c2) -> (trunc (sra x, c1+c2)) +  // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))    //      if c1 is equal to the number of bits the trunc removes    if (N0.getOpcode() == ISD::TRUNCATE &&        (N0.getOperand(0).getOpcode() == ISD::SRL ||         N0.getOperand(0).getOpcode() == ISD::SRA) &&        N0.getOperand(0).hasOneUse() &&        N0.getOperand(0).getOperand(1).hasOneUse() && -      N1C && isa<ConstantSDNode>(N0.getOperand(0).getOperand(1))) { -    EVT LargeVT = N0.getOperand(0).getValueType(); -    ConstantSDNode *LargeShiftAmt = -      cast<ConstantSDNode>(N0.getOperand(0).getOperand(1)); - -    if (LargeVT.getScalarType().getSizeInBits() - OpSizeInBits == -        LargeShiftAmt->getZExtValue()) { -      SDValue Amt = -        DAG.getConstant(LargeShiftAmt->getZExtValue() + N1C->getZExtValue(), -              getShiftAmountTy(N0.getOperand(0).getOperand(0).getValueType())); -      SDValue SRA = DAG.getNode(ISD::SRA, SDLoc(N), LargeVT, -                                N0.getOperand(0).getOperand(0), Amt); -      return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, SRA); +      N1C) { +    SDValue N0Op0 = N0.getOperand(0); +    if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) { +      unsigned LargeShiftVal = LargeShift->getZExtValue(); +      EVT LargeVT = N0Op0.getValueType(); + +      if (LargeVT.getScalarSizeInBits() - OpSizeInBits == LargeShiftVal) { +        SDValue Amt = +          DAG.getConstant(LargeShiftVal + N1C->getZExtValue(), +                          getShiftAmountTy(N0Op0.getOperand(0).getValueType())); +        SDValue SRA = DAG.getNode(ISD::SRA, SDLoc(N), LargeVT, +                                  N0Op0.getOperand(0), Amt); +        return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, SRA); +      }      }    } @@ -4210,7 +4232,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {      return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);    if (N1C) { -    SDValue NewSRA = visitShiftByConstant(N, N1C->getZExtValue()); +    SDValue NewSRA = visitShiftByConstant(N, N1C);      if (NewSRA.getNode())        return NewSRA;    } @@ -4230,6 +4252,8 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {    if (VT.isVector()) {      SDValue FoldedVOp = SimplifyVBinOp(N);      if (FoldedVOp.getNode()) return FoldedVOp; + +    N1C = isConstOrConstSplat(N1);    }    // fold (srl c1, c2) -> c1 >>u c2 @@ -4250,14 +4274,15 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {      return DAG.getConstant(0, VT);    // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2)) -  if (N1C && N0.getOpcode() == ISD::SRL && -      N0.getOperand(1).getOpcode() == ISD::Constant) { -    uint64_t c1 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); -    uint64_t c2 = N1C->getZExtValue(); -    if (c1 + c2 >= OpSizeInBits) -      return DAG.getConstant(0, VT); -    return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), -                       DAG.getConstant(c1 + c2, N1.getValueType())); +  if (N1C && N0.getOpcode() == ISD::SRL) { +    if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) { +      uint64_t c1 = N01C->getZExtValue(); +      uint64_t c2 = N1C->getZExtValue(); +      if (c1 + c2 >= OpSizeInBits) +        return DAG.getConstant(0, VT); +      return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), +                         DAG.getConstant(c1 + c2, N1.getValueType())); +    }    }    // fold (srl (trunc (srl x, c1)), c2) -> 0 or (trunc (srl x, (add c1, c2))) @@ -4282,18 +4307,21 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {    }    // fold (srl (shl x, c), c) -> (and x, cst2) -  if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 && -      N0.getValueSizeInBits() <= 64) { -    uint64_t ShAmt = N1C->getZExtValue()+64-N0.getValueSizeInBits(); -    return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), -                       DAG.getConstant(~0ULL >> ShAmt, VT)); +  if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1) { +    unsigned BitSize = N0.getScalarValueSizeInBits(); +    if (BitSize <= 64) { +      uint64_t ShAmt = N1C->getZExtValue() + 64 - BitSize; +      return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), +                         DAG.getConstant(~0ULL >> ShAmt, VT)); +    }    }    // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)    if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {      // Shifting in all undef bits?      EVT SmallVT = N0.getOperand(0).getValueType(); -    if (N1C->getZExtValue() >= SmallVT.getSizeInBits()) +    unsigned BitSize = SmallVT.getScalarSizeInBits(); +    if (N1C->getZExtValue() >= BitSize)        return DAG.getUNDEF(VT);      if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) { @@ -4302,7 +4330,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {                                         N0.getOperand(0),                            DAG.getConstant(ShiftAmt, getShiftAmountTy(SmallVT)));        AddToWorkList(SmallShift.getNode()); -      APInt Mask = APInt::getAllOnesValue(VT.getSizeInBits()).lshr(ShiftAmt); +      APInt Mask = APInt::getAllOnesValue(OpSizeInBits).lshr(ShiftAmt);        return DAG.getNode(ISD::AND, SDLoc(N), VT,                           DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, SmallShift),                           DAG.getConstant(Mask, VT)); @@ -4311,14 +4339,14 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {    // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign    // bit, which is unmodified by sra. -  if (N1C && N1C->getZExtValue() + 1 == VT.getSizeInBits()) { +  if (N1C && N1C->getZExtValue() + 1 == OpSizeInBits) {      if (N0.getOpcode() == ISD::SRA)        return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);    }    // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit).    if (N1C && N0.getOpcode() == ISD::CTLZ && -      N1C->getAPIntValue() == Log2_32(VT.getSizeInBits())) { +      N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {      APInt KnownZero, KnownOne;      DAG.ComputeMaskedBits(N0.getOperand(0), KnownZero, KnownOne); @@ -4365,7 +4393,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {      return SDValue(N, 0);    if (N1C) { -    SDValue NewSRL = visitShiftByConstant(N, N1C->getZExtValue()); +    SDValue NewSRL = visitShiftByConstant(N, N1C);      if (NewSRL.getNode())        return NewSRL;    } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 43a02fe9c7e..df8d423ab22 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6573,6 +6573,19 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue,    return true;  } +ConstantSDNode *BuildVectorSDNode::isConstantSplat() const { +  SDValue Op0 = getOperand(0); +  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { +    SDValue Opi = getOperand(i); +    unsigned Opc = Opi.getOpcode(); +    if ((Opc != ISD::UNDEF && Opc != ISD::Constant && Opc != ISD::ConstantFP) || +        Opi != Op0) +      return nullptr; +  } + +  return cast<ConstantSDNode>(Op0); +} +  bool BuildVectorSDNode::isConstant() const {    for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {      unsigned Opc = getOperand(i).getOpcode(); diff --git a/llvm/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll b/llvm/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll index bd52fbde42b..0b520d7ac84 100644 --- a/llvm/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll +++ b/llvm/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll @@ -186,14 +186,14 @@ define <2 x i64> @ashr.v2i64(<2 x i64> %a, <2 x i64> %b) {  define <1 x i64> @shl.v1i64.0(<1 x i64> %a) {  ; CHECK-LABEL: shl.v1i64.0: -; CHECK: shl d{{[0-9]+}}, d{{[0-9]+}}, #0 +; CHECK-NOT: shl d{{[0-9]+}}, d{{[0-9]+}}, #0    %c = shl <1 x i64> %a, zeroinitializer    ret <1 x i64> %c  }  define <2 x i32> @shl.v2i32.0(<2 x i32> %a) {  ; CHECK-LABEL: shl.v2i32.0: -; CHECK: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0 +; CHECK-NOT: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0    %c = shl <2 x i32> %a, zeroinitializer    ret <2 x i32> %c  } @@ -285,7 +285,7 @@ define <1 x i16> @shl.v1i16.imm(<1 x i16> %a) {  define <1 x i32> @shl.v1i32.imm(<1 x i32> %a) {  ; CHECK-LABEL: shl.v1i32.imm: -; CHECK: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0 +; CHECK-NOT: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0    %c = shl <1 x i32> %a, zeroinitializer    ret <1 x i32> %c  } diff --git a/llvm/test/CodeGen/X86/avx2-vector-shifts.ll b/llvm/test/CodeGen/X86/avx2-vector-shifts.ll index 4868e4b4797..4ae2905ef22 100644 --- a/llvm/test/CodeGen/X86/avx2-vector-shifts.ll +++ b/llvm/test/CodeGen/X86/avx2-vector-shifts.ll @@ -9,7 +9,7 @@ entry:  }  ; CHECK-LABEL: test_sllw_1: -; CHECK: vpsllw  $0, %ymm0, %ymm0 +; CHECK-NOT: vpsllw  $0, %ymm0, %ymm0  ; CHECK: ret  define <16 x i16> @test_sllw_2(<16 x i16> %InVec) { @@ -39,7 +39,7 @@ entry:  }  ; CHECK-LABEL: test_slld_1: -; CHECK: vpslld  $0, %ymm0, %ymm0 +; CHECK-NOT: vpslld  $0, %ymm0, %ymm0  ; CHECK: ret  define <8 x i32> @test_slld_2(<8 x i32> %InVec) { @@ -69,7 +69,7 @@ entry:  }  ; CHECK-LABEL: test_sllq_1: -; CHECK: vpsllq  $0, %ymm0, %ymm0 +; CHECK-NOT: vpsllq  $0, %ymm0, %ymm0  ; CHECK: ret  define <4 x i64> @test_sllq_2(<4 x i64> %InVec) { @@ -101,7 +101,7 @@ entry:  }  ; CHECK-LABEL: test_sraw_1: -; CHECK: vpsraw  $0, %ymm0, %ymm0 +; CHECK-NOT: vpsraw  $0, %ymm0, %ymm0  ; CHECK: ret  define <16 x i16> @test_sraw_2(<16 x i16> %InVec) { @@ -131,7 +131,7 @@ entry:  }  ; CHECK-LABEL: test_srad_1: -; CHECK: vpsrad  $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrad  $0, %ymm0, %ymm0  ; CHECK: ret  define <8 x i32> @test_srad_2(<8 x i32> %InVec) { @@ -163,7 +163,7 @@ entry:  }  ; CHECK-LABEL: test_srlw_1: -; CHECK: vpsrlw  $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrlw  $0, %ymm0, %ymm0  ; CHECK: ret  define <16 x i16> @test_srlw_2(<16 x i16> %InVec) { @@ -193,7 +193,7 @@ entry:  }  ; CHECK-LABEL: test_srld_1: -; CHECK: vpsrld  $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrld  $0, %ymm0, %ymm0  ; CHECK: ret  define <8 x i32> @test_srld_2(<8 x i32> %InVec) { @@ -223,7 +223,7 @@ entry:  }  ; CHECK-LABEL: test_srlq_1: -; CHECK: vpsrlq  $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrlq  $0, %ymm0, %ymm0  ; CHECK: ret  define <4 x i64> @test_srlq_2(<4 x i64> %InVec) { @@ -245,3 +245,14 @@ entry:  ; CHECK-LABEL: test_srlq_3:  ; CHECK: vpsrlq $63, %ymm0, %ymm0  ; CHECK: ret + +; CHECK-LABEL: @srl_trunc_and_v4i64 +; CHECK: vpand +; CHECK-NEXT: vpsrlvd +; CHECK: ret +define <4 x i32> @srl_trunc_and_v4i64(<4 x i32> %x, <4 x i64> %y) nounwind { +  %and = and <4 x i64> %y, <i64 8, i64 8, i64 8, i64 8> +  %trunc = trunc <4 x i64> %and to <4 x i32> +  %sra = lshr <4 x i32> %x, %trunc +  ret <4 x i32> %sra +} diff --git a/llvm/test/CodeGen/X86/sse2-vector-shifts.ll b/llvm/test/CodeGen/X86/sse2-vector-shifts.ll index 47a01ff2583..7c8d5e57889 100644 --- a/llvm/test/CodeGen/X86/sse2-vector-shifts.ll +++ b/llvm/test/CodeGen/X86/sse2-vector-shifts.ll @@ -9,8 +9,8 @@ entry:  }  ; CHECK-LABEL: test_sllw_1: -; CHECK: psllw   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psllw   $0, %xmm0 +; CHECK: ret  define <8 x i16> @test_sllw_2(<8 x i16> %InVec) {  entry: @@ -39,8 +39,8 @@ entry:  }  ; CHECK-LABEL: test_slld_1: -; CHECK: pslld   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: pslld   $0, %xmm0 +; CHECK: ret  define <4 x i32> @test_slld_2(<4 x i32> %InVec) {  entry: @@ -69,8 +69,8 @@ entry:  }  ; CHECK-LABEL: test_sllq_1: -; CHECK: psllq   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psllq   $0, %xmm0 +; CHECK: ret  define <2 x i64> @test_sllq_2(<2 x i64> %InVec) {  entry: @@ -101,8 +101,8 @@ entry:  }  ; CHECK-LABEL: test_sraw_1: -; CHECK: psraw   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psraw   $0, %xmm0 +; CHECK: ret  define <8 x i16> @test_sraw_2(<8 x i16> %InVec) {  entry: @@ -131,8 +131,8 @@ entry:  }  ; CHECK-LABEL: test_srad_1: -; CHECK: psrad   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrad   $0, %xmm0 +; CHECK: ret  define <4 x i32> @test_srad_2(<4 x i32> %InVec) {  entry: @@ -163,8 +163,8 @@ entry:  }  ; CHECK-LABEL: test_srlw_1: -; CHECK: psrlw   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrlw   $0, %xmm0 +; CHECK: ret  define <8 x i16> @test_srlw_2(<8 x i16> %InVec) {  entry: @@ -193,8 +193,8 @@ entry:  }  ; CHECK-LABEL: test_srld_1: -; CHECK: psrld   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrld   $0, %xmm0 +; CHECK: ret  define <4 x i32> @test_srld_2(<4 x i32> %InVec) {  entry: @@ -223,8 +223,8 @@ entry:  }  ; CHECK-LABEL: test_srlq_1: -; CHECK: psrlq   $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrlq   $0, %xmm0 +; CHECK: ret  define <2 x i64> @test_srlq_2(<2 x i64> %InVec) {  entry: @@ -245,3 +245,123 @@ entry:  ; CHECK-LABEL: test_srlq_3:  ; CHECK: psrlq $63, %xmm0  ; CHECK-NEXT: ret + + +; CHECK-LABEL: sra_sra_v4i32: +; CHECK: psrad $6, %xmm0 +; CHECK-NEXT: retq +define <4 x i32> @sra_sra_v4i32(<4 x i32> %x) nounwind { +  %sra0 = ashr <4 x i32> %x, <i32 2, i32 2, i32 2, i32 2> +  %sra1 = ashr <4 x i32> %sra0, <i32 4, i32 4, i32 4, i32 4> +  ret <4 x i32> %sra1 +} + +; CHECK-LABEL: @srl_srl_v4i32 +; CHECK: psrld $6, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @srl_srl_v4i32(<4 x i32> %x) nounwind { +  %srl0 = lshr <4 x i32> %x, <i32 2, i32 2, i32 2, i32 2> +  %srl1 = lshr <4 x i32> %srl0, <i32 4, i32 4, i32 4, i32 4> +  ret <4 x i32> %srl1 +} + +; CHECK-LABEL: @srl_shl_v4i32 +; CHECK: andps +; CHECK-NEXT: retq +define <4 x i32> @srl_shl_v4i32(<4 x i32> %x) nounwind { +  %srl0 = shl <4 x i32> %x, <i32 4, i32 4, i32 4, i32 4> +  %srl1 = lshr <4 x i32> %srl0, <i32 4, i32 4, i32 4, i32 4> +  ret <4 x i32> %srl1 +} + +; CHECK-LABEL: @srl_sra_31_v4i32 +; CHECK: psrld $31, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @srl_sra_31_v4i32(<4 x i32> %x, <4 x i32> %y) nounwind { +  %sra = ashr <4 x i32> %x, %y +  %srl1 = lshr <4 x i32> %sra, <i32 31, i32 31, i32 31, i32 31> +  ret <4 x i32> %srl1 +} + +; CHECK-LABEL: @shl_shl_v4i32 +; CHECK: pslld $6, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @shl_shl_v4i32(<4 x i32> %x) nounwind { +  %shl0 = shl <4 x i32> %x, <i32 2, i32 2, i32 2, i32 2> +  %shl1 = shl <4 x i32> %shl0, <i32 4, i32 4, i32 4, i32 4> +  ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @shl_sra_v4i32 +; CHECK: andps +; CHECK-NEXT: ret +define <4 x i32> @shl_sra_v4i32(<4 x i32> %x) nounwind { +  %shl0 = ashr <4 x i32> %x, <i32 4, i32 4, i32 4, i32 4> +  %shl1 = shl <4 x i32> %shl0, <i32 4, i32 4, i32 4, i32 4> +  ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @shl_srl_v4i32 +; CHECK: pslld $3, %xmm0 +; CHECK-NEXT: pand +; CHECK-NEXT: ret +define <4 x i32> @shl_srl_v4i32(<4 x i32> %x) nounwind { +  %shl0 = lshr <4 x i32> %x, <i32 2, i32 2, i32 2, i32 2> +  %shl1 = shl <4 x i32> %shl0, <i32 5, i32 5, i32 5, i32 5> +  ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @shl_zext_srl_v4i32 +; CHECK: andps +; CHECK-NEXT: ret +define <4 x i32> @shl_zext_srl_v4i32(<4 x i16> %x) nounwind { +  %srl = lshr <4 x i16> %x, <i16 2, i16 2, i16 2, i16 2> +  %zext = zext <4 x i16> %srl to <4 x i32> +  %shl = shl <4 x i32> %zext, <i32 2, i32 2, i32 2, i32 2> +  ret <4 x i32> %shl +} + +; CHECK: @sra_trunc_srl_v4i32 +; CHECK: psrad $19, %xmm0 +; CHECK-NEXT: retq +define <4 x i16> @sra_trunc_srl_v4i32(<4 x i32> %x) nounwind { +  %srl = lshr <4 x i32> %x, <i32 16, i32 16, i32 16, i32 16> +  %trunc = trunc <4 x i32> %srl to <4 x i16> +  %sra = ashr <4 x i16> %trunc, <i16 3, i16 3, i16 3, i16 3> +  ret <4 x i16> %sra +} + +; CHECK-LABEL: @shl_zext_shl_v4i32 +; CHECK: pand +; CHECK-NEXT: pslld $19, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @shl_zext_shl_v4i32(<4 x i16> %x) nounwind { +  %shl0 = shl <4 x i16> %x, <i16 2, i16 2, i16 2, i16 2> +  %ext = zext <4 x i16> %shl0 to <4 x i32> +  %shl1 = shl <4 x i32> %ext, <i32 17, i32 17, i32 17, i32 17> +  ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @sra_v4i32 +; CHECK: psrad $3, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @sra_v4i32(<4 x i32> %x) nounwind { +  %sra = ashr <4 x i32> %x, <i32 3, i32 3, i32 3, i32 3> +  ret <4 x i32> %sra +} + +; CHECK-LABEL: @srl_v4i32 +; CHECK: psrld $3, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @srl_v4i32(<4 x i32> %x) nounwind { +  %sra = lshr <4 x i32> %x, <i32 3, i32 3, i32 3, i32 3> +  ret <4 x i32> %sra +} + +; CHECK-LABEL: @shl_v4i32 +; CHECK: pslld $3, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @shl_v4i32(<4 x i32> %x) nounwind { +  %sra = shl <4 x i32> %x, <i32 3, i32 3, i32 3, i32 3> +  ret <4 x i32> %sra +}  | 

