diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 128 |
1 files changed, 90 insertions, 38 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 3dd22e2c9d6..d73e320f393 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3010,44 +3010,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, SDLoc DL, case ISD::CTTZ: case ISD::CTTZ_ZERO_UNDEF: case ISD::CTPOP: { - EVT SVT = VT.getScalarType(); - EVT InVT = BV->getValueType(0); - EVT InSVT = InVT.getScalarType(); - - // Find legal integer scalar type for constant promotion and - // ensure that its scalar size is at least as large as source. - EVT LegalSVT = SVT; - if (SVT.isInteger()) { - LegalSVT = TLI->getTypeToTransformTo(*getContext(), SVT); - if (LegalSVT.bitsLT(SVT)) break; - } - - // Let the above scalar folding handle the folding of each element. - SmallVector<SDValue, 8> Ops; - for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) { - SDValue OpN = BV->getOperand(i); - EVT OpVT = OpN.getValueType(); - - // Build vector (integer) scalar operands may need implicit - // truncation - do this before constant folding. - if (OpVT.isInteger() && OpVT.bitsGT(InSVT)) - OpN = getNode(ISD::TRUNCATE, DL, InSVT, OpN); - - OpN = getNode(Opcode, DL, SVT, OpN); - - // Legalize the (integer) scalar constant if necessary. - if (LegalSVT != SVT) - OpN = getNode(ISD::ANY_EXTEND, DL, LegalSVT, OpN); - - if (OpN.getOpcode() != ISD::UNDEF && - OpN.getOpcode() != ISD::Constant && - OpN.getOpcode() != ISD::ConstantFP) - break; - Ops.push_back(OpN); - } - if (Ops.size() == VT.getVectorNumElements()) - return getNode(ISD::BUILD_VECTOR, DL, VT, Ops); - break; + SDValue Ops = { Operand }; + if (SDValue Fold = FoldConstantVectorArithmetic(Opcode, DL, VT, Ops)) + return Fold; } } } @@ -3348,6 +3313,93 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, return getNode(ISD::BUILD_VECTOR, SDLoc(), VT, Outputs); } +SDValue SelectionDAG::FoldConstantVectorArithmetic(unsigned Opcode, SDLoc DL, + EVT VT, + ArrayRef<SDValue> Ops, + const SDNodeFlags *Flags) { + // If the opcode is a target-specific ISD node, there's nothing we can + // do here and the operand rules may not line up with the below, so + // bail early. + if (Opcode >= ISD::BUILTIN_OP_END) + return SDValue(); + + // We can only fold vectors - maybe merge with FoldConstantArithmetic someday? + if (!VT.isVector()) + return SDValue(); + + unsigned NumElts = VT.getVectorNumElements(); + + auto IsSameVectorSize = [&](const SDValue &Op) { + return Op.getValueType().isVector() && + Op.getValueType().getVectorNumElements() == NumElts; + }; + + auto IsConstantBuildVectorOrUndef = [&](const SDValue &Op) { + BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op); + return (Op.getOpcode() == ISD::UNDEF) || (BV && BV->isConstant()); + }; + + // All operands must be vector types with the same number of elements as + // the result type and must be either UNDEF or a build vector of constant + // or UNDEF scalars. + if (!std::all_of(Ops.begin(), Ops.end(), IsConstantBuildVectorOrUndef) || + !std::all_of(Ops.begin(), Ops.end(), IsSameVectorSize)) + return SDValue(); + + // Find legal integer scalar type for constant promotion and + // ensure that its scalar size is at least as large as source. + EVT SVT = VT.getScalarType(); + EVT LegalSVT = SVT; + if (SVT.isInteger()) { + LegalSVT = TLI->getTypeToTransformTo(*getContext(), SVT); + if (LegalSVT.bitsLT(SVT)) + return SDValue(); + } + + // Constant fold each scalar lane separately. + SmallVector<SDValue, 4> ScalarResults; + for (unsigned i = 0; i != NumElts; i++) { + SmallVector<SDValue, 4> ScalarOps; + for (SDValue Op : Ops) { + EVT InSVT = Op->getValueType(0).getScalarType(); + BuildVectorSDNode *InBV = dyn_cast<BuildVectorSDNode>(Op); + if (!InBV) { + // We've checked that this is UNDEF above. + ScalarOps.push_back(getUNDEF(LegalSVT)); + continue; + } + + SDValue ScalarOp = InBV->getOperand(i); + EVT ScalarVT = ScalarOp.getValueType(); + + // Build vector (integer) scalar operands may need implicit + // truncation - do this before constant folding. + if (ScalarVT.isInteger() && ScalarVT.bitsGT(InSVT)) + ScalarOp = getNode(ISD::TRUNCATE, DL, InSVT, ScalarOp); + + ScalarOps.push_back(ScalarOp); + } + + // Constant fold the scalar operands. + SDValue ScalarResult = getNode(Opcode, DL, SVT, ScalarOps, Flags); + + // Legalize the (integer) scalar constant if necessary. + if (LegalSVT != SVT) + ScalarResult = getNode(ISD::ANY_EXTEND, DL, LegalSVT, ScalarResult); + + // Scalar folding only succeeded if the result is a constant or UNDEF. + if (ScalarResult.getOpcode() != ISD::UNDEF && + ScalarResult.getOpcode() != ISD::Constant && + ScalarResult.getOpcode() != ISD::ConstantFP) + return SDValue(); + ScalarResults.push_back(ScalarResult); + } + + assert(ScalarResults.size() == NumElts && + "Unexpected number of scalar results for BUILD_VECTOR"); + return getNode(ISD::BUILD_VECTOR, DL, VT, ScalarResults); +} + SDValue SelectionDAG::getNode(unsigned Opcode, SDLoc DL, EVT VT, SDValue N1, SDValue N2, const SDNodeFlags *Flags) { ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); |