diff options
-rw-r--r-- | llvm/include/llvm/CodeGen/SelectionDAG.h | 4 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 191 |
2 files changed, 95 insertions, 100 deletions
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 09fdb626de9..a76843200be 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1128,6 +1128,10 @@ public: SDValue FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, SDNode *Cst1, SDNode *Cst2); + SDValue FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, + const ConstantSDNode *Cst1, + const ConstantSDNode *Cst2); + /// Constant fold a setcc to true or false. SDValue FoldSetCC(EVT VT, SDValue N1, SDValue N2, ISD::CondCode Cond, SDLoc dl); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 7e898d582a4..6d75a7c9533 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -49,6 +49,7 @@ #include "llvm/Target/TargetSubtargetInfo.h" #include <algorithm> #include <cmath> +#include <utility> using namespace llvm; @@ -3109,6 +3110,53 @@ SDValue SelectionDAG::getNode(unsigned Opcode, SDLoc DL, return SDValue(N, 0); } +static std::pair<APInt, bool> FoldValue(unsigned Opcode, const APInt &C1, + const APInt &C2) { + switch (Opcode) { + case ISD::ADD: return std::make_pair(C1 + C2, true); + case ISD::SUB: return std::make_pair(C1 - C2, true); + case ISD::MUL: return std::make_pair(C1 * C2, true); + case ISD::AND: return std::make_pair(C1 & C2, true); + case ISD::OR: return std::make_pair(C1 | C2, true); + case ISD::XOR: return std::make_pair(C1 ^ C2, true); + case ISD::SHL: return std::make_pair(C1 << C2, true); + case ISD::SRL: return std::make_pair(C1.lshr(C2), true); + case ISD::SRA: return std::make_pair(C1.ashr(C2), true); + case ISD::ROTL: return std::make_pair(C1.rotl(C2), true); + case ISD::ROTR: return std::make_pair(C1.rotr(C2), true); + case ISD::UDIV: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.udiv(C2), true); + case ISD::UREM: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.urem(C2), true); + case ISD::SDIV: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.sdiv(C2), true); + case ISD::SREM: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.srem(C2), true); + } + return std::make_pair(APInt(1, 0), false); +} + +SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, + const ConstantSDNode *Cst1, + const ConstantSDNode *Cst2) { + if (Cst1->isOpaque() || Cst2->isOpaque()) + return SDValue(); + + std::pair<APInt, bool> Folded = FoldValue(Opcode, Cst1->getAPIntValue(), + Cst2->getAPIntValue()); + if (!Folded.second) + return SDValue(); + return getConstant(Folded.first, DL, VT); +} + SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, SDNode *Cst1, SDNode *Cst2) { // If the opcode is a target-specific ISD node, there's nothing we can @@ -3117,116 +3165,59 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, if (Opcode >= ISD::BUILTIN_OP_END) return SDValue(); - SmallVector<std::pair<ConstantSDNode *, ConstantSDNode *>, 4> Inputs; - SmallVector<SDValue, 4> Outputs; - EVT SVT = VT.getScalarType(); + // Handle the case of two scalars. + if (const ConstantSDNode *Scalar1 = dyn_cast<ConstantSDNode>(Cst1)) { + if (const ConstantSDNode *Scalar2 = dyn_cast<ConstantSDNode>(Cst2)) { + if (SDValue Folded = + FoldConstantArithmetic(Opcode, DL, VT, Scalar1, Scalar2)) { + if (!VT.isVector()) + return Folded; + SmallVector<SDValue, 4> Outputs; + // We may have a vector type but a scalar result. Create a splat. + Outputs.resize(VT.getVectorNumElements(), Outputs.back()); + // Build a big vector out of the scalar elements we generated. + return getNode(ISD::BUILD_VECTOR, SDLoc(), VT, Outputs); + } else { + return SDValue(); + } + } + } - ConstantSDNode *Scalar1 = dyn_cast<ConstantSDNode>(Cst1); - ConstantSDNode *Scalar2 = dyn_cast<ConstantSDNode>(Cst2); - if (Scalar1 && Scalar2 && (Scalar1->isOpaque() || Scalar2->isOpaque())) + // For vectors extract each constant element into Inputs so we can constant + // fold them individually. + BuildVectorSDNode *BV1 = dyn_cast<BuildVectorSDNode>(Cst1); + BuildVectorSDNode *BV2 = dyn_cast<BuildVectorSDNode>(Cst2); + if (!BV1 || !BV2) return SDValue(); - if (Scalar1 && Scalar2) - // Scalar instruction. - Inputs.push_back(std::make_pair(Scalar1, Scalar2)); - else { - // For vectors extract each constant element into Inputs so we can constant - // fold them individually. - BuildVectorSDNode *BV1 = dyn_cast<BuildVectorSDNode>(Cst1); - BuildVectorSDNode *BV2 = dyn_cast<BuildVectorSDNode>(Cst2); - if (!BV1 || !BV2) - return SDValue(); - - assert(BV1->getNumOperands() == BV2->getNumOperands() && "Out of sync!"); - - for (unsigned I = 0, E = BV1->getNumOperands(); I != E; ++I) { - ConstantSDNode *V1 = dyn_cast<ConstantSDNode>(BV1->getOperand(I)); - ConstantSDNode *V2 = dyn_cast<ConstantSDNode>(BV2->getOperand(I)); - if (!V1 || !V2) // Not a constant, bail. - return SDValue(); + assert(BV1->getNumOperands() == BV2->getNumOperands() && "Out of sync!"); - if (V1->isOpaque() || V2->isOpaque()) - return SDValue(); - - // Avoid BUILD_VECTOR nodes that perform implicit truncation. - // FIXME: This is valid and could be handled by truncating the APInts. - if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT) - return SDValue(); + EVT SVT = VT.getScalarType(); + SmallVector<SDValue, 4> Outputs; + for (unsigned I = 0, E = BV1->getNumOperands(); I != E; ++I) { + ConstantSDNode *V1 = dyn_cast<ConstantSDNode>(BV1->getOperand(I)); + ConstantSDNode *V2 = dyn_cast<ConstantSDNode>(BV2->getOperand(I)); + if (!V1 || !V2) // Not a constant, bail. + return SDValue(); - Inputs.push_back(std::make_pair(V1, V2)); - } - } + if (V1->isOpaque() || V2->isOpaque()) + return SDValue(); - // We have a number of constant values, constant fold them element by element. - for (unsigned I = 0, E = Inputs.size(); I != E; ++I) { - const APInt &C1 = Inputs[I].first->getAPIntValue(); - const APInt &C2 = Inputs[I].second->getAPIntValue(); + // Avoid BUILD_VECTOR nodes that perform implicit truncation. + // FIXME: This is valid and could be handled by truncating the APInts. + if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT) + return SDValue(); - switch (Opcode) { - case ISD::ADD: - Outputs.push_back(getConstant(C1 + C2, DL, SVT)); - break; - case ISD::SUB: - Outputs.push_back(getConstant(C1 - C2, DL, SVT)); - break; - case ISD::MUL: - Outputs.push_back(getConstant(C1 * C2, DL, SVT)); - break; - case ISD::UDIV: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.udiv(C2), DL, SVT)); - break; - case ISD::UREM: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.urem(C2), DL, SVT)); - break; - case ISD::SDIV: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.sdiv(C2), DL, SVT)); - break; - case ISD::SREM: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.srem(C2), DL, SVT)); - break; - case ISD::AND: - Outputs.push_back(getConstant(C1 & C2, DL, SVT)); - break; - case ISD::OR: - Outputs.push_back(getConstant(C1 | C2, DL, SVT)); - break; - case ISD::XOR: - Outputs.push_back(getConstant(C1 ^ C2, DL, SVT)); - break; - case ISD::SHL: - Outputs.push_back(getConstant(C1 << C2, DL, SVT)); - break; - case ISD::SRL: - Outputs.push_back(getConstant(C1.lshr(C2), DL, SVT)); - break; - case ISD::SRA: - Outputs.push_back(getConstant(C1.ashr(C2), DL, SVT)); - break; - case ISD::ROTL: - Outputs.push_back(getConstant(C1.rotl(C2), DL, SVT)); - break; - case ISD::ROTR: - Outputs.push_back(getConstant(C1.rotr(C2), DL, SVT)); - break; - default: + // Fold one vector element. + std::pair<APInt, bool> Folded = FoldValue(Opcode, V1->getAPIntValue(), + V2->getAPIntValue()); + if (!Folded.second) return SDValue(); - } + Outputs.push_back(getConstant(Folded.first, DL, SVT)); } - assert((Scalar1 && Scalar2) || (VT.getVectorNumElements() == Outputs.size() && - "Expected a scalar or vector!")); - - // Handle the scalar case first. - if (!VT.isVector()) - return Outputs.back(); + assert(VT.getVectorNumElements() == Outputs.size() && + "Vector size mismatch!"); // We may have a vector type but a scalar result. Create a splat. Outputs.resize(VT.getVectorNumElements(), Outputs.back()); |