diff options
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp | 16 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 44 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 40 |
4 files changed, 89 insertions, 19 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index 385bd346b7b..4ee0b933b9e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -3321,9 +3321,11 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) { } case ISD::UMULO: case ISD::SMULO: { - auto Pair = TLI.expandMULO(Node, DAG); - Results.push_back(Pair.first); - Results.push_back(Pair.second); + SDValue Result, Overflow; + if (TLI.expandMULO(Node, Result, Overflow, DAG)) { + Results.push_back(Result); + Results.push_back(Overflow); + } break; } case ISD::BUILD_PAIR: { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 5d080e06d75..511bff484c7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -140,6 +140,7 @@ class VectorLegalizer { SDValue ExpandFunnelShift(SDValue Op); SDValue ExpandROT(SDValue Op); SDValue ExpandFMINNUM_FMAXNUM(SDValue Op); + SDValue ExpandMULO(SDValue Op); SDValue ExpandAddSubSat(SDValue Op); SDValue ExpandFixedPointMul(SDValue Op); SDValue ExpandStrictFPOp(SDValue Op); @@ -418,6 +419,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { case ISD::UMAX: case ISD::SMUL_LOHI: case ISD::UMUL_LOHI: + case ISD::SMULO: + case ISD::UMULO: case ISD::FCANONICALIZE: case ISD::SADDSAT: case ISD::UADDSAT: @@ -779,6 +782,9 @@ SDValue VectorLegalizer::Expand(SDValue Op) { case ISD::FMINNUM: case ISD::FMAXNUM: return ExpandFMINNUM_FMAXNUM(Op); + case ISD::UMULO: + case ISD::SMULO: + return ExpandMULO(Op); case ISD::USUBSAT: case ISD::SSUBSAT: case ISD::UADDSAT: @@ -1216,6 +1222,16 @@ SDValue VectorLegalizer::ExpandFMINNUM_FMAXNUM(SDValue Op) { return DAG.UnrollVectorOp(Op.getNode()); } +SDValue VectorLegalizer::ExpandMULO(SDValue Op) { + SDValue Result, Overflow; + if (!TLI.expandMULO(Op.getNode(), Result, Overflow, DAG)) + std::tie(Result, Overflow) = DAG.UnrollVectorOverflowOp(Op.getNode()); + + AddLegalizedOperand(Op.getValue(0), Result); + AddLegalizedOperand(Op.getValue(1), Overflow); + return Op.getResNo() ? Overflow : Result; +} + SDValue VectorLegalizer::ExpandAddSubSat(SDValue Op) { if (SDValue Expanded = TLI.expandAddSubSat(Op.getNode(), DAG)) return Expanded; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index d7d7b8b7191..65b738151e6 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8918,6 +8918,50 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) { return getBuildVector(VecVT, dl, Scalars); } +std::pair<SDValue, SDValue> SelectionDAG::UnrollVectorOverflowOp( + SDNode *N, unsigned ResNE) { + unsigned Opcode = N->getOpcode(); + assert((Opcode == ISD::UADDO || Opcode == ISD::SADDO || + Opcode == ISD::USUBO || Opcode == ISD::SSUBO || + Opcode == ISD::UMULO || Opcode == ISD::SMULO) && + "Expected an overflow opcode"); + + EVT ResVT = N->getValueType(0); + EVT OvVT = N->getValueType(1); + EVT ResEltVT = ResVT.getVectorElementType(); + EVT OvEltVT = OvVT.getVectorElementType(); + SDLoc dl(N); + + // If ResNE is 0, fully unroll the vector op. + unsigned NE = ResVT.getVectorNumElements(); + if (ResNE == 0) + ResNE = NE; + else if (NE > ResNE) + NE = ResNE; + + SmallVector<SDValue, 8> LHSScalars; + SmallVector<SDValue, 8> RHSScalars; + ExtractVectorElements(N->getOperand(0), LHSScalars, 0, NE); + ExtractVectorElements(N->getOperand(1), RHSScalars, 0, NE); + + SDVTList VTs = getVTList(ResEltVT, OvEltVT); + SmallVector<SDValue, 8> ResScalars; + SmallVector<SDValue, 8> OvScalars; + for (unsigned i = 0; i < NE; ++i) { + SDValue Res = getNode(Opcode, dl, VTs, LHSScalars[i], RHSScalars[i]); + ResScalars.push_back(Res); + OvScalars.push_back(SDValue(Res.getNode(), 1)); + } + + ResScalars.append(ResNE - NE, getUNDEF(ResEltVT)); + OvScalars.append(ResNE - NE, getUNDEF(OvEltVT)); + + EVT NewResVT = EVT::getVectorVT(*getContext(), ResEltVT, ResNE); + EVT NewOvVT = EVT::getVectorVT(*getContext(), OvEltVT, ResNE); + return std::make_pair(getBuildVector(NewResVT, dl, ResScalars), + getBuildVector(NewOvVT, dl, OvScalars)); +} + bool SelectionDAG::areNonVolatileConsecutiveLoads(LoadSDNode *LD, LoadSDNode *Base, unsigned Bytes, diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 9b85fa2051c..b05e8f14b38 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5522,11 +5522,15 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const { DAG.getConstant(Scale, dl, ShiftTy)); } -std::pair<SDValue, SDValue> TargetLowering::expandMULO( - SDNode *Node, SelectionDAG &DAG) const { +bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result, + SDValue &Overflow, SelectionDAG &DAG) const { SDLoc dl(Node); EVT VT = Node->getValueType(0); - EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2); + EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits() * 2); + if (VT.isVector()) + WideVT = EVT::getVectorVT(*DAG.getContext(), WideVT, + VT.getVectorNumElements()); + SDValue LHS = Node->getOperand(0); SDValue RHS = Node->getOperand(1); SDValue BottomHalf; @@ -5546,11 +5550,15 @@ std::pair<SDValue, SDValue> TargetLowering::expandMULO( LHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, LHS); RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS); SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS); - BottomHalf = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, VT, Mul, - DAG.getIntPtrConstant(0, dl)); - TopHalf = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, VT, Mul, - DAG.getIntPtrConstant(1, dl)); + BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); + SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits(), dl, + getShiftAmountTy(WideVT, DAG.getDataLayout())); + TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, + DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt)); } else { + if (VT.isVector()) + return false; + // We can fall back to a libcall with an illegal type for the MUL if we // have a libcall big enough. // Also, we can fall back to a division in some cases, but that's a big @@ -5618,24 +5626,24 @@ std::pair<SDValue, SDValue> TargetLowering::expandMULO( } EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + Result = BottomHalf; if (isSigned) { SDValue ShiftAmt = DAG.getConstant( - VT.getSizeInBits() - 1, dl, + VT.getScalarSizeInBits() - 1, dl, getShiftAmountTy(BottomHalf.getValueType(), DAG.getDataLayout())); SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt); - TopHalf = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE); + Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE); } else { - TopHalf = DAG.getSetCC(dl, SetCCVT, TopHalf, - DAG.getConstant(0, dl, VT), ISD::SETNE); + Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, + DAG.getConstant(0, dl, VT), ISD::SETNE); } // Truncate the result if SetCC returns a larger type than needed. EVT RType = Node->getValueType(1); - if (RType.getSizeInBits() < TopHalf.getValueSizeInBits()) - TopHalf = DAG.getNode(ISD::TRUNCATE, dl, RType, TopHalf); + if (RType.getSizeInBits() < Overflow.getValueSizeInBits()) + Overflow = DAG.getNode(ISD::TRUNCATE, dl, RType, Overflow); - assert(RType.getSizeInBits() == TopHalf.getValueSizeInBits() && + assert(RType.getSizeInBits() == Overflow.getValueSizeInBits() && "Unexpected result type for S/UMULO legalization"); - - return std::make_pair(BottomHalf, TopHalf); + return true; } |