diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 120 |
1 files changed, 68 insertions, 52 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 4d1dbe13293..baf6d06da7d 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -25077,20 +25077,26 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, /// into either a single instruction if there is a special purpose instruction /// for this operation, or into a PSHUFB instruction which is a fully general /// instruction but should only be used to replace chains over a certain depth. -static bool combineX86ShuffleChain(SDValue Input, SDValue Root, +static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth, bool HasVariableMask, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!"); + assert((Inputs.size() == 1 || Inputs.size() == 2) && + "Unexpected number of shuffle inputs!"); - // Find the operand that enters the chain. Note that multiple uses are OK - // here, we're not going to remove the operand we find. - Input = peekThroughBitcasts(Input); + // Find the inputs that enter the chain. Note that multiple uses are OK + // here, we're not going to remove the operands we find. + bool UnaryShuffle = (Inputs.size() == 1); + SDValue V1 = peekThroughBitcasts(Inputs[0]); + SDValue V2 = (UnaryShuffle ? V1 : peekThroughBitcasts(Inputs[1])); - MVT VT = Input.getSimpleValueType(); + MVT VT1 = V1.getSimpleValueType(); + MVT VT2 = V2.getSimpleValueType(); MVT RootVT = Root.getSimpleValueType(); - assert(VT.getSizeInBits() == RootVT.getSizeInBits() && + assert(VT1.getSizeInBits() == RootVT.getSizeInBits() && + VT2.getSizeInBits() == RootVT.getSizeInBits() && "Vector size mismatch"); SDLoc DL(Root); @@ -25099,14 +25105,14 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, unsigned NumBaseMaskElts = BaseMask.size(); if (NumBaseMaskElts == 1) { assert(BaseMask[0] == 0 && "Invalid shuffle index found!"); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Input), + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, V1), /*AddTo*/ true); return true; } unsigned RootSizeInBits = RootVT.getSizeInBits(); unsigned BaseMaskEltSizeInBits = RootSizeInBits / NumBaseMaskElts; - bool FloatDomain = VT.isFloatingPoint() || + bool FloatDomain = VT1.isFloatingPoint() || VT2.isFloatingPoint() || (RootVT.is256BitVector() && !Subtarget.hasAVX2()); // Don't combine if we are a AVX512/EVEX target and the mask element size @@ -25124,7 +25130,8 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, // TODO - handle 128/256-bit lane shuffles of 512-bit vectors. // Handle 128-bit lane shuffles of 256-bit vectors. - if (RootVT.is256BitVector() && NumBaseMaskElts == 2 && + // TODO - this should support binary shuffles. + if (UnaryShuffle && RootVT.is256BitVector() && NumBaseMaskElts == 2 && !isSequentialOrUndefOrZeroInRange(BaseMask, 0, 2, 0)) { if (Depth == 1 && Root.getOpcode() == X86ISD::VPERM2X128) return false; // Nothing to do! @@ -25133,7 +25140,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, PermMask |= ((BaseMask[0] < 0 ? 0x8 : (BaseMask[0] & 1)) << 0); PermMask |= ((BaseMask[1] < 0 ? 0x8 : (BaseMask[1] & 1)) << 4); - Res = DAG.getBitcast(ShuffleVT, Input); + Res = DAG.getBitcast(ShuffleVT, V1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT, Res, DAG.getUNDEF(ShuffleVT), @@ -25168,45 +25175,47 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, MVT ShuffleVT; unsigned Shuffle, PermuteImm; - if (matchUnaryVectorShuffle(MaskVT, Mask, Subtarget, Shuffle, ShuffleVT)) { - if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, Input); - DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); - DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; - } + if (UnaryShuffle) { + if (matchUnaryVectorShuffle(MaskVT, Mask, Subtarget, Shuffle, ShuffleVT)) { + if (Depth == 1 && Root.getOpcode() == Shuffle) + return false; // Nothing to do! + Res = DAG.getBitcast(ShuffleVT, V1); + DCI.AddToWorklist(Res.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } - if (matchUnaryPermuteVectorShuffle(MaskVT, Mask, Subtarget, Shuffle, ShuffleVT, - PermuteImm)) { - if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, Input); - DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, - DAG.getConstant(PermuteImm, DL, MVT::i8)); - DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; - } + if (matchUnaryPermuteVectorShuffle(MaskVT, Mask, Subtarget, Shuffle, ShuffleVT, + PermuteImm)) { + if (Depth == 1 && Root.getOpcode() == Shuffle) + return false; // Nothing to do! + Res = DAG.getBitcast(ShuffleVT, V1); + DCI.AddToWorklist(Res.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, + DAG.getConstant(PermuteImm, DL, MVT::i8)); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } - if (matchBinaryVectorShuffle(MaskVT, Mask, Shuffle, ShuffleVT)) { - if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, Input); - DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, Res); - DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + // TODO - this should support binary shuffles. + if (matchBinaryVectorShuffle(MaskVT, Mask, Shuffle, ShuffleVT)) { + if (Depth == 1 && Root.getOpcode() == Shuffle) + return false; // Nothing to do! + Res = DAG.getBitcast(ShuffleVT, V1); + DCI.AddToWorklist(Res.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, Res); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } } - SDValue V1 = Input, V2 = Input; if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, V1, V2, DL, DAG, Subtarget, Shuffle, ShuffleVT, PermuteImm)) { if (Depth == 1 && Root.getOpcode() == Shuffle) @@ -25237,7 +25246,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, // If we have a single input shuffle with different shuffle patterns in the // the 128-bit lanes use the variable mask to VPERMILPS. // TODO Combine other mask types at higher depths. - if (HasVariableMask && !MaskContainsZeros && + if (UnaryShuffle && HasVariableMask && !MaskContainsZeros && ((MaskVT == MVT::v8f32 && Subtarget.hasAVX()) || (MaskVT == MVT::v16f32 && Subtarget.hasAVX512()))) { SmallVector<SDValue, 16> VPermIdx; @@ -25249,7 +25258,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, MVT VPermMaskVT = MVT::getVectorVT(MVT::i32, NumMaskElts); SDValue VPermMask = DAG.getBuildVector(VPermMaskVT, DL, VPermIdx); DCI.AddToWorklist(VPermMask.getNode()); - Res = DAG.getBitcast(MaskVT, Input); + Res = DAG.getBitcast(MaskVT, V1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask); DCI.AddToWorklist(Res.getNode()); @@ -25263,7 +25272,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, // Intel's manuals suggest only using PSHUFB if doing so replacing 5 // instructions, but in practice PSHUFB tends to be *very* fast so we're // more aggressive. - if ((Depth >= 3 || HasVariableMask) && + if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && ((RootVT.is128BitVector() && Subtarget.hasSSSE3()) || (RootVT.is256BitVector() && Subtarget.hasAVX2()) || (RootVT.is512BitVector() && Subtarget.hasBWI()))) { @@ -25285,7 +25294,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, PSHUFBMask.push_back(DAG.getConstant(M, DL, MVT::i8)); } MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes); - Res = DAG.getBitcast(ByteVT, Input); + Res = DAG.getBitcast(ByteVT, V1); DCI.AddToWorklist(Res.getNode()); SDValue PSHUFBMaskOp = DAG.getBuildVector(ByteVT, DL, PSHUFBMask); DCI.AddToWorklist(PSHUFBMaskOp.getNode()); @@ -25486,8 +25495,8 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, HasVariableMask, DAG, DCI, Subtarget)) return true; - // At the moment we can only combine unary shuffle mask cases. - if (Ops.size() != 1) + // We can only combine unary and binary shuffle mask cases. + if (Ops.size() > 2) return false; // Minor canonicalization of the accumulated shuffle mask to make it easier @@ -25500,7 +25509,14 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, Mask = std::move(WidenedMask); } - return combineX86ShuffleChain(Ops[0], Root, Mask, Depth, HasVariableMask, DAG, + // Canonicalization of binary shuffle masks to improve pattern matching by + // commuting the inputs. + if (Ops.size() == 2 && canonicalizeShuffleMaskWithCommute(Mask)) { + ShuffleVectorSDNode::commuteMask(Mask); + std::swap(Ops[0], Ops[1]); + } + + return combineX86ShuffleChain(Ops, Root, Mask, Depth, HasVariableMask, DAG, DCI, Subtarget); } |

