diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 40 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86InstrSSE.td | 5 |
2 files changed, 29 insertions, 16 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 2aa9f50933c..0c0b788231c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1266,7 +1266,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) { - setOperationAction(ISD::MLOAD, VT, Legal); + setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::MSTORE, VT, Legal); } @@ -1412,15 +1412,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTruncStoreAction(MVT::v16i32, MVT::v16i8, Legal); setTruncStoreAction(MVT::v16i32, MVT::v16i16, Legal); - if (!Subtarget.hasVLX()) { - // With 512-bit vectors and no VLX, we prefer to widen MLOAD/MSTORE - // to 512-bit rather than use the AVX2 instructions so that we can use - // k-masks. - for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, - MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) { - setOperationAction(ISD::MLOAD, VT, Custom); - setOperationAction(ISD::MSTORE, VT, Custom); - } + // With 512-bit vectors and no VLX, we prefer to widen MLOAD/MSTORE + // to 512-bit rather than use the AVX2 instructions so that we can use + // k-masks. + for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, + MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) { + setOperationAction(ISD::MLOAD, VT, Subtarget.hasVLX() ? Legal : Custom); + setOperationAction(ISD::MSTORE, VT, Subtarget.hasVLX() ? Legal : Custom); } setOperationAction(ISD::TRUNCATE, MVT::v8i32, Custom); @@ -26914,8 +26912,28 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, MVT VT = Op.getSimpleValueType(); MVT ScalarVT = VT.getScalarType(); SDValue Mask = N->getMask(); + MVT MaskVT = Mask.getSimpleValueType(); + SDValue PassThru = N->getPassThru(); SDLoc dl(Op); + // Handle AVX masked loads which don't support passthru other than 0. + if (MaskVT.getVectorElementType() != MVT::i1) { + // We also allow undef in the isel pattern. + if (PassThru.isUndef() || ISD::isBuildVectorAllZeros(PassThru.getNode())) + return Op; + + SDValue NewLoad = DAG.getMaskedLoad(VT, dl, N->getChain(), + N->getBasePtr(), Mask, + getZeroVector(VT, Subtarget, DAG, dl), + N->getMemoryVT(), N->getMemOperand(), + N->getExtensionType(), + N->isExpandingLoad()); + // Emit a blend. + SDValue Select = DAG.getNode(ISD::VSELECT, dl, MaskVT, Mask, NewLoad, + PassThru); + return DAG.getMergeValues({ Select, NewLoad.getValue(1) }, dl); + } + assert((!N->isExpandingLoad() || Subtarget.hasAVX512()) && "Expanding masked load is supported on AVX-512 target only!"); @@ -26934,7 +26952,7 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, // VLX the vector should be widened to 512 bit unsigned NumEltsInWideVec = 512 / VT.getScalarSizeInBits(); MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec); - SDValue PassThru = ExtendToType(N->getPassThru(), WideDataVT, DAG); + PassThru = ExtendToType(PassThru, WideDataVT, DAG); // Mask element has to be i1. assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 && diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td index 15b260ebe60..ebc284ca091 100644 --- a/llvm/lib/Target/X86/X86InstrSSE.td +++ b/llvm/lib/Target/X86/X86InstrSSE.td @@ -7757,11 +7757,6 @@ multiclass maskmov_lowering<string InstrStr, RegisterClass RC, ValueType VT, def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), (VT immAllZerosV))), (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)>; - def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))), - (!cast<Instruction>(BlendStr#"rr") - RC:$src0, - (VT (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)), - RC:$mask)>; } let Predicates = [HasAVX] in { defm : maskmov_lowering<"VMASKMOVPS", VR128, v4f32, v4i32, "VBLENDVPS", v4i32>; |