diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 208 |
1 files changed, 118 insertions, 90 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c7c8cb53fb7..8cf3b53b867 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -6057,102 +6057,130 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(ISD::BITCAST, dl, VT, Select); } -static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, - const X86Subtarget *Subtarget) { - EVT VT = N->getValueType(0); +/// \brief Return true if \p N implements a horizontal binop and return the +/// operands for the horizontal binop into V0 and V1. +/// +/// This is a helper function of PerformBUILD_VECTORCombine. +/// This function checks that the build_vector \p N in input implements a +/// horizontal operation. Parameter \p Opcode defines the kind of horizontal +/// operation to match. +/// For example, if \p Opcode is equal to ISD::ADD, then this function +/// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode +/// is equal to ISD::SUB, then this function checks if this is a horizontal +/// arithmetic sub. +/// +/// This function only analyzes elements of \p N whose indices are +/// in range [BaseIdx, LastIdx). +static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, + unsigned BaseIdx, unsigned LastIdx, + SDValue &V0, SDValue &V1) { + assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!"); + assert(N->getValueType(0).isVector() && + N->getValueType(0).getVectorNumElements() >= LastIdx && + "Invalid Vector in input!"); + + bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD); + bool CanFold = true; + unsigned ExpectedVExtractIdx = BaseIdx; + unsigned NumElts = LastIdx - BaseIdx; - // Try to match a horizontal ADD or SUB. - if (((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) || - ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) || - ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || - VT == MVT::v16i16) && Subtarget->hasAVX())) { - unsigned NumOperands = N->getNumOperands(); - unsigned Opcode = N->getOperand(0)->getOpcode(); - bool isCommutable = false; - bool CanFold = false; - switch (Opcode) { - default : break; - case ISD::ADD : - case ISD::FADD : - isCommutable = true; - // FALL-THROUGH - case ISD::SUB : - case ISD::FSUB : - CanFold = true; - } - - // Verify that operands have the same opcode; also, the opcode can only - // be either of: ADD, FADD, SUB, FSUB. - SDValue InVec0, InVec1; - for (unsigned i = 0, e = NumOperands; i != e && CanFold; ++i) { - SDValue Op = N->getOperand(i); - CanFold = Op->getOpcode() == Opcode && Op->hasOneUse(); - - if (!CanFold) - break; + // Check if N implements a horizontal binop. + for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) { + SDValue Op = N->getOperand(i + BaseIdx); + CanFold = Op->getOpcode() == Opcode && Op->hasOneUse(); - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - - // Try to match the following pattern: - // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1)) - CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - Op0.getOperand(0) == Op1.getOperand(0) && - isa<ConstantSDNode>(Op0.getOperand(1)) && - isa<ConstantSDNode>(Op1.getOperand(1))); - if (!CanFold) - break; + if (!CanFold) + break; - unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue(); - unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue(); - unsigned ExpectedIndex = (i * 2) % NumOperands; - - if (i == 0) - InVec0 = Op0.getOperand(0); - else if (i * 2 == NumOperands) - InVec1 = Op0.getOperand(0); - - SDValue Expected = (i * 2 < NumOperands) ? InVec0 : InVec1; - if (I0 == ExpectedIndex) - CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected; - else if (isCommutable && I1 == ExpectedIndex) { - // Try to see if we can match the following dag sequence: - // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I)) - CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected; - } - } + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + // Try to match the following pattern: + // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1)) + CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Op0.getOperand(0) == Op1.getOperand(0) && + isa<ConstantSDNode>(Op0.getOperand(1)) && + isa<ConstantSDNode>(Op1.getOperand(1))); + if (!CanFold) + break; - if (CanFold) { - unsigned NewOpcode; - switch (Opcode) { - default : llvm_unreachable("Unexpected opcode found!"); - case ISD::ADD : NewOpcode = X86ISD::HADD; break; - case ISD::FADD : NewOpcode = X86ISD::FHADD; break; - case ISD::SUB : NewOpcode = X86ISD::HSUB; break; - case ISD::FSUB : NewOpcode = X86ISD::FHSUB; break; - } + unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue(); + unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue(); - if (VT.is256BitVector()) { - SDLoc dl(N); - - // Convert this sequence into two horizontal add/sub followed - // by a concat vector. - SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, dl); - SDValue InVec0_HI = - Extract128BitVector(InVec0, NumOperands/2, DAG, dl); - SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, dl); - SDValue InVec1_HI = - Extract128BitVector(InVec1, NumOperands/2, DAG, dl); - EVT NewVT = InVec0_LO.getValueType(); - - SDValue LO = DAG.getNode(NewOpcode, dl, NewVT, InVec0_LO, InVec0_HI); - SDValue HI = DAG.getNode(NewOpcode, dl, NewVT, InVec1_LO, InVec1_HI); - return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, LO, HI); - } + if (i == 0) + V0 = Op0.getOperand(0); + else if (i * 2 == NumElts) { + V1 = Op0.getOperand(0); + ExpectedVExtractIdx = BaseIdx; + } + + SDValue Expected = (i * 2 < NumElts) ? V0 : V1; + if (I0 == ExpectedVExtractIdx) + CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected; + else if (IsCommutable && I1 == ExpectedVExtractIdx) { + // Try to match the following dag sequence: + // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I)) + CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected; + } else + CanFold = false; - return DAG.getNode(NewOpcode, SDLoc(N), VT, InVec0, InVec1); - } + ExpectedVExtractIdx += 2; + } + + return CanFold; +} + +static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, + const X86Subtarget *Subtarget) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + unsigned NumElts = VT.getVectorNumElements(); + BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N); + SDValue InVec0, InVec1; + + // Try to match horizontal ADD/SUB. + if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) { + // Try to match an SSE3 float HADD/HSUB. + if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); + + if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); + } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) { + // Try to match an SSSE3 integer HADD/HSUB. + if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1); + + if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1); + } + + if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || + VT == MVT::v16i16) && Subtarget->hasAVX()) { + unsigned X86Opcode; + if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::HADD; + else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::HSUB; + else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::FHADD; + else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::FHSUB; + else + return SDValue(); + + // Convert this build_vector into two horizontal add/sub followed by + // a concat vector. + SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, DL); + SDValue InVec0_HI = Extract128BitVector(InVec0, NumElts/2, DAG, DL); + SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, DL); + SDValue InVec1_HI = Extract128BitVector(InVec1, NumElts/2, DAG, DL); + EVT NewVT = InVec0_LO.getValueType(); + + SDValue LO = DAG.getNode(X86Opcode, DL, NewVT, InVec0_LO, InVec0_HI); + SDValue HI = DAG.getNode(X86Opcode, DL, NewVT, InVec1_LO, InVec1_HI); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI); } return SDValue(); |

