summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp208
1 files changed, 118 insertions, 90 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c7c8cb53fb7..8cf3b53b867 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -6057,102 +6057,130 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(ISD::BITCAST, dl, VT, Select);
}
-static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
- const X86Subtarget *Subtarget) {
- EVT VT = N->getValueType(0);
+/// \brief Return true if \p N implements a horizontal binop and return the
+/// operands for the horizontal binop into V0 and V1.
+///
+/// This is a helper function of PerformBUILD_VECTORCombine.
+/// This function checks that the build_vector \p N in input implements a
+/// horizontal operation. Parameter \p Opcode defines the kind of horizontal
+/// operation to match.
+/// For example, if \p Opcode is equal to ISD::ADD, then this function
+/// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode
+/// is equal to ISD::SUB, then this function checks if this is a horizontal
+/// arithmetic sub.
+///
+/// This function only analyzes elements of \p N whose indices are
+/// in range [BaseIdx, LastIdx).
+static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
+ unsigned BaseIdx, unsigned LastIdx,
+ SDValue &V0, SDValue &V1) {
+ assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!");
+ assert(N->getValueType(0).isVector() &&
+ N->getValueType(0).getVectorNumElements() >= LastIdx &&
+ "Invalid Vector in input!");
+
+ bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD);
+ bool CanFold = true;
+ unsigned ExpectedVExtractIdx = BaseIdx;
+ unsigned NumElts = LastIdx - BaseIdx;
- // Try to match a horizontal ADD or SUB.
- if (((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) ||
- ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) ||
- ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
- VT == MVT::v16i16) && Subtarget->hasAVX())) {
- unsigned NumOperands = N->getNumOperands();
- unsigned Opcode = N->getOperand(0)->getOpcode();
- bool isCommutable = false;
- bool CanFold = false;
- switch (Opcode) {
- default : break;
- case ISD::ADD :
- case ISD::FADD :
- isCommutable = true;
- // FALL-THROUGH
- case ISD::SUB :
- case ISD::FSUB :
- CanFold = true;
- }
-
- // Verify that operands have the same opcode; also, the opcode can only
- // be either of: ADD, FADD, SUB, FSUB.
- SDValue InVec0, InVec1;
- for (unsigned i = 0, e = NumOperands; i != e && CanFold; ++i) {
- SDValue Op = N->getOperand(i);
- CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
-
- if (!CanFold)
- break;
+ // Check if N implements a horizontal binop.
+ for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) {
+ SDValue Op = N->getOperand(i + BaseIdx);
+ CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
- SDValue Op0 = Op.getOperand(0);
- SDValue Op1 = Op.getOperand(1);
-
- // Try to match the following pattern:
- // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
- CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- Op0.getOperand(0) == Op1.getOperand(0) &&
- isa<ConstantSDNode>(Op0.getOperand(1)) &&
- isa<ConstantSDNode>(Op1.getOperand(1)));
- if (!CanFold)
- break;
+ if (!CanFold)
+ break;
- unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
- unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
- unsigned ExpectedIndex = (i * 2) % NumOperands;
-
- if (i == 0)
- InVec0 = Op0.getOperand(0);
- else if (i * 2 == NumOperands)
- InVec1 = Op0.getOperand(0);
-
- SDValue Expected = (i * 2 < NumOperands) ? InVec0 : InVec1;
- if (I0 == ExpectedIndex)
- CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
- else if (isCommutable && I1 == ExpectedIndex) {
- // Try to see if we can match the following dag sequence:
- // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
- CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
- }
- }
+ SDValue Op0 = Op.getOperand(0);
+ SDValue Op1 = Op.getOperand(1);
+
+ // Try to match the following pattern:
+ // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
+ CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ Op0.getOperand(0) == Op1.getOperand(0) &&
+ isa<ConstantSDNode>(Op0.getOperand(1)) &&
+ isa<ConstantSDNode>(Op1.getOperand(1)));
+ if (!CanFold)
+ break;
- if (CanFold) {
- unsigned NewOpcode;
- switch (Opcode) {
- default : llvm_unreachable("Unexpected opcode found!");
- case ISD::ADD : NewOpcode = X86ISD::HADD; break;
- case ISD::FADD : NewOpcode = X86ISD::FHADD; break;
- case ISD::SUB : NewOpcode = X86ISD::HSUB; break;
- case ISD::FSUB : NewOpcode = X86ISD::FHSUB; break;
- }
+ unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
+ unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
- if (VT.is256BitVector()) {
- SDLoc dl(N);
-
- // Convert this sequence into two horizontal add/sub followed
- // by a concat vector.
- SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, dl);
- SDValue InVec0_HI =
- Extract128BitVector(InVec0, NumOperands/2, DAG, dl);
- SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, dl);
- SDValue InVec1_HI =
- Extract128BitVector(InVec1, NumOperands/2, DAG, dl);
- EVT NewVT = InVec0_LO.getValueType();
-
- SDValue LO = DAG.getNode(NewOpcode, dl, NewVT, InVec0_LO, InVec0_HI);
- SDValue HI = DAG.getNode(NewOpcode, dl, NewVT, InVec1_LO, InVec1_HI);
- return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, LO, HI);
- }
+ if (i == 0)
+ V0 = Op0.getOperand(0);
+ else if (i * 2 == NumElts) {
+ V1 = Op0.getOperand(0);
+ ExpectedVExtractIdx = BaseIdx;
+ }
+
+ SDValue Expected = (i * 2 < NumElts) ? V0 : V1;
+ if (I0 == ExpectedVExtractIdx)
+ CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
+ else if (IsCommutable && I1 == ExpectedVExtractIdx) {
+ // Try to match the following dag sequence:
+ // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
+ CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
+ } else
+ CanFold = false;
- return DAG.getNode(NewOpcode, SDLoc(N), VT, InVec0, InVec1);
- }
+ ExpectedVExtractIdx += 2;
+ }
+
+ return CanFold;
+}
+
+static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget *Subtarget) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ unsigned NumElts = VT.getVectorNumElements();
+ BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N);
+ SDValue InVec0, InVec1;
+
+ // Try to match horizontal ADD/SUB.
+ if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) {
+ // Try to match an SSE3 float HADD/HSUB.
+ if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);
+
+ if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
+ } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) {
+ // Try to match an SSSE3 integer HADD/HSUB.
+ if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1);
+
+ if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1);
+ }
+
+ if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
+ VT == MVT::v16i16) && Subtarget->hasAVX()) {
+ unsigned X86Opcode;
+ if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::HADD;
+ else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::HSUB;
+ else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::FHADD;
+ else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::FHSUB;
+ else
+ return SDValue();
+
+ // Convert this build_vector into two horizontal add/sub followed by
+ // a concat vector.
+ SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, DL);
+ SDValue InVec0_HI = Extract128BitVector(InVec0, NumElts/2, DAG, DL);
+ SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, DL);
+ SDValue InVec1_HI = Extract128BitVector(InVec1, NumElts/2, DAG, DL);
+ EVT NewVT = InVec0_LO.getValueType();
+
+ SDValue LO = DAG.getNode(X86Opcode, DL, NewVT, InVec0_LO, InVec0_HI);
+ SDValue HI = DAG.getNode(X86Opcode, DL, NewVT, InVec1_LO, InVec1_HI);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI);
}
return SDValue();
OpenPOWER on IntegriCloud