diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 19 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 33 |
2 files changed, 42 insertions, 10 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 216389ac42a..76187e29116 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -273,6 +273,7 @@ namespace { SDValue visitANY_EXTEND(SDNode *N); SDValue visitSIGN_EXTEND_INREG(SDNode *N); SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N); + SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N); SDValue visitTRUNCATE(SDNode *N); SDValue visitBITCAST(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); @@ -1396,6 +1397,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::ANY_EXTEND: return visitANY_EXTEND(N); case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N); case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N); + case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N); case ISD::TRUNCATE: return visitTRUNCATE(N); case ISD::BITCAST: return visitBITCAST(N); case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); @@ -5715,7 +5717,8 @@ static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, EVT VT = N->getValueType(0); assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || - Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG) + Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG || + Opcode == ISD::ZERO_EXTEND_VECTOR_INREG) && "Expected EXTEND dag node in input!"); // fold (sext c1) -> c1 @@ -6993,6 +6996,20 @@ SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) { + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + + if (N0.getOpcode() == ISD::UNDEF) + return DAG.getUNDEF(VT); + + if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, + LegalOperations)) + return SDValue(Res, 0); + + return SDValue(); +} + SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 3b45544730e..692633c494e 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -28332,13 +28332,15 @@ static SDValue getDivRem8(SDNode *N, SelectionDAG &DAG) { return R.getValue(1); } -/// Convert a SEXT of a vector to a SIGN_EXTEND_VECTOR_INREG, this requires -/// the splitting (or concatenating with UNDEFs) of the input to vectors of the -/// same size as the target type which then extends the lowest elements. +/// Convert a SEXT or ZEXT of a vector to a SIGN_EXTEND_VECTOR_INREG or +/// ZERO_EXTEND_VECTOR_INREG, this requires the splitting (or concatenating +/// with UNDEFs) of the input to vectors of the same size as the target type +/// which then extends the lowest elements. static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { - if (N->getOpcode() != ISD::SIGN_EXTEND) + unsigned Opcode = N->getOpcode(); + if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND) return SDValue(); if (!DCI.isBeforeLegalizeOps()) return SDValue(); @@ -28359,6 +28361,12 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, if (InSVT != MVT::i32 && InSVT != MVT::i16 && InSVT != MVT::i8) return SDValue(); + // On AVX2+ targets, if the input/output types are both legal then we will be + // able to use SIGN_EXTEND/ZERO_EXTEND directly. + if (Subtarget.hasInt256() && DAG.getTargetLoweringInfo().isTypeLegal(VT) && + DAG.getTargetLoweringInfo().isTypeLegal(InVT)) + return SDValue(); + SDLoc DL(N); auto ExtendVecSize = [&DAG](SDLoc DL, SDValue N, unsigned Size) { @@ -28378,20 +28386,22 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, EVT ExVT = EVT::getVectorVT(*DAG.getContext(), SVT, 128 / SVT.getSizeInBits()); SDValue Ex = ExtendVecSize(DL, N0, Scale * InVT.getSizeInBits()); - SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, ExVT, Ex); + SDValue SExt = DAG.getNode(Opcode, DL, ExVT, Ex); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SExt, DAG.getIntPtrConstant(0, DL)); } // If target-size is 128-bits (or 256-bits on AVX2 target), then convert to - // ISD::SIGN_EXTEND_VECTOR_INREG which ensures lowering to X86ISD::VSEXT. + // ISD::*_EXTEND_VECTOR_INREG which ensures lowering to X86ISD::V*EXT. if (VT.is128BitVector() || (VT.is256BitVector() && Subtarget.hasInt256())) { SDValue ExOp = ExtendVecSize(DL, N0, VT.getSizeInBits()); - return DAG.getSignExtendVectorInReg(ExOp, DL, VT); + return Opcode == ISD::SIGN_EXTEND + ? DAG.getSignExtendVectorInReg(ExOp, DL, VT) + : DAG.getZeroExtendVectorInReg(ExOp, DL, VT); } // On pre-AVX2 targets, split into 128-bit nodes of - // ISD::SIGN_EXTEND_VECTOR_INREG. + // ISD::*_EXTEND_VECTOR_INREG. if (!Subtarget.hasInt256() && !(VT.getSizeInBits() % 128)) { unsigned NumVecs = VT.getSizeInBits() / 128; unsigned NumSubElts = 128 / SVT.getSizeInBits(); @@ -28403,7 +28413,9 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, SDValue SrcVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InSubVT, N0, DAG.getIntPtrConstant(Offset, DL)); SrcVec = ExtendVecSize(DL, SrcVec, 128); - SrcVec = DAG.getSignExtendVectorInReg(SrcVec, DL, SubVT); + SrcVec = Opcode == ISD::SIGN_EXTEND + ? DAG.getSignExtendVectorInReg(SrcVec, DL, SubVT) + : DAG.getZeroExtendVectorInReg(SrcVec, DL, SubVT); Opnds.push_back(SrcVec); } return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Opnds); @@ -28522,6 +28534,9 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, } } + if (SDValue V = combineToExtendVectorInReg(N, DAG, DCI, Subtarget)) + return V; + if (VT.is256BitVector()) if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) return R; |