summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp136
1 files changed, 136 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 0579b979e25..c80bccf002a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -29348,8 +29348,144 @@ static SDValue OptimizeConditionalInDecrement(SDNode *N, SelectionDAG &DAG) {
DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp);
}
+static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+
+ if (!VT.isVector() || !VT.isSimple() ||
+ !(VT.getVectorElementType() == MVT::i32))
+ return SDValue();
+
+ unsigned RegSize = 128;
+ if (Subtarget.hasBWI())
+ RegSize = 512;
+ else if (Subtarget.hasAVX2())
+ RegSize = 256;
+
+ // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512.
+ if (VT.getSizeInBits() / 4 > RegSize)
+ return SDValue();
+
+ // Detect the following pattern:
+ //
+ // 1: %2 = zext <N x i8> %0 to <N x i32>
+ // 2: %3 = zext <N x i8> %1 to <N x i32>
+ // 3: %4 = sub nsw <N x i32> %2, %3
+ // 4: %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N]
+ // 5: %6 = sub nsw <N x i32> zeroinitializer, %4
+ // 6: %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6
+ // 7: %8 = add nsw <N x i32> %7, %vec.phi
+ //
+ // The last instruction must be a reduction add. The instructions 3-6 forms an
+ // ABSDIFF pattern.
+
+ // The two operands of reduction add are from PHI and a select-op as in line 7
+ // above.
+ SDValue SelectOp, Phi;
+ if (Op0.getOpcode() == ISD::VSELECT) {
+ SelectOp = Op0;
+ Phi = Op1;
+ } else if (Op1.getOpcode() == ISD::VSELECT) {
+ SelectOp = Op1;
+ Phi = Op0;
+ } else
+ return SDValue();
+
+ // Check the condition of the select instruction is greater-than.
+ SDValue SetCC = SelectOp->getOperand(0);
+ if (SetCC.getOpcode() != ISD::SETCC)
+ return SDValue();
+ ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
+ if (CC != ISD::SETGT)
+ return SDValue();
+
+ Op0 = SelectOp->getOperand(1);
+ Op1 = SelectOp->getOperand(2);
+
+ // The second operand of SelectOp Op1 is the negation of the first operand
+ // Op0, which is implemented as 0 - Op0.
+ if (!(Op1.getOpcode() == ISD::SUB &&
+ ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) &&
+ Op1.getOperand(1) == Op0))
+ return SDValue();
+
+ // The first operand of SetCC is the first operand of SelectOp, which is the
+ // difference between two input vectors.
+ if (SetCC.getOperand(0) != Op0)
+ return SDValue();
+
+ // The second operand of > comparison can be either -1 or 0.
+ if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
+ ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode())))
+ return SDValue();
+
+ // The first operand of SelectOp is the difference between two input vectors.
+ if (Op0.getOpcode() != ISD::SUB)
+ return SDValue();
+
+ Op1 = Op0.getOperand(1);
+ Op0 = Op0.getOperand(0);
+
+ // Check if the operands of the diff are zero-extended from vectors of i8.
+ if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
+ Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+ Op1.getOpcode() != ISD::ZERO_EXTEND ||
+ Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
+ return SDValue();
+
+ // SAD pattern detected. Now build a SAD instruction and an addition for
+ // reduction. Note that the number of elments of the result of SAD is less
+ // than the number of elements of its input. Therefore, we could only update
+ // part of elements in the reduction vector.
+
+ // Legalize the type of the inputs of PSADBW.
+ EVT InVT = Op0.getOperand(0).getValueType();
+ if (InVT.getSizeInBits() <= 128)
+ RegSize = 128;
+ else if (InVT.getSizeInBits() <= 256)
+ RegSize = 256;
+
+ unsigned NumConcat = RegSize / InVT.getSizeInBits();
+ SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
+ Ops[0] = Op0.getOperand(0);
+ MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
+ Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+ Ops[0] = Op1.getOperand(0);
+ Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+
+ // The output of PSADBW is a vector of i64.
+ MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
+ SDValue Sad = DAG.getNode(X86ISD::PSADBW, DL, SadVT, Op0, Op1);
+
+ // We need to turn the vector of i64 into a vector of i32.
+ MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
+ Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
+
+ NumConcat = VT.getSizeInBits() / ResVT.getSizeInBits();
+ if (NumConcat > 1) {
+ // Update part of elements of the reduction vector. This is done by first
+ // extracting a sub-vector from it, updating this sub-vector, and inserting
+ // it back.
+ SDValue SubPhi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Phi,
+ DAG.getIntPtrConstant(0, DL));
+ SDValue Res = DAG.getNode(ISD::ADD, DL, ResVT, Sad, SubPhi);
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Phi, Res,
+ DAG.getIntPtrConstant(0, DL));
+ } else
+ return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi);
+}
+
static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
+ const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ if (Flags->hasVectorReduction()) {
+ if (SDValue Sad = detectSADPattern(N, DAG, Subtarget))
+ return Sad;
+ }
+
EVT VT = N->getValueType(0);
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
OpenPOWER on IntegriCloud