diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 250af7ab8bc..2bed6c85bec 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43158,6 +43158,55 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, return SDValue(); } +/// If we are extracting a subvector of a vector select and the select condition +/// is composed of concatenated vectors, try to narrow the select width. This +/// is a common pattern for AVX1 integer code because 256-bit selects may be +/// legal, but there is almost no integer math/logic available for 256-bit. +/// This function should only be called with legal types (otherwise, the calls +/// to get simple value types will assert). +static SDValue narrowExtractedVectorSelect(SDNode *Ext, SelectionDAG &DAG) { + SDValue Sel = peekThroughBitcasts(Ext->getOperand(0)); + SmallVector<SDValue, 4> CatOps; + if (Sel.getOpcode() != ISD::VSELECT || + !collectConcatOps(Sel.getOperand(0).getNode(), CatOps)) + return SDValue(); + + // TODO: This can be extended to handle extraction to 256-bits. + MVT VT = Ext->getSimpleValueType(0); + if (!VT.is128BitVector()) + return SDValue(); + + MVT WideVT = Ext->getOperand(0).getSimpleValueType(); + MVT SelVT = Sel.getSimpleValueType(); + unsigned SelElts = SelVT.getVectorNumElements(); + unsigned CastedElts = WideVT.getVectorNumElements(); + unsigned ExtIdx = cast<ConstantSDNode>(Ext->getOperand(1))->getZExtValue(); + if (SelElts % CastedElts == 0) { + // The select has the same or more (narrower) elements than the extract + // operand. The extraction index gets scaled by that factor. + ExtIdx *= (SelElts / CastedElts); + } else if (CastedElts % SelElts == 0) { + // The select has less (wider) elements than the extract operand. Make sure + // that the extraction index can be divided evenly. + unsigned IndexDivisor = CastedElts / SelElts; + if (ExtIdx % IndexDivisor != 0) + return SDValue(); + ExtIdx /= IndexDivisor; + } else { + llvm_unreachable("Element count of simple vector types are not divisible?"); + } + + unsigned NarrowingFactor = WideVT.getSizeInBits() / VT.getSizeInBits(); + unsigned NarrowElts = SelElts / NarrowingFactor; + MVT NarrowSelVT = MVT::getVectorVT(SelVT.getVectorElementType(), NarrowElts); + SDLoc DL(Ext); + SDValue ExtCond = extract128BitVector(Sel.getOperand(0), ExtIdx, DAG, DL); + SDValue ExtT = extract128BitVector(Sel.getOperand(1), ExtIdx, DAG, DL); + SDValue ExtF = extract128BitVector(Sel.getOperand(2), ExtIdx, DAG, DL); + SDValue NarrowSel = DAG.getSelect(DL, NarrowSelVT, ExtCond, ExtT, ExtF); + return DAG.getBitcast(VT, NarrowSel); +} + static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -43200,6 +43249,9 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG, if (DCI.isBeforeLegalizeOps()) return SDValue(); + if (SDValue V = narrowExtractedVectorSelect(N, DAG)) + return V; + SDValue InVec = N->getOperand(0); unsigned IdxVal = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); |