diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b2accf6b230..c779414cd0f 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -25082,6 +25082,14 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, if (MaskEltSizeInBits > 64) return false; + // Determine the effective mask value type. + bool FloatDomain = + (VT.isFloatingPoint() || (VT.is256BitVector() && !Subtarget.hasAVX2())) && + (32 <= MaskEltSizeInBits); + MVT MaskVT = FloatDomain ? MVT::getFloatingPointVT(MaskEltSizeInBits) + : MVT::getIntegerVT(MaskEltSizeInBits); + MaskVT = MVT::getVectorVT(MaskVT, NumMaskElts); + // Attempt to match the mask against known shuffle patterns. MVT ShuffleVT; unsigned Shuffle, PermuteImm; @@ -25130,11 +25138,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, (Subtarget.hasAVX() && VT.is256BitVector()))) { // Convert VT to a type compatible with X86ISD::BLENDI. // TODO - add 16i16 support (requires lane duplication). - bool FloatDomain = VT.isFloatingPoint(); - MVT ShuffleVT = FloatDomain ? MVT::getFloatingPointVT(MaskEltSizeInBits) - : MVT::getIntegerVT(MaskEltSizeInBits); - ShuffleVT = MVT::getVectorVT(ShuffleVT, NumMaskElts); - + MVT ShuffleVT = MaskVT; if (Subtarget.hasAVX2()) { if (ShuffleVT == MVT::v4i64) ShuffleVT = MVT::v8i32; @@ -25213,6 +25217,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, // instructions, but in practice PSHUFB tends to be *very* fast so we're // more aggressive. if ((Depth >= 3 || HasVariableMask) && + !is128BitLaneCrossingShuffleMask(MaskVT, Mask) && ((VT.is128BitVector() && Subtarget.hasSSSE3()) || (VT.is256BitVector() && Subtarget.hasAVX2()) || (VT.is512BitVector() && Subtarget.hasBWI()))) { @@ -25230,9 +25235,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, continue; } M = Ratio * M + i % Ratio; - // Check that we are not crossing lanes. - if ((M / 16) != (i / 16)) - return false; + assert ((M / 16) == (i / 16) && "Lane crossing detected"); PSHUFBMask.push_back(DAG.getConstant(M, DL, MVT::i8)); } MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes); |