diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index c5e5193421b..057badcd6b7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3171,14 +3171,36 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, return ComputeNumSignBits(InVec, DemandedSrcElts, Depth + 1); } - case ISD::EXTRACT_SUBVECTOR: - return ComputeNumSignBits(Op.getOperand(0), Depth + 1); + case ISD::EXTRACT_SUBVECTOR: { + // If we know the element index, just demand that subvector elements, + // otherwise demand them all. + SDValue Src = Op.getOperand(0); + ConstantSDNode *SubIdx = dyn_cast<ConstantSDNode>(Op.getOperand(1)); + unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); + if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { + // Offset the demanded elts by the subvector index. + uint64_t Idx = SubIdx->getZExtValue(); + APInt DemandedSrc = DemandedElts.zext(NumSrcElts).shl(Idx); + return ComputeNumSignBits(Src, DemandedSrc, Depth + 1); + } + return ComputeNumSignBits(Src, Depth + 1); + } case ISD::CONCAT_VECTORS: - // Determine the minimum number of sign bits across all input vectors. - // Early out if the result is already 1. - Tmp = ComputeNumSignBits(Op.getOperand(0), Depth + 1); - for (unsigned i = 1, e = Op.getNumOperands(); (i < e) && (Tmp > 1); ++i) - Tmp = std::min(Tmp, ComputeNumSignBits(Op.getOperand(i), Depth + 1)); + // Determine the minimum number of sign bits across all demanded + // elts of the input vectors. Early out if the result is already 1. + Tmp = UINT_MAX; + EVT SubVectorVT = Op.getOperand(0).getValueType(); + unsigned NumSubVectorElts = SubVectorVT.getVectorNumElements(); + unsigned NumSubVectors = Op.getNumOperands(); + for (unsigned i = 0; (i < NumSubVectors) && (Tmp > 1); ++i) { + APInt DemandedSub = DemandedElts.lshr(i * NumSubVectorElts); + DemandedSub = DemandedSub.trunc(NumSubVectorElts); + if (!DemandedSub) + continue; + Tmp2 = ComputeNumSignBits(Op.getOperand(i), DemandedSub, Depth + 1); + Tmp = std::min(Tmp, Tmp2); + } + assert(Tmp <= VTBits && "Failed to determine minimum sign bits"); return Tmp; } |