diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 177 |
1 files changed, 173 insertions, 4 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index abd2d42369a..3e2f5f1ab7c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -4622,9 +4622,9 @@ static SDValue getConstVector(ArrayRef<int> Values, MVT VT, SelectionDAG &DAG, return ConstsNode; } -static SDValue getConstVector(ArrayRef<APInt> Values, SmallBitVector &Undefs, +static SDValue getConstVector(ArrayRef<APInt> Bits, SmallBitVector &Undefs, MVT VT, SelectionDAG &DAG, const SDLoc &dl) { - assert(Values.size() == Undefs.size() && "Unequal constant and undef arrays"); + assert(Bits.size() == Undefs.size() && "Unequal constant and undef arrays"); SmallVector<SDValue, 32> Ops; bool Split = false; @@ -4637,16 +4637,22 @@ static SDValue getConstVector(ArrayRef<APInt> Values, SmallBitVector &Undefs, } MVT EltVT = ConstVecVT.getVectorElementType(); - for (unsigned i = 0, e = Values.size(); i != e; ++i) { + for (unsigned i = 0, e = Bits.size(); i != e; ++i) { if (Undefs[i]) { Ops.append(Split ? 2 : 1, DAG.getUNDEF(EltVT)); continue; } - const APInt &V = Values[i]; + const APInt &V = Bits[i]; assert(V.getBitWidth() == VT.getScalarSizeInBits() && "Unexpected sizes"); if (Split) { Ops.push_back(DAG.getConstant(V.trunc(32), dl, EltVT)); Ops.push_back(DAG.getConstant(V.lshr(32).trunc(32), dl, EltVT)); + } else if (EltVT == MVT::f32) { + APFloat FV(APFloat::IEEEsingle, V); + Ops.push_back(DAG.getConstantFP(FV, dl, EltVT)); + } else if (EltVT == MVT::f64) { + APFloat FV(APFloat::IEEEdouble, V); + Ops.push_back(DAG.getConstantFP(FV, dl, EltVT)); } else { Ops.push_back(DAG.getConstant(V, dl, EltVT)); } @@ -5037,6 +5043,77 @@ static const Constant *getTargetConstantFromNode(SDValue Op) { return dyn_cast<Constant>(CNode->getConstVal()); } +// Extract constant bits from constant pool vector. +static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, + SmallBitVector &UndefElts, + SmallVectorImpl<APInt> &EltBits) { + assert(UndefElts.empty() && "Expected an empty UndefElts vector"); + assert(EltBits.empty() && "Expected an empty EltBits vector"); + + EVT VT = Op.getValueType(); + unsigned SizeInBits = VT.getSizeInBits(); + assert((SizeInBits % EltSizeInBits) == 0 && "Can't split constant!"); + unsigned NumElts = SizeInBits / EltSizeInBits; + + auto *Cst = getTargetConstantFromNode(Op); + if (!Cst) + return false; + + Type *CstTy = Cst->getType(); + if (!CstTy->isVectorTy() || (SizeInBits != CstTy->getPrimitiveSizeInBits())) + return false; + + // Extract all the undef/constant element data and pack into single bitsets. + APInt UndefBits(SizeInBits, 0); + APInt MaskBits(SizeInBits, 0); + + unsigned CstEltSizeInBits = CstTy->getScalarSizeInBits(); + for (unsigned i = 0, e = CstTy->getVectorNumElements(); i != e; ++i) { + auto *COp = Cst->getAggregateElement(i); + if (!COp || + !(isa<UndefValue>(COp) || isa<ConstantInt>(COp) || + isa<ConstantFP>(COp))) + return false; + + if (isa<UndefValue>(COp)) { + APInt EltUndef = APInt::getLowBitsSet(SizeInBits, CstEltSizeInBits); + UndefBits |= EltUndef.shl(i * CstEltSizeInBits); + continue; + } + + APInt Bits; + if (auto *CInt = dyn_cast<ConstantInt>(COp)) + Bits = CInt->getValue(); + else if (auto *CFP = dyn_cast<ConstantFP>(COp)) + Bits = CFP->getValueAPF().bitcastToAPInt(); + + Bits = Bits.zextOrTrunc(SizeInBits); + MaskBits |= Bits.shl(i * CstEltSizeInBits); + } + + UndefElts = SmallBitVector(NumElts, false); + EltBits.resize(NumElts, APInt(EltSizeInBits, 0)); + + // Now extract the undef/constant bit data into the target elts. + for (unsigned i = 0; i != NumElts; ++i) { + APInt UndefEltBits = UndefBits.lshr(i * EltSizeInBits); + UndefEltBits = UndefEltBits.zextOrTrunc(EltSizeInBits); + + // Only treat the element as UNDEF if all bits are UNDEF, otherwise + // treat it as zero. + if (UndefEltBits.isAllOnesValue()) { + UndefElts[i] = true; + continue; + } + + APInt Bits = MaskBits.lshr(i * EltSizeInBits); + Bits = Bits.zextOrTrunc(EltSizeInBits); + EltBits[i] = Bits.getZExtValue(); + } + + return true; +} + static bool getTargetShuffleMaskIndices(SDValue MaskNode, unsigned MaskEltSizeInBits, SmallVectorImpl<uint64_t> &RawMask) { @@ -26308,6 +26385,93 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, return false; } +// Attempt to constant fold all of the constant source ops. +// Returns true if the entire shuffle is folded to a constant. +// TODO: Extend this to merge multiple constant Ops and update the mask. +static bool combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, + ArrayRef<int> Mask, SDValue Root, + bool HasVariableMask, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + MVT VT = Root.getSimpleValueType(); + + unsigned SizeInBits = VT.getSizeInBits(); + unsigned NumMaskElts = Mask.size(); + unsigned MaskSizeInBits = SizeInBits / NumMaskElts; + unsigned NumOps = Ops.size(); + + // Extract constant bits from each source op. + bool OneUseConstantOp = false; + SmallVector<SmallBitVector, 4> UndefEltsOps(NumOps); + SmallVector<SmallVector<APInt, 8>, 4> RawBitsOps(NumOps); + for (unsigned i = 0; i != NumOps; ++i) { + SDValue SrcOp = Ops[i]; + OneUseConstantOp |= SrcOp.hasOneUse(); + if (!getTargetConstantBitsFromNode(SrcOp, MaskSizeInBits, UndefEltsOps[i], + RawBitsOps[i])) + return false; + } + + // Only fold if at least one of the constants is only used once or + // the combined shuffle has included a variable mask shuffle, this + // is to avoid constant pool bloat. + if (!OneUseConstantOp && !HasVariableMask) + return false; + + // Shuffle the constant bits according to the mask. + SmallBitVector UndefElts(NumMaskElts, false); + SmallBitVector ZeroElts(NumMaskElts, false); + SmallBitVector ConstantElts(NumMaskElts, false); + SmallVector<APInt, 8> ConstantBitData(NumMaskElts, + APInt::getNullValue(MaskSizeInBits)); + for (unsigned i = 0; i != NumMaskElts; ++i) { + int M = Mask[i]; + if (M == SM_SentinelUndef) { + UndefElts[i] = true; + continue; + } else if (M == SM_SentinelZero) { + ZeroElts[i] = true; + continue; + } + assert(0 <= M && M < (int)(NumMaskElts * NumOps)); + + unsigned SrcOpIdx = (unsigned)M / NumMaskElts; + unsigned SrcMaskIdx = (unsigned)M % NumMaskElts; + + auto &SrcUndefElts = UndefEltsOps[SrcOpIdx]; + if (SrcUndefElts[SrcMaskIdx]) { + UndefElts[i] = true; + continue; + } + + auto &SrcEltBits = RawBitsOps[SrcOpIdx]; + APInt &Bits = SrcEltBits[SrcMaskIdx]; + if (!Bits) { + ZeroElts[i] = true; + continue; + } + + ConstantElts[i] = true; + ConstantBitData[i] = Bits; + } + assert((UndefElts | ZeroElts | ConstantElts).count() == NumMaskElts); + + // Create the constant data. + MVT MaskSVT; + if (VT.isFloatingPoint() && (MaskSizeInBits == 32 || MaskSizeInBits == 64)) + MaskSVT = MVT::getFloatingPointVT(MaskSizeInBits); + else + MaskSVT = MVT::getIntegerVT(MaskSizeInBits); + + MVT MaskVT = MVT::getVectorVT(MaskSVT, NumMaskElts); + + SDLoc DL(Root); + SDValue CstOp = getConstVector(ConstantBitData, UndefElts, MaskVT, DAG, DL); + DCI.AddToWorklist(CstOp.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(VT, CstOp)); + return true; +} + /// \brief Fully generic combining of x86 shuffle instructions. /// /// This should be the last combine run over the x86 shuffle instructions. Once @@ -26491,6 +26655,11 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, HasVariableMask, DAG, DCI, Subtarget)) return true; + // Attempt to constant fold all of the constant source ops. + if (combineX86ShufflesConstants(Ops, Mask, Root, HasVariableMask, DAG, DCI, + Subtarget)) + return true; + // We can only combine unary and binary shuffle mask cases. if (Ops.size() > 2) return false; |