diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 36 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp | 88 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 48 |
6 files changed, 149 insertions, 38 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 9f3ad172f0e..b6be6d1dfc5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -325,6 +325,7 @@ namespace { SDValue visitSHL(SDNode *N); SDValue visitSRA(SDNode *N); SDValue visitSRL(SDNode *N); + SDValue visitFunnelShift(SDNode *N); SDValue visitRotate(SDNode *N); SDValue visitABS(SDNode *N); SDValue visitBSWAP(SDNode *N); @@ -1513,6 +1514,8 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::SRL: return visitSRL(N); case ISD::ROTR: case ISD::ROTL: return visitRotate(N); + case ISD::FSHL: + case ISD::FSHR: return visitFunnelShift(N); case ISD::ABS: return visitABS(N); case ISD::BSWAP: return visitBSWAP(N); case ISD::BITREVERSE: return visitBITREVERSE(N); @@ -6926,6 +6929,39 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFunnelShift(SDNode *N) { + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue N2 = N->getOperand(2); + bool IsFSHL = N->getOpcode() == ISD::FSHL; + unsigned BitWidth = VT.getScalarSizeInBits(); + + // fold (fshl N0, N1, 0) -> N0 + // fold (fshr N0, N1, 0) -> N1 + if (DAG.MaskedValueIsZero(N2, APInt::getAllOnesValue(BitWidth))) + return IsFSHL ? N0 : N1; + + // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth) + if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) { + if (Cst->getAPIntValue().uge(BitWidth)) { + uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth); + return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1, + DAG.getConstant(RotAmt, SDLoc(N), N2.getValueType())); + } + } + + // fold (fshl N0, N0, N2) -> (rotl N0, N2) + // fold (fshr N0, N0, N2) -> (rotr N0, N2) + // TODO: Investigate flipping this rotate if only one is legal, if funnel shift + // is legal as well we might be better off avoiding non-constant (BW - N2). + unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR; + if (N0 == N1 && hasOperation(RotOpc, VT)) + return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2); + + return SDValue(); +} + SDValue DAGCombiner::visitABS(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index dcb479e4ce1..2f4a5451a61 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -1170,6 +1170,8 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) { } } break; + case ISD::FSHL: + case ISD::FSHR: case ISD::SRL_PARTS: case ISD::SRA_PARTS: case ISD::SHL_PARTS: { @@ -3262,6 +3264,11 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) { } break; } + case ISD::FSHL: + case ISD::FSHR: + if (TLI.expandFunnelShift(Node, Tmp1, DAG)) + Results.push_back(Tmp1); + break; case ISD::SADDSAT: case ISD::UADDSAT: case ISD::SSUBSAT: diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 4f283e9f3b8..4229e577a3d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -129,12 +129,13 @@ class VectorLegalizer { SDValue ExpandFNEG(SDValue Op); SDValue ExpandFSUB(SDValue Op); SDValue ExpandBITREVERSE(SDValue Op); - SDValue ExpandCTPOP(SDValue Op);
- SDValue ExpandCTLZ(SDValue Op);
- SDValue ExpandCTTZ(SDValue Op);
- SDValue ExpandFMINNUM_FMAXNUM(SDValue Op);
- SDValue ExpandStrictFPOp(SDValue Op);
-
+ SDValue ExpandCTPOP(SDValue Op); + SDValue ExpandCTLZ(SDValue Op); + SDValue ExpandCTTZ(SDValue Op); + SDValue ExpandFunnelShift(SDValue Op); + SDValue ExpandFMINNUM_FMAXNUM(SDValue Op); + SDValue ExpandStrictFPOp(SDValue Op); + /// Implements vector promotion. /// /// This is essentially just bitcasting the operands to a different type and @@ -746,12 +747,15 @@ SDValue VectorLegalizer::Expand(SDValue Op) { case ISD::CTLZ: case ISD::CTLZ_ZERO_UNDEF: return ExpandCTLZ(Op); - case ISD::CTTZ:
- case ISD::CTTZ_ZERO_UNDEF:
- return ExpandCTTZ(Op);
- case ISD::FMINNUM:
- case ISD::FMAXNUM:
- return ExpandFMINNUM_FMAXNUM(Op);
+ case ISD::CTTZ: + case ISD::CTTZ_ZERO_UNDEF: + return ExpandCTTZ(Op); + case ISD::FSHL: + case ISD::FSHR: + return ExpandFunnelShift(Op); + case ISD::FMINNUM: + case ISD::FMAXNUM: + return ExpandFMINNUM_FMAXNUM(Op); case ISD::STRICT_FADD: case ISD::STRICT_FSUB: case ISD::STRICT_FMUL: @@ -1123,32 +1127,40 @@ SDValue VectorLegalizer::ExpandFSUB(SDValue Op) { return Op; // Defer to LegalizeDAG return DAG.UnrollVectorOp(Op.getNode()); -}
-
-SDValue VectorLegalizer::ExpandCTPOP(SDValue Op) {
- SDValue Result;
- if (TLI.expandCTPOP(Op.getNode(), Result, DAG))
- return Result;
-
- return DAG.UnrollVectorOp(Op.getNode());
-}
-
-SDValue VectorLegalizer::ExpandCTLZ(SDValue Op) {
- SDValue Result;
- if (TLI.expandCTLZ(Op.getNode(), Result, DAG))
- return Result;
-
- return DAG.UnrollVectorOp(Op.getNode());
-}
-
-SDValue VectorLegalizer::ExpandCTTZ(SDValue Op) {
- SDValue Result;
- if (TLI.expandCTTZ(Op.getNode(), Result, DAG))
- return Result;
-
- return DAG.UnrollVectorOp(Op.getNode());
-}
-
+} + +SDValue VectorLegalizer::ExpandCTPOP(SDValue Op) { + SDValue Result; + if (TLI.expandCTPOP(Op.getNode(), Result, DAG)) + return Result; + + return DAG.UnrollVectorOp(Op.getNode()); +} + +SDValue VectorLegalizer::ExpandCTLZ(SDValue Op) { + SDValue Result; + if (TLI.expandCTLZ(Op.getNode(), Result, DAG)) + return Result; + + return DAG.UnrollVectorOp(Op.getNode()); +} + +SDValue VectorLegalizer::ExpandCTTZ(SDValue Op) { + SDValue Result; + if (TLI.expandCTTZ(Op.getNode(), Result, DAG)) + return Result; + + return DAG.UnrollVectorOp(Op.getNode()); +} + +SDValue VectorLegalizer::ExpandFunnelShift(SDValue Op) { + SDValue Result; + if (TLI.expandFunnelShift(Op.getNode(), Result, DAG)) + return Result; + + return DAG.UnrollVectorOp(Op.getNode()); +} + SDValue VectorLegalizer::ExpandFMINNUM_FMAXNUM(SDValue Op) { if (SDValue Expanded = TLI.expandFMINNUM_FMAXNUM(Op.getNode(), DAG)) return Expanded; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index a6a8ebe1a47..d0d748f3841 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -5751,6 +5751,12 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) { SDValue Zero = DAG.getConstant(0, sdl, VT); SDValue ShAmt = DAG.getNode(ISD::UREM, sdl, VT, Z, BitWidthC); + auto FunnelOpcode = IsFSHL ? ISD::FSHL : ISD::FSHR; + if (TLI.isOperationLegalOrCustom(FunnelOpcode, VT)) { + setValue(&I, DAG.getNode(FunnelOpcode, sdl, VT, X, Y, Z)); + return nullptr; + } + // When X == Y, this is rotate. If the data type has a power-of-2 size, we // avoid the select that is necessary in the general case to filter out // the 0-shift possibility that leads to UB. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index 02d45df5864..bf81d0d267a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -237,6 +237,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const { case ISD::SRL: return "srl"; case ISD::ROTL: return "rotl"; case ISD::ROTR: return "rotr"; + case ISD::FSHL: return "fshl"; + case ISD::FSHR: return "fshr"; case ISD::FADD: return "fadd"; case ISD::STRICT_FADD: return "strict_fadd"; case ISD::FSUB: return "fsub"; diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 04303d682d3..6654a02e97a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -4114,6 +4114,54 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, return Ok; } +bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + EVT VT = Node->getValueType(0); + + if (VT.isVector() && (!isOperationLegalOrCustom(ISD::SHL, VT) || + !isOperationLegalOrCustom(ISD::SRL, VT) || + !isOperationLegalOrCustom(ISD::SUB, VT) || + !isOperationLegalOrCustomOrPromote(ISD::OR, VT))) + return false; + + // fshl: (X << (Z % BW)) | (Y >> (BW - (Z % BW))) + // fshr: (X << (BW - (Z % BW))) | (Y >> (Z % BW)) + SDValue X = Node->getOperand(0); + SDValue Y = Node->getOperand(1); + SDValue Z = Node->getOperand(2); + + unsigned EltSizeInBits = VT.getScalarSizeInBits(); + bool IsFSHL = Node->getOpcode() == ISD::FSHL; + SDLoc DL(SDValue(Node, 0)); + + EVT ShVT = Z.getValueType(); + SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT); + SDValue Zero = DAG.getConstant(0, DL, ShVT); + + SDValue ShAmt; + if (isPowerOf2_32(EltSizeInBits)) { + SDValue Mask = DAG.getConstant(EltSizeInBits - 1, DL, ShVT); + ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask); + } else { + ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC); + } + + SDValue InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt); + SDValue ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt); + SDValue ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt); + SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShX, ShY); + + // If (Z % BW == 0), then the opposite direction shift is shift-by-bitwidth, + // and that is undefined. We must compare and select to avoid UB. + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ShVT); + + // For fshl, 0-shift returns the 1st arg (X). + // For fshr, 0-shift returns the 2nd arg (Y). + SDValue IsZeroShift = DAG.getSetCC(DL, CCVT, ShAmt, Zero, ISD::SETEQ); + Result = DAG.getSelect(DL, VT, IsZeroShift, IsFSHL ? X : Y, Or); + return true; +} + bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const { SDValue Src = Node->getOperand(0); |