summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp54
1 files changed, 41 insertions, 13 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 50cd8ab1dbc..e3fee6fbb63 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -29144,12 +29144,18 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0,
if (SetCC.getOpcode() != ISD::SETCC)
return false;
ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
- if (CC != ISD::SETGT)
+ if (CC != ISD::SETGT && CC != ISD::SETLT)
return false;
SDValue SelectOp1 = Select->getOperand(1);
SDValue SelectOp2 = Select->getOperand(2);
+ // The following instructions assume SelectOp1 is the subtraction operand
+ // and SelectOp2 is the negation operand.
+ // In the case of SETLT this is the other way around.
+ if (CC == ISD::SETLT)
+ std::swap(SelectOp1, SelectOp2);
+
// The second operand of the select should be the negation of the first
// operand, which is implemented as 0 - SelectOp1.
if (!(SelectOp2.getOpcode() == ISD::SUB &&
@@ -29162,8 +29168,17 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0,
if (SetCC.getOperand(0) != SelectOp1)
return false;
- // The second operand of the comparison can be either -1 or 0.
- if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
+ // In SetLT case, The second operand of the comparison can be either 1 or 0.
+ APInt SplatVal;
+ if ((CC == ISD::SETLT) &&
+ !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal) &&
+ SplatVal == 1) ||
+ (ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()))))
+ return false;
+
+ // In SetGT case, The second operand of the comparison can be either -1 or 0.
+ if ((CC == ISD::SETGT) &&
+ !(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode())))
return false;
@@ -29292,11 +29307,9 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
if (!Subtarget.hasSSE2())
return SDValue();
- // Verify the type we're extracting from is appropriate
- // TODO: There's nothing special about i32, any integer type above i16 should
- // work just as well.
+ // Verify the type we're extracting from is any integer type above i16.
EVT VT = Extract->getOperand(0).getValueType();
- if (!VT.isSimple() || !(VT.getVectorElementType() == MVT::i32))
+ if (!VT.isSimple() || !(VT.getVectorElementType().getSizeInBits() > 16))
return SDValue();
unsigned RegSize = 128;
@@ -29305,15 +29318,28 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
else if (Subtarget.hasAVX2())
RegSize = 256;
- // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512.
+ // We handle upto v16i* for SSE2 / v32i* for AVX2 / v64i* for AVX512.
// TODO: We should be able to handle larger vectors by splitting them before
// feeding them into several SADs, and then reducing over those.
- if (VT.getSizeInBits() / 4 > RegSize)
+ if (RegSize / VT.getVectorNumElements() < 8)
return SDValue();
// Match shuffle + add pyramid.
SDValue Root = matchBinOpReduction(Extract, ISD::ADD);
+ // The operand is expected to be zero extended from i8
+ // (verified in detectZextAbsDiff).
+ // In order to convert to i64 and above, additional any/zero/sign
+ // extend is expected.
+ // The zero extend from 32 bit has no mathematical effect on the result.
+ // Also the sign extend is basically zero extend
+ // (extends the sign bit which is zero).
+ // So it is correct to skip the sign/zero extend instruction.
+ if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND ||
+ Root.getOpcode() == ISD::ZERO_EXTEND ||
+ Root.getOpcode() == ISD::ANY_EXTEND))
+ Root = Root.getOperand(0);
+
// If there was a match, we want Root to be a select that is the root of an
// abs-diff pattern.
if (!Root || (Root.getOpcode() != ISD::VSELECT))
@@ -29324,7 +29350,7 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
if (!detectZextAbsDiff(Root, Zext0, Zext1))
return SDValue();
- // Create the SAD instruction
+ // Create the SAD instruction.
SDLoc DL(Extract);
SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL);
@@ -29346,10 +29372,12 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
}
}
- // Return the lowest i32.
- MVT ResVT = MVT::getVectorVT(MVT::i32, SadVT.getSizeInBits() / 32);
+ MVT Type = Extract->getSimpleValueType(0);
+ unsigned TypeSizeInBits = Type.getSizeInBits();
+ // Return the lowest TypeSizeInBits bits.
+ MVT ResVT = MVT::getVectorVT(Type, SadVT.getSizeInBits() / TypeSizeInBits);
SAD = DAG.getNode(ISD::BITCAST, DL, ResVT, SAD);
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, SAD,
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Type, SAD,
Extract->getOperand(1));
}
OpenPOWER on IntegriCloud