diff options
author | Matt Arsenault <Matthew.Arsenault@amd.com> | 2014-03-17 18:58:01 +0000 |
---|---|---|
committer | Matt Arsenault <Matthew.Arsenault@amd.com> | 2014-03-17 18:58:01 +0000 |
commit | 985b9de4856b37489de54cfa7054ee36a2464d34 (patch) | |
tree | 07e04c9ac0768b442c5b235301212c80cb4108bd /llvm/lib/CodeGen | |
parent | 76c82037bb9e9a3a2dafaf8478dd42cc54206f23 (diff) | |
download | bcm5719-llvm-985b9de4856b37489de54cfa7054ee36a2464d34.tar.gz bcm5719-llvm-985b9de4856b37489de54cfa7054ee36a2464d34.zip |
Make DAGCombiner work on vector bitshifts with constant splat vectors.
llvm-svn: 204071
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 302 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 |
2 files changed, 178 insertions, 137 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 842d5a3ee73..c45d6a1a790 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -280,7 +280,7 @@ namespace { SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); - SDValue visitShiftByConstant(SDNode *N, unsigned Amt); + SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt); bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS); SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N); @@ -634,7 +634,23 @@ static bool isOneUseSetCC(SDValue N) { return false; } -// \brief Returns the SDNode if it is a constant BuildVector or constant int. +/// isConstantSplatVector - Returns true if N is a BUILD_VECTOR node whose +/// elements are all the same constant or undefined. +static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { + BuildVectorSDNode *C = dyn_cast<BuildVectorSDNode>(N); + if (!C) + return false; + + APInt SplatUndef; + unsigned SplatBitSize; + bool HasAnyUndefs; + EVT EltVT = N->getValueType(0).getVectorElementType(); + return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, + HasAnyUndefs) && + EltVT.getSizeInBits() >= SplatBitSize); +} + +// \brief Returns the SDNode if it is a constant BuildVector or constant. static SDNode *isConstantBuildVectorOrConstantInt(SDValue N) { if (isa<ConstantSDNode>(N)) return N.getNode(); @@ -644,6 +660,18 @@ static SDNode *isConstantBuildVectorOrConstantInt(SDValue N) { return NULL; } +// \brief Returns the SDNode if it is a constant splat BuildVector or constant +// int. +static ConstantSDNode *isConstOrConstSplat(SDValue N) { + if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) + return BV->isConstantSplat(); + + return nullptr; +} + SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL, SDValue N0, SDValue N1) { EVT VT = N0.getValueType(); @@ -1830,22 +1858,6 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) { return SDValue(); } -/// isConstantSplatVector - Returns true if N is a BUILD_VECTOR node whose -/// elements are all the same constant or undefined. -static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { - BuildVectorSDNode *C = dyn_cast<BuildVectorSDNode>(N); - if (!C) - return false; - - APInt SplatUndef; - unsigned SplatBitSize; - bool HasAnyUndefs; - EVT EltVT = N->getValueType(0).getVectorElementType(); - return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, - HasAnyUndefs) && - EltVT.getSizeInBits() >= SplatBitSize); -} - SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -3805,11 +3817,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { /// visitShiftByConstant - Handle transforms common to the three shifts, when /// the shift amount is a constant. -SDValue DAGCombiner::visitShiftByConstant(SDNode *N, unsigned Amt) { - assert(isa<ConstantSDNode>(N->getOperand(1)) && - "Expected an ConstantSDNode operand."); +SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) { // We can't and shouldn't fold opaque constants. - if (cast<ConstantSDNode>(N->getOperand(1))->isOpaque()) + if (Amt->isOpaque()) return SDValue(); SDNode *LHS = N->getOperand(0).getNode(); @@ -3888,11 +3898,11 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) { if (N->hasOneUse() && N->getOperand(0).hasOneUse()) { SDValue N01 = N->getOperand(0).getOperand(1); - if (ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N01)) { + if (ConstantSDNode *N01C = isConstOrConstSplat(N01)) { EVT TruncVT = N->getValueType(0); SDValue N00 = N->getOperand(0).getOperand(0); APInt TruncC = N01C->getAPIntValue(); - TruncC = TruncC.trunc(TruncVT.getScalarType().getSizeInBits()); + TruncC = TruncC.trunc(TruncVT.getScalarSizeInBits()); return DAG.getNode(ISD::AND, SDLoc(N), TruncVT, DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N00), @@ -3921,7 +3931,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); EVT VT = N0.getValueType(); - unsigned OpSizeInBits = VT.getScalarType().getSizeInBits(); + unsigned OpSizeInBits = VT.getScalarSizeInBits(); // fold vector ops if (VT.isVector()) { @@ -3931,18 +3941,21 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1); // If setcc produces all-one true value then: // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV) - if (N1CV && N1CV->isConstant() && - TLI.getBooleanContents(true) == - TargetLowering::ZeroOrNegativeOneBooleanContent && - N0.getOpcode() == ISD::AND) { - SDValue N00 = N0->getOperand(0); - SDValue N01 = N0->getOperand(1); - BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01); - - if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC) { - SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, VT, N01CV, N1CV); - if (C.getNode()) - return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); + if (N1CV && N1CV->isConstant()) { + if (N0.getOpcode() == ISD::AND && + TLI.getBooleanContents(true) == + TargetLowering::ZeroOrNegativeOneBooleanContent) { + SDValue N00 = N0->getOperand(0); + SDValue N01 = N0->getOperand(1); + BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01); + + if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC) { + SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, VT, N01CV, N1CV); + if (C.getNode()) + return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); + } + } else { + N1C = isConstOrConstSplat(N1); } } } @@ -3978,14 +3991,15 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return SDValue(N, 0); // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2)) - if (N1C && N0.getOpcode() == ISD::SHL && - N0.getOperand(1).getOpcode() == ISD::Constant) { - uint64_t c1 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - if (c1 + c2 >= OpSizeInBits) - return DAG.getConstant(0, VT); - return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c1 + c2, N1.getValueType())); + if (N1C && N0.getOpcode() == ISD::SHL) { + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N0C1->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, VT); + return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c1 + c2, N1.getValueType())); + } } // fold (shl (ext (shl x, c1)), c2) -> (ext (shl x, (add c1, c2))) @@ -3996,20 +4010,21 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { if (N1C && (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::SIGN_EXTEND) && - N0.getOperand(0).getOpcode() == ISD::SHL && - isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) { - uint64_t c1 = - cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - EVT InnerShiftVT = N0.getOperand(0).getValueType(); - uint64_t InnerShiftSize = InnerShiftVT.getScalarType().getSizeInBits(); - if (c2 >= OpSizeInBits - InnerShiftSize) { - if (c1 + c2 >= OpSizeInBits) - return DAG.getConstant(0, VT); - return DAG.getNode(ISD::SHL, SDLoc(N0), VT, - DAG.getNode(N0.getOpcode(), SDLoc(N0), VT, - N0.getOperand(0)->getOperand(0)), - DAG.getConstant(c1 + c2, N1.getValueType())); + N0.getOperand(0).getOpcode() == ISD::SHL) { + SDValue N0Op0 = N0.getOperand(0); + if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { + uint64_t c1 = N0Op0C1->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); + EVT InnerShiftVT = N0Op0.getValueType(); + uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits(); + if (c2 >= OpSizeInBits - InnerShiftSize) { + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, VT); + return DAG.getNode(ISD::SHL, SDLoc(N0), VT, + DAG.getNode(N0.getOpcode(), SDLoc(N0), VT, + N0Op0->getOperand(0)), + DAG.getConstant(c1 + c2, N1.getValueType())); + } } } @@ -4017,19 +4032,20 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // Only fold this if the inner zext has no other uses to avoid increasing // the total number of instructions. if (N1C && N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() && - N0.getOperand(0).getOpcode() == ISD::SRL && - isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) { - uint64_t c1 = - cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue(); - if (c1 < VT.getSizeInBits()) { - uint64_t c2 = N1C->getZExtValue(); - if (c1 == c2) { - SDValue NewOp0 = N0.getOperand(0); - EVT CountVT = NewOp0.getOperand(1).getValueType(); - SDValue NewSHL = DAG.getNode(ISD::SHL, SDLoc(N), NewOp0.getValueType(), - NewOp0, DAG.getConstant(c2, CountVT)); - AddToWorkList(NewSHL.getNode()); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); + N0.getOperand(0).getOpcode() == ISD::SRL) { + SDValue N0Op0 = N0.getOperand(0); + if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { + uint64_t c1 = N0Op0C1->getZExtValue(); + if (c1 < VT.getScalarSizeInBits()) { + uint64_t c2 = N1C->getZExtValue(); + if (c1 == c2) { + SDValue NewOp0 = N0.getOperand(0); + EVT CountVT = NewOp0.getOperand(1).getValueType(); + SDValue NewSHL = DAG.getNode(ISD::SHL, SDLoc(N), NewOp0.getValueType(), + NewOp0, DAG.getConstant(c2, CountVT)); + AddToWorkList(NewSHL.getNode()); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); + } } } } @@ -4038,40 +4054,39 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // (and (srl x, (sub c1, c2), MASK) // Only fold this if the inner shift has no other uses -- if it does, folding // this will increase the total number of instructions. - if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() && - N0.getOperand(1).getOpcode() == ISD::Constant) { - uint64_t c1 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); - if (c1 < VT.getSizeInBits()) { - uint64_t c2 = N1C->getZExtValue(); - APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(), - VT.getSizeInBits() - c1); - SDValue Shift; - if (c2 > c1) { - Mask = Mask.shl(c2-c1); - Shift = DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c2-c1, N1.getValueType())); - } else { - Mask = Mask.lshr(c1-c2); - Shift = DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c1-c2, N1.getValueType())); + if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N0C1->getZExtValue(); + if (c1 < OpSizeInBits) { + uint64_t c2 = N1C->getZExtValue(); + APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1); + SDValue Shift; + if (c2 > c1) { + Mask = Mask.shl(c2 - c1); + Shift = DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c2 - c1, N1.getValueType())); + } else { + Mask = Mask.lshr(c1 - c2); + Shift = DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c1 - c2, N1.getValueType())); + } + return DAG.getNode(ISD::AND, SDLoc(N0), VT, Shift, + DAG.getConstant(Mask, VT)); } - return DAG.getNode(ISD::AND, SDLoc(N0), VT, Shift, - DAG.getConstant(Mask, VT)); } } // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1)) if (N1C && N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1)) { + unsigned BitSize = VT.getScalarSizeInBits(); SDValue HiBitsMask = - DAG.getConstant(APInt::getHighBitsSet(VT.getSizeInBits(), - VT.getSizeInBits() - - N1C->getZExtValue()), - VT); + DAG.getConstant(APInt::getHighBitsSet(BitSize, + BitSize - N1C->getZExtValue()), VT); return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), HiBitsMask); } if (N1C) { - SDValue NewSHL = visitShiftByConstant(N, N1C->getZExtValue()); + SDValue NewSHL = visitShiftByConstant(N, N1C); if (NewSHL.getNode()) return NewSHL; } @@ -4091,6 +4106,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (VT.isVector()) { SDValue FoldedVOp = SimplifyVBinOp(N); if (FoldedVOp.getNode()) return FoldedVOp; + + N1C = isConstOrConstSplat(N1); } // fold (sra c1, c2) -> (sra c1, c2) @@ -4124,11 +4141,12 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) if (N1C && N0.getOpcode() == ISD::SRA) { - if (ConstantSDNode *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { + if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) { unsigned Sum = N1C->getZExtValue() + C1->getZExtValue(); - if (Sum >= OpSizeInBits) Sum = OpSizeInBits-1; + if (Sum >= OpSizeInBits) + Sum = OpSizeInBits - 1; return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(Sum, N1C->getValueType(0))); + DAG.getConstant(Sum, N1.getValueType())); } } @@ -4137,14 +4155,17 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // result_size - n != m. // If truncate is free for the target sext(shl) is likely to result in better // code. - if (N0.getOpcode() == ISD::SHL) { + if (N0.getOpcode() == ISD::SHL && N1C) { // Get the two constanst of the shifts, CN0 = m, CN = n. - const ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); - if (N01C && N1C) { + const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1)); + if (N01C) { + LLVMContext &Ctx = *DAG.getContext(); // Determine what the truncate's result bitsize and type would be. - EVT TruncVT = - EVT::getIntegerVT(*DAG.getContext(), - OpSizeInBits - N1C->getZExtValue()); + EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue()); + + if (VT.isVector()) + TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements()); + // Determine the residual right-shift amount. signed ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue(); @@ -4177,26 +4198,27 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1); } - // fold (sra (trunc (sr x, c1)), c2) -> (trunc (sra x, c1+c2)) + // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2)) // if c1 is equal to the number of bits the trunc removes if (N0.getOpcode() == ISD::TRUNCATE && (N0.getOperand(0).getOpcode() == ISD::SRL || N0.getOperand(0).getOpcode() == ISD::SRA) && N0.getOperand(0).hasOneUse() && N0.getOperand(0).getOperand(1).hasOneUse() && - N1C && isa<ConstantSDNode>(N0.getOperand(0).getOperand(1))) { - EVT LargeVT = N0.getOperand(0).getValueType(); - ConstantSDNode *LargeShiftAmt = - cast<ConstantSDNode>(N0.getOperand(0).getOperand(1)); - - if (LargeVT.getScalarType().getSizeInBits() - OpSizeInBits == - LargeShiftAmt->getZExtValue()) { - SDValue Amt = - DAG.getConstant(LargeShiftAmt->getZExtValue() + N1C->getZExtValue(), - getShiftAmountTy(N0.getOperand(0).getOperand(0).getValueType())); - SDValue SRA = DAG.getNode(ISD::SRA, SDLoc(N), LargeVT, - N0.getOperand(0).getOperand(0), Amt); - return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, SRA); + N1C) { + SDValue N0Op0 = N0.getOperand(0); + if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) { + unsigned LargeShiftVal = LargeShift->getZExtValue(); + EVT LargeVT = N0Op0.getValueType(); + + if (LargeVT.getScalarSizeInBits() - OpSizeInBits == LargeShiftVal) { + SDValue Amt = + DAG.getConstant(LargeShiftVal + N1C->getZExtValue(), + getShiftAmountTy(N0Op0.getOperand(0).getValueType())); + SDValue SRA = DAG.getNode(ISD::SRA, SDLoc(N), LargeVT, + N0Op0.getOperand(0), Amt); + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, SRA); + } } } @@ -4210,7 +4232,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1); if (N1C) { - SDValue NewSRA = visitShiftByConstant(N, N1C->getZExtValue()); + SDValue NewSRA = visitShiftByConstant(N, N1C); if (NewSRA.getNode()) return NewSRA; } @@ -4230,6 +4252,8 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { if (VT.isVector()) { SDValue FoldedVOp = SimplifyVBinOp(N); if (FoldedVOp.getNode()) return FoldedVOp; + + N1C = isConstOrConstSplat(N1); } // fold (srl c1, c2) -> c1 >>u c2 @@ -4250,14 +4274,15 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return DAG.getConstant(0, VT); // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2)) - if (N1C && N0.getOpcode() == ISD::SRL && - N0.getOperand(1).getOpcode() == ISD::Constant) { - uint64_t c1 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - if (c1 + c2 >= OpSizeInBits) - return DAG.getConstant(0, VT); - return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c1 + c2, N1.getValueType())); + if (N1C && N0.getOpcode() == ISD::SRL) { + if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N01C->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, VT); + return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c1 + c2, N1.getValueType())); + } } // fold (srl (trunc (srl x, c1)), c2) -> 0 or (trunc (srl x, (add c1, c2))) @@ -4282,18 +4307,21 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { } // fold (srl (shl x, c), c) -> (and x, cst2) - if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 && - N0.getValueSizeInBits() <= 64) { - uint64_t ShAmt = N1C->getZExtValue()+64-N0.getValueSizeInBits(); - return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(~0ULL >> ShAmt, VT)); + if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1) { + unsigned BitSize = N0.getScalarValueSizeInBits(); + if (BitSize <= 64) { + uint64_t ShAmt = N1C->getZExtValue() + 64 - BitSize; + return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(~0ULL >> ShAmt, VT)); + } } // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask) if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) { // Shifting in all undef bits? EVT SmallVT = N0.getOperand(0).getValueType(); - if (N1C->getZExtValue() >= SmallVT.getSizeInBits()) + unsigned BitSize = SmallVT.getScalarSizeInBits(); + if (N1C->getZExtValue() >= BitSize) return DAG.getUNDEF(VT); if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) { @@ -4302,7 +4330,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { N0.getOperand(0), DAG.getConstant(ShiftAmt, getShiftAmountTy(SmallVT))); AddToWorkList(SmallShift.getNode()); - APInt Mask = APInt::getAllOnesValue(VT.getSizeInBits()).lshr(ShiftAmt); + APInt Mask = APInt::getAllOnesValue(OpSizeInBits).lshr(ShiftAmt); return DAG.getNode(ISD::AND, SDLoc(N), VT, DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, SmallShift), DAG.getConstant(Mask, VT)); @@ -4311,14 +4339,14 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign // bit, which is unmodified by sra. - if (N1C && N1C->getZExtValue() + 1 == VT.getSizeInBits()) { + if (N1C && N1C->getZExtValue() + 1 == OpSizeInBits) { if (N0.getOpcode() == ISD::SRA) return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1); } // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit). if (N1C && N0.getOpcode() == ISD::CTLZ && - N1C->getAPIntValue() == Log2_32(VT.getSizeInBits())) { + N1C->getAPIntValue() == Log2_32(OpSizeInBits)) { APInt KnownZero, KnownOne; DAG.ComputeMaskedBits(N0.getOperand(0), KnownZero, KnownOne); @@ -4365,7 +4393,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return SDValue(N, 0); if (N1C) { - SDValue NewSRL = visitShiftByConstant(N, N1C->getZExtValue()); + SDValue NewSRL = visitShiftByConstant(N, N1C); if (NewSRL.getNode()) return NewSRL; } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 43a02fe9c7e..df8d423ab22 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6573,6 +6573,19 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue, return true; } +ConstantSDNode *BuildVectorSDNode::isConstantSplat() const { + SDValue Op0 = getOperand(0); + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + SDValue Opi = getOperand(i); + unsigned Opc = Opi.getOpcode(); + if ((Opc != ISD::UNDEF && Opc != ISD::Constant && Opc != ISD::ConstantFP) || + Opi != Op0) + return nullptr; + } + + return cast<ConstantSDNode>(Op0); +} + bool BuildVectorSDNode::isConstant() const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { unsigned Opc = getOperand(i).getOpcode(); |