diff options
author | Kerry McLaughlin <kerry.mclaughlin@arm.com> | 2019-10-28 10:00:57 +0000 |
---|---|---|
committer | Kerry McLaughlin <kerry.mclaughlin@arm.com> | 2019-10-28 10:06:14 +0000 |
commit | da720a38b9f24cc92b46fd5df503b13d5c823285 (patch) | |
tree | d452c0795fb73e6598c9a158c8a743b7bf3a2c93 /llvm/lib/Target | |
parent | 7214f7a79f4bf791e5c6726757dbcec143f0aa91 (diff) | |
download | bcm5719-llvm-da720a38b9f24cc92b46fd5df503b13d5c823285.tar.gz bcm5719-llvm-da720a38b9f24cc92b46fd5df503b13d5c823285.zip |
[AArch64][SVE] Implement masked load intrinsics
Summary:
Adds support for codegen of masked loads, with non-extending,
zero-extending and sign-extending variants.
Reviewers: huntergr, rovka, greened, dmgreen
Reviewed By: dmgreen
Subscribers: dmgreen, samparker, tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, cfe-commits, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D68877
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp | 20 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.td | 49 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 38 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 15 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/SVEInstrFormats.td | 10 |
7 files changed, 136 insertions, 4 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index 1f08505f37e..054d0f15e0d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -140,6 +140,26 @@ public: return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift); } + bool SelectDupZeroOrUndef(SDValue N) { + switch(N->getOpcode()) { + case ISD::UNDEF: + return true; + case AArch64ISD::DUP: + case ISD::SPLAT_VECTOR: { + auto Opnd0 = N->getOperand(0); + if (auto CN = dyn_cast<ConstantSDNode>(Opnd0)) + if (CN->isNullValue()) + return true; + if (auto CN = dyn_cast<ConstantFPSDNode>(Opnd0)) + if (CN->isZero()) + return true; + } + default: + break; + } + + return false; + } /// Form sequences of consecutive 64/128-bit registers for use in NEON /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 2746117e8ee..aa9e26c879a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -802,6 +802,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, } if (Subtarget->hasSVE()) { + // FIXME: Add custom lowering of MLOAD to handle different passthrus (not a + // splat of 0 or undef) once vector selects supported in SVE codegen. See + // D68877 for more details. for (MVT VT : MVT::integer_scalable_vector_valuetypes()) { if (isTypeLegal(VT) && VT.getVectorElementType() != MVT::i1) setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); @@ -2886,6 +2889,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, } } +bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { + return ExtVal.getValueType().isScalableVector(); +} + // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16. static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST, EVT VT, EVT MemVT, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 00fa96bc4e6..5a76f0c467b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -741,6 +741,7 @@ private: return TargetLowering::getInlineAsmMemConstraint(ConstraintCode); } + bool isVectorLoadExtDesirable(SDValue ExtVal) const override; bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; bool mayBeEmittedAsTailCall(const CallInst *CI) const override; bool getIndexedAddressParts(SDNode *Op, SDValue &Base, SDValue &Offset, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 1981bd5d3bf..89de1212443 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -259,6 +259,55 @@ def SDT_AArch64WrapperLarge : SDTypeProfile<1, 4, SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisSameAs<1, 4>]>; +// non-extending masked load fragment. +def nonext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; +// sign extending masked load fragments. +def asext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def),[{ + return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::EXTLOAD || + cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::SEXTLOAD; +}]>; +def asext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def asext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def asext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; +// zero extending masked load fragments. +def zext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::ZEXTLOAD; +}]>; +def zext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def zext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def zext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast<MaskedLoadSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; // Node definitions. def AArch64adrp : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index b573eac7675..379640eb5d3 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1070,6 +1070,44 @@ let Predicates = [HasSVE] in { def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>; def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + // Add more complex addressing modes here as required + multiclass pred_load<ValueType Ty, ValueType PredTy, SDPatternOperator Load, + Instruction RegImmInst> { + + def _default_z : Pat<(Ty (Load GPR64:$base, (PredTy PPR:$gp), (SVEDup0Undef))), + (RegImmInst PPR:$gp, GPR64:$base, (i64 0))>; + } + + // 2-element contiguous loads + defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i8, LD1B_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i8, LD1SB_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i16, LD1H_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i16, LD1SH_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i32, LD1W_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i32, LD1SW_D_IMM>; + defm : pred_load<nxv2i64, nxv2i1, nonext_masked_load, LD1D_IMM>; + defm : pred_load<nxv2f16, nxv2i1, nonext_masked_load, LD1H_D_IMM>; + defm : pred_load<nxv2f32, nxv2i1, nonext_masked_load, LD1W_D_IMM>; + defm : pred_load<nxv2f64, nxv2i1, nonext_masked_load, LD1D_IMM>; + + // 4-element contiguous loads + defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i8, LD1B_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i8, LD1SB_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i16, LD1H_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i16, LD1SH_S_IMM>; + defm : pred_load<nxv4i32, nxv4i1, nonext_masked_load, LD1W_IMM>; + defm : pred_load<nxv4f16, nxv4i1, nonext_masked_load, LD1H_S_IMM>; + defm : pred_load<nxv4f32, nxv4i1, nonext_masked_load, LD1W_IMM>; + + // 8-element contiguous loads + defm : pred_load<nxv8i16, nxv8i1, zext_masked_load_i8, LD1B_H_IMM>; + defm : pred_load<nxv8i16, nxv8i1, asext_masked_load_i8, LD1SB_H_IMM>; + defm : pred_load<nxv8i16, nxv8i1, nonext_masked_load, LD1H_IMM>; + defm : pred_load<nxv8f16, nxv8i1, nonext_masked_load, LD1H_IMM>; + + // 16-element contiguous loads + defm : pred_load<nxv16i8, nxv16i1, nonext_masked_load, LD1B_IMM>; + } let Predicates = [HasSVE2] in { diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index e85350cb6d4..9956c9c07e4 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -147,6 +147,21 @@ public: bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info); + bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) { + if (!isa<VectorType>(DataType) || !ST->hasSVE()) + return false; + + Type *Ty = DataType->getVectorElementType(); + if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) + return true; + + if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || + Ty->isIntegerTy(32) || Ty->isIntegerTy(64)) + return true; + + return false; + } + int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, unsigned Alignment, unsigned AddressSpace, diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 8ccf6aa675b..4d5b3ee7b8d 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -293,6 +293,8 @@ class SVE_3_Op_Pat<ValueType vtd, SDPatternOperator op, ValueType vt1, : Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3)), (inst $Op1, $Op2, $Op3)>; +def SVEDup0Undef : ComplexPattern<i64, 0, "SelectDupZeroOrUndef", []>; + //===----------------------------------------------------------------------===// // SVE Predicate Misc Group //===----------------------------------------------------------------------===// @@ -4736,14 +4738,14 @@ class sve_mem_cld_si_base<bits<4> dtype, bit nf, string asm, multiclass sve_mem_cld_si_base<bits<4> dtype, bit nf, string asm, RegisterOperand listty, ZPRRegOp zprty> { - def _REAL : sve_mem_cld_si_base<dtype, nf, asm, listty>; + def "" : sve_mem_cld_si_base<dtype, nf, asm, listty>; def : InstAlias<asm # "\t$Zt, $Pg/z, [$Rn]", - (!cast<Instruction>(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 0>; + (!cast<Instruction>(NAME) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 0>; def : InstAlias<asm # "\t$Zt, $Pg/z, [$Rn, $imm4, mul vl]", - (!cast<Instruction>(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>; + (!cast<Instruction>(NAME) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>; def : InstAlias<asm # "\t$Zt, $Pg/z, [$Rn]", - (!cast<Instruction>(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>; + (!cast<Instruction>(NAME) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>; } multiclass sve_mem_cld_si<bits<4> dtype, string asm, RegisterOperand listty, |