summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp177
1 files changed, 173 insertions, 4 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index abd2d42369a..3e2f5f1ab7c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -4622,9 +4622,9 @@ static SDValue getConstVector(ArrayRef<int> Values, MVT VT, SelectionDAG &DAG,
return ConstsNode;
}
-static SDValue getConstVector(ArrayRef<APInt> Values, SmallBitVector &Undefs,
+static SDValue getConstVector(ArrayRef<APInt> Bits, SmallBitVector &Undefs,
MVT VT, SelectionDAG &DAG, const SDLoc &dl) {
- assert(Values.size() == Undefs.size() && "Unequal constant and undef arrays");
+ assert(Bits.size() == Undefs.size() && "Unequal constant and undef arrays");
SmallVector<SDValue, 32> Ops;
bool Split = false;
@@ -4637,16 +4637,22 @@ static SDValue getConstVector(ArrayRef<APInt> Values, SmallBitVector &Undefs,
}
MVT EltVT = ConstVecVT.getVectorElementType();
- for (unsigned i = 0, e = Values.size(); i != e; ++i) {
+ for (unsigned i = 0, e = Bits.size(); i != e; ++i) {
if (Undefs[i]) {
Ops.append(Split ? 2 : 1, DAG.getUNDEF(EltVT));
continue;
}
- const APInt &V = Values[i];
+ const APInt &V = Bits[i];
assert(V.getBitWidth() == VT.getScalarSizeInBits() && "Unexpected sizes");
if (Split) {
Ops.push_back(DAG.getConstant(V.trunc(32), dl, EltVT));
Ops.push_back(DAG.getConstant(V.lshr(32).trunc(32), dl, EltVT));
+ } else if (EltVT == MVT::f32) {
+ APFloat FV(APFloat::IEEEsingle, V);
+ Ops.push_back(DAG.getConstantFP(FV, dl, EltVT));
+ } else if (EltVT == MVT::f64) {
+ APFloat FV(APFloat::IEEEdouble, V);
+ Ops.push_back(DAG.getConstantFP(FV, dl, EltVT));
} else {
Ops.push_back(DAG.getConstant(V, dl, EltVT));
}
@@ -5037,6 +5043,77 @@ static const Constant *getTargetConstantFromNode(SDValue Op) {
return dyn_cast<Constant>(CNode->getConstVal());
}
+// Extract constant bits from constant pool vector.
+static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
+ SmallBitVector &UndefElts,
+ SmallVectorImpl<APInt> &EltBits) {
+ assert(UndefElts.empty() && "Expected an empty UndefElts vector");
+ assert(EltBits.empty() && "Expected an empty EltBits vector");
+
+ EVT VT = Op.getValueType();
+ unsigned SizeInBits = VT.getSizeInBits();
+ assert((SizeInBits % EltSizeInBits) == 0 && "Can't split constant!");
+ unsigned NumElts = SizeInBits / EltSizeInBits;
+
+ auto *Cst = getTargetConstantFromNode(Op);
+ if (!Cst)
+ return false;
+
+ Type *CstTy = Cst->getType();
+ if (!CstTy->isVectorTy() || (SizeInBits != CstTy->getPrimitiveSizeInBits()))
+ return false;
+
+ // Extract all the undef/constant element data and pack into single bitsets.
+ APInt UndefBits(SizeInBits, 0);
+ APInt MaskBits(SizeInBits, 0);
+
+ unsigned CstEltSizeInBits = CstTy->getScalarSizeInBits();
+ for (unsigned i = 0, e = CstTy->getVectorNumElements(); i != e; ++i) {
+ auto *COp = Cst->getAggregateElement(i);
+ if (!COp ||
+ !(isa<UndefValue>(COp) || isa<ConstantInt>(COp) ||
+ isa<ConstantFP>(COp)))
+ return false;
+
+ if (isa<UndefValue>(COp)) {
+ APInt EltUndef = APInt::getLowBitsSet(SizeInBits, CstEltSizeInBits);
+ UndefBits |= EltUndef.shl(i * CstEltSizeInBits);
+ continue;
+ }
+
+ APInt Bits;
+ if (auto *CInt = dyn_cast<ConstantInt>(COp))
+ Bits = CInt->getValue();
+ else if (auto *CFP = dyn_cast<ConstantFP>(COp))
+ Bits = CFP->getValueAPF().bitcastToAPInt();
+
+ Bits = Bits.zextOrTrunc(SizeInBits);
+ MaskBits |= Bits.shl(i * CstEltSizeInBits);
+ }
+
+ UndefElts = SmallBitVector(NumElts, false);
+ EltBits.resize(NumElts, APInt(EltSizeInBits, 0));
+
+ // Now extract the undef/constant bit data into the target elts.
+ for (unsigned i = 0; i != NumElts; ++i) {
+ APInt UndefEltBits = UndefBits.lshr(i * EltSizeInBits);
+ UndefEltBits = UndefEltBits.zextOrTrunc(EltSizeInBits);
+
+ // Only treat the element as UNDEF if all bits are UNDEF, otherwise
+ // treat it as zero.
+ if (UndefEltBits.isAllOnesValue()) {
+ UndefElts[i] = true;
+ continue;
+ }
+
+ APInt Bits = MaskBits.lshr(i * EltSizeInBits);
+ Bits = Bits.zextOrTrunc(EltSizeInBits);
+ EltBits[i] = Bits.getZExtValue();
+ }
+
+ return true;
+}
+
static bool getTargetShuffleMaskIndices(SDValue MaskNode,
unsigned MaskEltSizeInBits,
SmallVectorImpl<uint64_t> &RawMask) {
@@ -26308,6 +26385,93 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
return false;
}
+// Attempt to constant fold all of the constant source ops.
+// Returns true if the entire shuffle is folded to a constant.
+// TODO: Extend this to merge multiple constant Ops and update the mask.
+static bool combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops,
+ ArrayRef<int> Mask, SDValue Root,
+ bool HasVariableMask, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
+ MVT VT = Root.getSimpleValueType();
+
+ unsigned SizeInBits = VT.getSizeInBits();
+ unsigned NumMaskElts = Mask.size();
+ unsigned MaskSizeInBits = SizeInBits / NumMaskElts;
+ unsigned NumOps = Ops.size();
+
+ // Extract constant bits from each source op.
+ bool OneUseConstantOp = false;
+ SmallVector<SmallBitVector, 4> UndefEltsOps(NumOps);
+ SmallVector<SmallVector<APInt, 8>, 4> RawBitsOps(NumOps);
+ for (unsigned i = 0; i != NumOps; ++i) {
+ SDValue SrcOp = Ops[i];
+ OneUseConstantOp |= SrcOp.hasOneUse();
+ if (!getTargetConstantBitsFromNode(SrcOp, MaskSizeInBits, UndefEltsOps[i],
+ RawBitsOps[i]))
+ return false;
+ }
+
+ // Only fold if at least one of the constants is only used once or
+ // the combined shuffle has included a variable mask shuffle, this
+ // is to avoid constant pool bloat.
+ if (!OneUseConstantOp && !HasVariableMask)
+ return false;
+
+ // Shuffle the constant bits according to the mask.
+ SmallBitVector UndefElts(NumMaskElts, false);
+ SmallBitVector ZeroElts(NumMaskElts, false);
+ SmallBitVector ConstantElts(NumMaskElts, false);
+ SmallVector<APInt, 8> ConstantBitData(NumMaskElts,
+ APInt::getNullValue(MaskSizeInBits));
+ for (unsigned i = 0; i != NumMaskElts; ++i) {
+ int M = Mask[i];
+ if (M == SM_SentinelUndef) {
+ UndefElts[i] = true;
+ continue;
+ } else if (M == SM_SentinelZero) {
+ ZeroElts[i] = true;
+ continue;
+ }
+ assert(0 <= M && M < (int)(NumMaskElts * NumOps));
+
+ unsigned SrcOpIdx = (unsigned)M / NumMaskElts;
+ unsigned SrcMaskIdx = (unsigned)M % NumMaskElts;
+
+ auto &SrcUndefElts = UndefEltsOps[SrcOpIdx];
+ if (SrcUndefElts[SrcMaskIdx]) {
+ UndefElts[i] = true;
+ continue;
+ }
+
+ auto &SrcEltBits = RawBitsOps[SrcOpIdx];
+ APInt &Bits = SrcEltBits[SrcMaskIdx];
+ if (!Bits) {
+ ZeroElts[i] = true;
+ continue;
+ }
+
+ ConstantElts[i] = true;
+ ConstantBitData[i] = Bits;
+ }
+ assert((UndefElts | ZeroElts | ConstantElts).count() == NumMaskElts);
+
+ // Create the constant data.
+ MVT MaskSVT;
+ if (VT.isFloatingPoint() && (MaskSizeInBits == 32 || MaskSizeInBits == 64))
+ MaskSVT = MVT::getFloatingPointVT(MaskSizeInBits);
+ else
+ MaskSVT = MVT::getIntegerVT(MaskSizeInBits);
+
+ MVT MaskVT = MVT::getVectorVT(MaskSVT, NumMaskElts);
+
+ SDLoc DL(Root);
+ SDValue CstOp = getConstVector(ConstantBitData, UndefElts, MaskVT, DAG, DL);
+ DCI.AddToWorklist(CstOp.getNode());
+ DCI.CombineTo(Root.getNode(), DAG.getBitcast(VT, CstOp));
+ return true;
+}
+
/// \brief Fully generic combining of x86 shuffle instructions.
///
/// This should be the last combine run over the x86 shuffle instructions. Once
@@ -26491,6 +26655,11 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps,
HasVariableMask, DAG, DCI, Subtarget))
return true;
+ // Attempt to constant fold all of the constant source ops.
+ if (combineX86ShufflesConstants(Ops, Mask, Root, HasVariableMask, DAG, DCI,
+ Subtarget))
+ return true;
+
// We can only combine unary and binary shuffle mask cases.
if (Ops.size() > 2)
return false;
OpenPOWER on IntegriCloud