diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 8d247c89a0a..6b9ab714e67 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -14462,6 +14462,99 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return SDValue(); } +/// If we are extracting a subvector produced by a wide binary operator with at +/// at least one operand that was the result of a vector concatenation, then try +/// to use the narrow vector operands directly to avoid the concatenation and +/// extraction. +static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { + // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share + // some of these bailouts with other transforms. + + // The extract index must be a constant, so we can map it to a concat operand. + auto *ExtractIndex = dyn_cast<ConstantSDNode>(Extract->getOperand(1)); + if (!ExtractIndex) + return SDValue(); + + // Only handle the case where we are doubling and then halving. A larger ratio + // may require more than two narrow binops to replace the wide binop. + EVT VT = Extract->getValueType(0); + unsigned NumElems = VT.getVectorNumElements(); + assert((ExtractIndex->getZExtValue() % NumElems) == 0 && + "Extract index is not a multiple of the vector length."); + if (Extract->getOperand(0).getValueSizeInBits() != VT.getSizeInBits() * 2) + return SDValue(); + + // We are looking for an optionally bitcasted wide vector binary operator + // feeding an extract subvector. + SDValue BinOp = Extract->getOperand(0); + if (BinOp.getOpcode() == ISD::BITCAST) + BinOp = BinOp.getOperand(0); + + // TODO: The motivating case for this transform is an x86 AVX1 target. That + // target has temptingly almost legal versions of bitwise logic ops in 256-bit + // flavors, but no other 256-bit integer support. This could be extended to + // handle any binop, but that may require fixing/adding other folds to avoid + // codegen regressions. + unsigned BOpcode = BinOp.getOpcode(); + if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR) + return SDValue(); + + // The binop must be a vector type, so we can chop it in half. + EVT WideBVT = BinOp.getValueType(); + if (!WideBVT.isVector()) + return SDValue(); + + // Bail out if the target does not support a narrower version of the binop. + EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(), + WideBVT.getVectorNumElements() / 2); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT)) + return SDValue(); + + // Peek through bitcasts of the binary operator operands if needed. + SDValue LHS = BinOp.getOperand(0); + if (LHS.getOpcode() == ISD::BITCAST) + LHS = LHS.getOperand(0); + + SDValue RHS = BinOp.getOperand(1); + if (RHS.getOpcode() == ISD::BITCAST) + RHS = RHS.getOperand(0); + + // We need at least one concatenation operation of a binop operand to make + // this transform worthwhile. The concat must double the input vector sizes. + // TODO: Should we also handle INSERT_SUBVECTOR patterns? + bool ConcatL = + LHS.getOpcode() == ISD::CONCAT_VECTORS && LHS.getNumOperands() == 2; + bool ConcatR = + RHS.getOpcode() == ISD::CONCAT_VECTORS && RHS.getNumOperands() == 2; + if (!ConcatL && !ConcatR) + return SDValue(); + + // If one of the binop operands was not the result of a concat, we must + // extract a half-sized operand for our new narrow binop. We can't just reuse + // the original extract index operand because we may have bitcasted. + unsigned ConcatOpNum = ExtractIndex->getZExtValue() / NumElems; + unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements(); + EVT ExtBOIdxVT = Extract->getOperand(1).getValueType(); + SDLoc DL(Extract); + + // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN + // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, N) + // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, N), YN + SDValue X = ConcatL ? DAG.getBitcast(NarrowBVT, LHS.getOperand(ConcatOpNum)) + : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, + BinOp.getOperand(0), + DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT)); + + SDValue Y = ConcatR ? DAG.getBitcast(NarrowBVT, RHS.getOperand(ConcatOpNum)) + : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, + BinOp.getOperand(1), + DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT)); + + SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y); + return DAG.getBitcast(VT, NarrowBinOp); +} + SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { EVT NVT = N->getValueType(0); SDValue V = N->getOperand(0); @@ -14517,6 +14610,9 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { } } + if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG)) + return NarrowBOp; + return SDValue(); } |