diff options
| author | Elena Demikhovsky <elena.demikhovsky@intel.com> | 2016-04-03 08:41:12 +0000 |
|---|---|---|
| committer | Elena Demikhovsky <elena.demikhovsky@intel.com> | 2016-04-03 08:41:12 +0000 |
| commit | 5e426f7356a9dfb5db04104b975bb67ef7724a52 (patch) | |
| tree | a9b6ed5b2d2d7ce2486bfcd014f548067bd09ed9 /llvm/lib/Target/X86 | |
| parent | 842fa53026a505ad89f8f9d996e15cf8c72c05d1 (diff) | |
| download | bcm5719-llvm-5e426f7356a9dfb5db04104b975bb67ef7724a52.tar.gz bcm5719-llvm-5e426f7356a9dfb5db04104b975bb67ef7724a52.zip | |
AVX-512: Load and Extended Load for i1 vectors
Implemented load+{sign|zero}_extend for i1 vectors
Fixed failures in i1 vector load.
Covered loading of v2i1, v4i1, v8i1, v16i1, v32i1, v64i1 vectors for KNL and SKX.
Differential Revision: http://reviews.llvm.org/D18737
llvm-svn: 265259
Diffstat (limited to 'llvm/lib/Target/X86')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 110 | ||||
| -rw-r--r-- | llvm/lib/Target/X86/X86InstrAVX512.td | 22 |
2 files changed, 122 insertions, 10 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index e816e46ae24..0ceb61c0e06 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1384,8 +1384,17 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::LOAD, MVT::v8f64, Legal); setOperationAction(ISD::LOAD, MVT::v8i64, Legal); setOperationAction(ISD::LOAD, MVT::v16i32, Legal); - setOperationAction(ISD::LOAD, MVT::v16i1, Legal); + setOperationAction(ISD::LOAD, MVT::v16i1, Legal); + setOperationAction(ISD::LOAD, MVT::v8i1, Legal); + for (MVT VT : {MVT::v2i64, MVT::v4i32, MVT::v8i32, MVT::v4i64, MVT::v8i16, + MVT::v16i8, MVT::v16i16, MVT::v32i8, MVT::v16i32, + MVT::v8i64, MVT::v32i16, MVT::v64i8}) { + MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); + setLoadExtAction(ISD::SEXTLOAD, VT, MaskVT, Custom); + setLoadExtAction(ISD::ZEXTLOAD, VT, MaskVT, Custom); + setLoadExtAction(ISD::EXTLOAD, VT, MaskVT, Custom); + } setOperationAction(ISD::FADD, MVT::v16f32, Legal); setOperationAction(ISD::FSUB, MVT::v16f32, Legal); setOperationAction(ISD::FMUL, MVT::v16f32, Legal); @@ -1661,6 +1670,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addRegisterClass(MVT::v32i1, &X86::VK32RegClass); addRegisterClass(MVT::v64i1, &X86::VK64RegClass); + setOperationAction(ISD::LOAD, MVT::v32i1, Legal); + setOperationAction(ISD::LOAD, MVT::v64i1, Legal); setOperationAction(ISD::LOAD, MVT::v32i16, Legal); setOperationAction(ISD::LOAD, MVT::v64i8, Legal); setOperationAction(ISD::SETCC, MVT::v32i1, Custom); @@ -1757,6 +1768,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addRegisterClass(MVT::v4i1, &X86::VK4RegClass); addRegisterClass(MVT::v2i1, &X86::VK2RegClass); + setOperationAction(ISD::LOAD, MVT::v2i1, Legal); + setOperationAction(ISD::LOAD, MVT::v4i1, Legal); setOperationAction(ISD::TRUNCATE, MVT::v2i1, Custom); setOperationAction(ISD::TRUNCATE, MVT::v4i1, Custom); setOperationAction(ISD::SETCC, MVT::v4i1, Custom); @@ -16093,6 +16106,98 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi); } +static SDValue LowerExtended1BitVectorLoad(SDValue Op, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + + LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode()); + SDLoc dl(Ld); + EVT MemVT = Ld->getMemoryVT(); + assert(MemVT.isVector() && MemVT.getScalarType() == MVT::i1 && + "Expected i1 vector load"); + unsigned ExtOpcode = Ld->getExtensionType() == ISD::ZEXTLOAD ? + ISD::ZERO_EXTEND : ISD::SIGN_EXTEND; + MVT VT = Op.getValueType().getSimpleVT(); + unsigned NumElts = VT.getVectorNumElements(); + + if ((Subtarget.hasVLX() && Subtarget.hasBWI() && Subtarget.hasDQI()) || + NumElts == 16) { + // Load and extend - everything is legal + if (NumElts < 8) { + SDValue Load = DAG.getLoad(MVT::v8i1, dl, Ld->getChain(), + Ld->getBasePtr(), + Ld->getMemOperand()); + // Replace chain users with the new chain. + assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); + MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8); + SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, Load); + + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, ExtVec, + DAG.getIntPtrConstant(0, dl)); + } + SDValue Load = DAG.getLoad(MemVT, dl, Ld->getChain(), + Ld->getBasePtr(), + Ld->getMemOperand()); + // Replace chain users with the new chain. + assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); + + // Finally, do a normal sign-extend to the desired register. + return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Load); + } + + if (NumElts <= 8) { + // A subset, assume that we have only AVX-512F + unsigned NumBitsToLoad = NumElts < 8 ? 8 : NumElts; + MVT TypeToLoad = MVT::getIntegerVT(NumBitsToLoad); + SDValue Load = DAG.getLoad(TypeToLoad, dl, Ld->getChain(), + Ld->getBasePtr(), + Ld->getMemOperand()); + // Replace chain users with the new chain. + assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); + + MVT MaskVT = MVT::getVectorVT(MVT::i1, NumBitsToLoad); + SDValue BitVec = DAG.getBitcast(MaskVT, Load); + + if (NumElts == 8) + return DAG.getNode(ExtOpcode, dl, VT, BitVec); + + // we should take care to v4i1 and v2i1 + + MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8); + SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, BitVec); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, ExtVec, + DAG.getIntPtrConstant(0, dl)); + } + + assert(VT == MVT::v32i8 && "Unexpected extload type"); + + SmallVector<SDValue, 2> Chains; + + SDValue BasePtr = Ld->getBasePtr(); + SDValue LoadLo = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), + Ld->getBasePtr(), + Ld->getMemOperand()); + Chains.push_back(LoadLo.getValue(1)); + + SDValue BasePtrHi = + DAG.getNode(ISD::ADD, dl, BasePtr.getValueType(), BasePtr, + DAG.getConstant(2, dl, BasePtr.getValueType())); + + SDValue LoadHi = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), + BasePtrHi, + Ld->getMemOperand()); + Chains.push_back(LoadHi.getValue(1)); + SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Chains); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewChain); + + SDValue Lo = DAG.getNode(ExtOpcode, dl, MVT::v16i8, LoadLo); + SDValue Hi = DAG.getNode(ExtOpcode, dl, MVT::v16i8, LoadHi); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v32i8, Lo, Hi); +} + // Lower vector extended loads using a shuffle. If SSSE3 is not available we // may emit an illegal shuffle but the expansion is still better than scalar // code. We generate X86ISD::VSEXT for SEXTLOADs if it's available, otherwise @@ -16113,6 +16218,9 @@ static SDValue LowerExtendedLoad(SDValue Op, const X86Subtarget &Subtarget, LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode()); SDLoc dl(Ld); EVT MemVT = Ld->getMemoryVT(); + if (MemVT.getScalarType() == MVT::i1) + return LowerExtended1BitVectorLoad(Op, Subtarget, DAG); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned RegSz = RegVT.getSizeInBits(); diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index 67338778fe5..f6749dc1717 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -2091,6 +2091,11 @@ let Predicates = [HasDQI] in { (KMOVBmk addr:$dst, (COPY_TO_REGCLASS VK2:$src, VK8))>; def : Pat<(store VK1:$src, addr:$dst), (KMOVBmk addr:$dst, (COPY_TO_REGCLASS VK1:$src, VK8))>; + + def : Pat<(v2i1 (load addr:$src)), + (COPY_TO_REGCLASS (KMOVBkm addr:$src), VK2)>; + def : Pat<(v4i1 (load addr:$src)), + (COPY_TO_REGCLASS (KMOVBkm addr:$src), VK4)>; } let Predicates = [HasAVX512, NoDQI] in { def : Pat<(store VK1:$src, addr:$dst), @@ -2110,18 +2115,19 @@ let Predicates = [HasAVX512, NoDQI] in { (EXTRACT_SUBREG (KMOVWrk (COPY_TO_REGCLASS VK8:$src, VK16)), sub_8bit))>; - def : Pat<(store (i8 (bitconvert (v8i1 VK8:$src))), addr:$dst), - (KMOVWmk addr:$dst, (COPY_TO_REGCLASS VK8:$src, VK16))>; - def : Pat<(v8i1 (bitconvert (i8 (load addr:$src)))), - (COPY_TO_REGCLASS (KMOVWkm addr:$src), VK8)>; + def : Pat<(v8i1 (load addr:$src)), + (COPY_TO_REGCLASS (MOVZX16rm8 addr:$src), VK8)>; + def : Pat<(v2i1 (load addr:$src)), + (COPY_TO_REGCLASS (MOVZX16rm8 addr:$src), VK2)>; + def : Pat<(v4i1 (load addr:$src)), + (COPY_TO_REGCLASS (MOVZX16rm8 addr:$src), VK4)>; } + let Predicates = [HasAVX512] in { def : Pat<(store (i16 (bitconvert (v16i1 VK16:$src))), addr:$dst), (KMOVWmk addr:$dst, VK16:$src)>; def : Pat<(i1 (load addr:$src)), - (COPY_TO_REGCLASS (AND16ri (i16 (SUBREG_TO_REG (i32 0), - (MOV8rm addr:$src), sub_8bit)), - (i16 1)), VK1)>; + (COPY_TO_REGCLASS (AND16ri (MOVZX16rm8 addr:$src), (i16 1)), VK1)>; def : Pat<(v16i1 (bitconvert (i16 (load addr:$src)))), (KMOVWkm addr:$src)>; } @@ -2130,8 +2136,6 @@ let Predicates = [HasBWI] in { (KMOVDmk addr:$dst, VK32:$src)>; def : Pat<(v32i1 (bitconvert (i32 (load addr:$src)))), (KMOVDkm addr:$src)>; -} -let Predicates = [HasBWI] in { def : Pat<(store (i64 (bitconvert (v64i1 VK64:$src))), addr:$dst), (KMOVQmk addr:$dst, VK64:$src)>; def : Pat<(v64i1 (bitconvert (i64 (load addr:$src)))), |

