diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2015-10-24 13:17:26 +0000 |
---|---|---|
committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2015-10-24 13:17:26 +0000 |
commit | d5ef318b5baf4f9a2b4d7fde6f83377e9733ba0b (patch) | |
tree | 743757cf8e9fe5efc73a0ce7f18000f1e039e358 /llvm/lib/CodeGen | |
parent | e1761952b43b7374fd1bc2899665d638d7a649fb (diff) | |
download | bcm5719-llvm-d5ef318b5baf4f9a2b4d7fde6f83377e9733ba0b.tar.gz bcm5719-llvm-d5ef318b5baf4f9a2b4d7fde6f83377e9733ba0b.zip |
[X86][XOP] Add support for lowering vector rotations
This patch adds support for lowering to the XOP VPROT / VPROTI vector bit rotation instructions.
This has required changes to the DAGCombiner rotation pattern matching to support vector types - so far I've only changed it to support splat vectors, but generalising this further is feasible in the future.
Differential Revision: http://reviews.llvm.org/D13851
llvm-svn: 251188
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 110 |
1 files changed, 55 insertions, 55 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c96f1a4175a..53d47c7165e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -3796,7 +3796,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) { /// Match "(X shl/srl V1) & V2" where V2 may not be present. static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { if (Op.getOpcode() == ISD::AND) { - if (isa<ConstantSDNode>(Op.getOperand(1))) { + if (isConstOrConstSplat(Op.getOperand(1))) { Mask = Op.getOperand(1); Op = Op.getOperand(0); } else { @@ -3813,105 +3813,106 @@ static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { } // Return true if we can prove that, whenever Neg and Pos are both in the -// range [0, OpSize), Neg == (Pos == 0 ? 0 : OpSize - Pos). This means that +// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that // for two opposing shifts shift1 and shift2 and a value X with OpBits bits: // // (or (shift1 X, Neg), (shift2 X, Pos)) // // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate -// in direction shift1 by Neg. The range [0, OpSize) means that we only need +// in direction shift1 by Neg. The range [0, EltSize) means that we only need // to consider shift amounts with defined behavior. -static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned OpSize) { - // If OpSize is a power of 2 then: +static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) { + // If EltSize is a power of 2 then: // - // (a) (Pos == 0 ? 0 : OpSize - Pos) == (OpSize - Pos) & (OpSize - 1) - // (b) Neg == Neg & (OpSize - 1) whenever Neg is in [0, OpSize). + // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1) + // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize). // - // So if OpSize is a power of 2 and Neg is (and Neg', OpSize-1), we check + // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check // for the stronger condition: // - // Neg & (OpSize - 1) == (OpSize - Pos) & (OpSize - 1) [A] + // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A] // - // for all Neg and Pos. Since Neg & (OpSize - 1) == Neg' & (OpSize - 1) + // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1) // we can just replace Neg with Neg' for the rest of the function. // // In other cases we check for the even stronger condition: // - // Neg == OpSize - Pos [B] + // Neg == EltSize - Pos [B] // // for all Neg and Pos. Note that the (or ...) then invokes undefined - // behavior if Pos == 0 (and consequently Neg == OpSize). + // behavior if Pos == 0 (and consequently Neg == EltSize). // - // We could actually use [A] whenever OpSize is a power of 2, but the + // We could actually use [A] whenever EltSize is a power of 2, but the // only extra cases that it would match are those uninteresting ones // where Neg and Pos are never in range at the same time. E.g. for - // OpSize == 32, using [A] would allow a Neg of the form (sub 64, Pos) + // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos) // as well as (sub 32, Pos), but: // // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos)) // // always invokes undefined behavior for 32-bit X. // - // Below, Mask == OpSize - 1 when using [A] and is all-ones otherwise. + // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise. unsigned MaskLoBits = 0; - if (Neg.getOpcode() == ISD::AND && - isPowerOf2_64(OpSize) && - Neg.getOperand(1).getOpcode() == ISD::Constant && - cast<ConstantSDNode>(Neg.getOperand(1))->getAPIntValue() == OpSize - 1) { - Neg = Neg.getOperand(0); - MaskLoBits = Log2_64(OpSize); + if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { + if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { + if (NegC->getAPIntValue() == EltSize - 1) { + Neg = Neg.getOperand(0); + MaskLoBits = Log2_64(EltSize); + } + } } // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1. if (Neg.getOpcode() != ISD::SUB) return 0; - ConstantSDNode *NegC = dyn_cast<ConstantSDNode>(Neg.getOperand(0)); + ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0)); if (!NegC) return 0; SDValue NegOp1 = Neg.getOperand(1); - // On the RHS of [A], if Pos is Pos' & (OpSize - 1), just replace Pos with + // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with // Pos'. The truncation is redundant for the purpose of the equality. - if (MaskLoBits && - Pos.getOpcode() == ISD::AND && - Pos.getOperand(1).getOpcode() == ISD::Constant && - cast<ConstantSDNode>(Pos.getOperand(1))->getAPIntValue() == OpSize - 1) - Pos = Pos.getOperand(0); + if (MaskLoBits && Pos.getOpcode() == ISD::AND) + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) + if (PosC->getAPIntValue() == EltSize - 1) + Pos = Pos.getOperand(0); // The condition we need is now: // - // (NegC - NegOp1) & Mask == (OpSize - Pos) & Mask + // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask // // If NegOp1 == Pos then we need: // - // OpSize & Mask == NegC & Mask + // EltSize & Mask == NegC & Mask // // (because "x & Mask" is a truncation and distributes through subtraction). APInt Width; if (Pos == NegOp1) Width = NegC->getAPIntValue(); + // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC. // Then the condition we want to prove becomes: // - // (NegC - NegOp1) & Mask == (OpSize - (NegOp1 + PosC)) & Mask + // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask // // which, again because "x & Mask" is a truncation, becomes: // - // NegC & Mask == (OpSize - PosC) & Mask - // OpSize & Mask == (NegC + PosC) & Mask - else if (Pos.getOpcode() == ISD::ADD && - Pos.getOperand(0) == NegOp1 && - Pos.getOperand(1).getOpcode() == ISD::Constant) - Width = (cast<ConstantSDNode>(Pos.getOperand(1))->getAPIntValue() + - NegC->getAPIntValue()); - else + // NegC & Mask == (EltSize - PosC) & Mask + // EltSize & Mask == (NegC + PosC) & Mask + else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) { + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) + Width = PosC->getAPIntValue() + NegC->getAPIntValue(); + else + return false; + } else return false; - // Now we just need to check that OpSize & Mask == Width & Mask. + // Now we just need to check that EltSize & Mask == Width & Mask. if (MaskLoBits) - // Opsize & Mask is 0 since Mask is Opsize - 1. + // EltSize & Mask is 0 since Mask is EltSize - 1. return Width.getLoBits(MaskLoBits) == 0; - return Width == OpSize; + return Width == EltSize; } // A subroutine of MatchRotate used once we have found an OR of two opposite @@ -3931,7 +3932,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, // (srl x, (*ext y))) -> // (rotr x, y) or (rotl x, (sub 32, y)) EVT VT = Shifted.getValueType(); - if (matchRotateSub(InnerPos, InnerNeg, VT.getSizeInBits())) { + if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits())) { bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, HasPos ? Pos : Neg).getNode(); @@ -3974,10 +3975,10 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { if (RHSShift.getOpcode() == ISD::SHL) { std::swap(LHS, RHS); std::swap(LHSShift, RHSShift); - std::swap(LHSMask , RHSMask ); + std::swap(LHSMask, RHSMask); } - unsigned OpSizeInBits = VT.getSizeInBits(); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); SDValue LHSShiftArg = LHSShift.getOperand(0); SDValue LHSShiftAmt = LHSShift.getOperand(1); SDValue RHSShiftArg = RHSShift.getOperand(0); @@ -3985,11 +3986,10 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - if (LHSShiftAmt.getOpcode() == ISD::Constant && - RHSShiftAmt.getOpcode() == ISD::Constant) { - uint64_t LShVal = cast<ConstantSDNode>(LHSShiftAmt)->getZExtValue(); - uint64_t RShVal = cast<ConstantSDNode>(RHSShiftAmt)->getZExtValue(); - if ((LShVal + RShVal) != OpSizeInBits) + if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) { + uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue(); + uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue(); + if ((LShVal + RShVal) != EltSizeInBits) return nullptr; SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, @@ -3997,15 +3997,15 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { - APInt Mask = APInt::getAllOnesValue(OpSizeInBits); + APInt Mask = APInt::getAllOnesValue(EltSizeInBits); if (LHSMask.getNode()) { - APInt RHSBits = APInt::getLowBitsSet(OpSizeInBits, LShVal); - Mask &= cast<ConstantSDNode>(LHSMask)->getAPIntValue() | RHSBits; + APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal); + Mask &= isConstOrConstSplat(LHSMask)->getAPIntValue() | RHSBits; } if (RHSMask.getNode()) { - APInt LHSBits = APInt::getHighBitsSet(OpSizeInBits, RShVal); - Mask &= cast<ConstantSDNode>(RHSMask)->getAPIntValue() | LHSBits; + APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal); + Mask &= isConstOrConstSplat(RHSMask)->getAPIntValue() | LHSBits; } Rot = DAG.getNode(ISD::AND, DL, VT, Rot, DAG.getConstant(Mask, DL, VT)); |