diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp')
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp | 325 |
1 files changed, 256 insertions, 69 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp index 8b4fca2ef3c..bc297680c37 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -76,6 +76,31 @@ static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) { }; } +// Increase the number of vector elements to reach the next multiple of 32-bit +// type. +static LegalizeMutation moreEltsToNext32Bit(unsigned TypeIdx) { + return [=](const LegalityQuery &Query) { + const LLT Ty = Query.Types[TypeIdx]; + + const LLT EltTy = Ty.getElementType(); + const int Size = Ty.getSizeInBits(); + const int EltSize = EltTy.getSizeInBits(); + const int NextMul32 = (Size + 31) / 32; + + assert(EltSize < 32); + + const int NewNumElts = (32 * NextMul32 + EltSize - 1) / EltSize; + return std::make_pair(TypeIdx, LLT::vector(NewNumElts, EltTy)); + }; +} + +static LegalityPredicate vectorSmallerThan(unsigned TypeIdx, unsigned Size) { + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isVector() && QueryTy.getSizeInBits() < Size; + }; +} + static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { const LLT QueryTy = Query.Types[TypeIdx]; @@ -112,6 +137,14 @@ static LegalityPredicate elementTypeIs(unsigned TypeIdx, LLT Type) { }; } +static LegalityPredicate isWideScalarTruncStore(unsigned TypeIdx) { + return [=](const LegalityQuery &Query) { + const LLT Ty = Query.Types[TypeIdx]; + return !Ty.isVector() && Ty.getSizeInBits() > 32 && + Query.MMODescrs[0].SizeInBits < Ty.getSizeInBits(); + }; +} + AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_, const GCNTargetMachine &TM) : ST(ST_) { @@ -126,6 +159,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_, const LLT S16 = LLT::scalar(16); const LLT S32 = LLT::scalar(32); const LLT S64 = LLT::scalar(64); + const LLT S96 = LLT::scalar(96); const LLT S128 = LLT::scalar(128); const LLT S256 = LLT::scalar(256); const LLT S512 = LLT::scalar(512); @@ -246,7 +280,9 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_, .legalForCartesianProduct({S64, V2S32, V4S16}) .legalForCartesianProduct({V2S64, V4S32}) // Don't worry about the size constraint. - .legalIf(all(isPointer(0), isPointer(1))); + .legalIf(all(isPointer(0), isPointer(1))) + // FIXME: Testing hack + .legalForCartesianProduct({S16, LLT::vector(2, 8), }); getActionDefinitionsBuilder(G_FCONSTANT) .legalFor({S32, S64, S16}) @@ -358,6 +394,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_, getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT}) .legalFor({{S64, S32}, {S32, S16}, {S64, S16}, {S32, S1}, {S64, S1}, {S16, S1}, + {S96, S32}, // FIXME: Hack {S64, LLT::scalar(33)}, {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}}) @@ -523,79 +560,229 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_, // TODO: Should load to s16 be legal? Most loads extend to 32-bits, but we // handle some operations by just promoting the register during // selection. There are also d16 loads on GFX9+ which preserve the high bits. - getActionDefinitionsBuilder({G_LOAD, G_STORE}) - .narrowScalarIf([](const LegalityQuery &Query) { - unsigned Size = Query.Types[0].getSizeInBits(); - unsigned MemSize = Query.MMODescrs[0].SizeInBits; - return (Size > 32 && MemSize < Size); - }, - [](const LegalityQuery &Query) { - return std::make_pair(0, LLT::scalar(32)); - }) - .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) - .fewerElementsIf([=](const LegalityQuery &Query) { - unsigned MemSize = Query.MMODescrs[0].SizeInBits; - return (MemSize == 96) && - Query.Types[0].isVector() && - !ST.hasDwordx3LoadStores(); - }, - [=](const LegalityQuery &Query) { - return std::make_pair(0, V2S32); - }) - .legalIf([=](const LegalityQuery &Query) { - const LLT &Ty0 = Query.Types[0]; - - unsigned Size = Ty0.getSizeInBits(); - unsigned MemSize = Query.MMODescrs[0].SizeInBits; - if (Size < 32 || (Size > 32 && MemSize < Size)) - return false; - - if (Ty0.isVector() && Size != MemSize) - return false; - - // TODO: Decompose private loads into 4-byte components. - // TODO: Illegal flat loads on SI - switch (MemSize) { - case 8: - case 16: - return Size == 32; - case 32: - case 64: - case 128: - return true; + auto maxSizeForAddrSpace = [this](unsigned AS) -> unsigned { + switch (AS) { + // FIXME: Private element size. + case AMDGPUAS::PRIVATE_ADDRESS: + return 32; + // FIXME: Check subtarget + case AMDGPUAS::LOCAL_ADDRESS: + return ST.useDS128() ? 128 : 64; + + // Treat constant and global as identical. SMRD loads are sometimes usable + // for global loads (ideally constant address space should be eliminated) + // depending on the context. Legality cannot be context dependent, but + // RegBankSelect can split the load as necessary depending on the pointer + // register bank/uniformity and if the memory is invariant or not written in + // a kernel. + case AMDGPUAS::CONSTANT_ADDRESS: + case AMDGPUAS::GLOBAL_ADDRESS: + return 512; + default: + return 128; + } + }; - case 96: - return ST.hasDwordx3LoadStores(); - - case 256: - case 512: - // TODO: Possibly support loads of i256 and i512 . This will require - // adding i256 and i512 types to MVT in order for to be able to use - // TableGen. - // TODO: Add support for other vector types, this will require - // defining more value mappings for the new types. - return Ty0.isVector() && (Ty0.getScalarType().getSizeInBits() == 32 || - Ty0.getScalarType().getSizeInBits() == 64); - - default: - return false; - } - }) - .clampScalar(0, S32, S64); + const auto needToSplitLoad = [=](const LegalityQuery &Query) -> bool { + const LLT DstTy = Query.Types[0]; + + // Split vector extloads. + unsigned MemSize = Query.MMODescrs[0].SizeInBits; + if (DstTy.isVector() && DstTy.getSizeInBits() > MemSize) + return true; + + const LLT PtrTy = Query.Types[1]; + unsigned AS = PtrTy.getAddressSpace(); + if (MemSize > maxSizeForAddrSpace(AS)) + return true; + + // Catch weird sized loads that don't evenly divide into the access sizes + // TODO: May be able to widen depending on alignment etc. + unsigned NumRegs = MemSize / 32; + if (NumRegs == 3 && !ST.hasDwordx3LoadStores()) + return true; + + unsigned Align = Query.MMODescrs[0].AlignInBits; + if (Align < MemSize) { + const SITargetLowering *TLI = ST.getTargetLowering(); + return !TLI->allowsMisalignedMemoryAccessesImpl(MemSize, AS, Align / 8); + } + + return false; + }; + + unsigned GlobalAlign32 = ST.hasUnalignedBufferAccess() ? 0 : 32; + unsigned GlobalAlign16 = ST.hasUnalignedBufferAccess() ? 0 : 16; + unsigned GlobalAlign8 = ST.hasUnalignedBufferAccess() ? 0 : 8; + + // TODO: Refine based on subtargets which support unaligned access or 128-bit + // LDS + // TODO: Unsupported flat for SI. + + for (unsigned Op : {G_LOAD, G_STORE}) { + const bool IsStore = Op == G_STORE; + + auto &Actions = getActionDefinitionsBuilder(Op); + // Whitelist the common cases. + // TODO: Pointer loads + // TODO: Wide constant loads + // TODO: Only CI+ has 3x loads + // TODO: Loads to s16 on gfx9 + Actions.legalForTypesWithMemDesc({{S32, GlobalPtr, 32, GlobalAlign32}, + {V2S32, GlobalPtr, 64, GlobalAlign32}, + {V3S32, GlobalPtr, 96, GlobalAlign32}, + {S96, GlobalPtr, 96, GlobalAlign32}, + {V4S32, GlobalPtr, 128, GlobalAlign32}, + {S128, GlobalPtr, 128, GlobalAlign32}, + {S64, GlobalPtr, 64, GlobalAlign32}, + {V2S64, GlobalPtr, 128, GlobalAlign32}, + {V2S16, GlobalPtr, 32, GlobalAlign32}, + {S32, GlobalPtr, 8, GlobalAlign8}, + {S32, GlobalPtr, 16, GlobalAlign16}, + + {S32, LocalPtr, 32, 32}, + {S64, LocalPtr, 64, 32}, + {V2S32, LocalPtr, 64, 32}, + {S32, LocalPtr, 8, 8}, + {S32, LocalPtr, 16, 16}, + {V2S16, LocalPtr, 32, 32}, + + {S32, PrivatePtr, 32, 32}, + {S32, PrivatePtr, 8, 8}, + {S32, PrivatePtr, 16, 16}, + {V2S16, PrivatePtr, 32, 32}, + + {S32, FlatPtr, 32, GlobalAlign32}, + {S32, FlatPtr, 16, GlobalAlign16}, + {S32, FlatPtr, 8, GlobalAlign8}, + {V2S16, FlatPtr, 32, GlobalAlign32}, + + {S32, ConstantPtr, 32, GlobalAlign32}, + {V2S32, ConstantPtr, 64, GlobalAlign32}, + {V3S32, ConstantPtr, 96, GlobalAlign32}, + {V4S32, ConstantPtr, 128, GlobalAlign32}, + {S64, ConstantPtr, 64, GlobalAlign32}, + {S128, ConstantPtr, 128, GlobalAlign32}, + {V2S32, ConstantPtr, 32, GlobalAlign32}}); + Actions + .narrowScalarIf( + [=](const LegalityQuery &Query) -> bool { + return !Query.Types[0].isVector() && needToSplitLoad(Query); + }, + [=](const LegalityQuery &Query) -> std::pair<unsigned, LLT> { + const LLT DstTy = Query.Types[0]; + const LLT PtrTy = Query.Types[1]; + + const unsigned DstSize = DstTy.getSizeInBits(); + unsigned MemSize = Query.MMODescrs[0].SizeInBits; + + // Split extloads. + if (DstSize > MemSize) + return std::make_pair(0, LLT::scalar(MemSize)); + + if (DstSize > 32 && (DstSize % 32 != 0)) { + // FIXME: Need a way to specify non-extload of larger size if + // suitably aligned. + return std::make_pair(0, LLT::scalar(32 * (DstSize / 32))); + } + + unsigned MaxSize = maxSizeForAddrSpace(PtrTy.getAddressSpace()); + if (MemSize > MaxSize) + return std::make_pair(0, LLT::scalar(MaxSize)); + + unsigned Align = Query.MMODescrs[0].AlignInBits; + return std::make_pair(0, LLT::scalar(Align)); + }) + .fewerElementsIf( + [=](const LegalityQuery &Query) -> bool { + return Query.Types[0].isVector() && needToSplitLoad(Query); + }, + [=](const LegalityQuery &Query) -> std::pair<unsigned, LLT> { + const LLT DstTy = Query.Types[0]; + const LLT PtrTy = Query.Types[1]; + + LLT EltTy = DstTy.getElementType(); + unsigned MaxSize = maxSizeForAddrSpace(PtrTy.getAddressSpace()); + + // Split if it's too large for the address space. + if (Query.MMODescrs[0].SizeInBits > MaxSize) { + unsigned NumElts = DstTy.getNumElements(); + unsigned NumPieces = Query.MMODescrs[0].SizeInBits / MaxSize; + + // FIXME: Refine when odd breakdowns handled + // The scalars will need to be re-legalized. + if (NumPieces == 1 || NumPieces >= NumElts || + NumElts % NumPieces != 0) + return std::make_pair(0, EltTy); + + return std::make_pair(0, + LLT::vector(NumElts / NumPieces, EltTy)); + } + + // Need to split because of alignment. + unsigned Align = Query.MMODescrs[0].AlignInBits; + unsigned EltSize = EltTy.getSizeInBits(); + if (EltSize > Align && + (EltSize / Align < DstTy.getNumElements())) { + return std::make_pair(0, LLT::vector(EltSize / Align, EltTy)); + } + + // May need relegalization for the scalars. + return std::make_pair(0, EltTy); + }) + .minScalar(0, S32); + + if (IsStore) + Actions.narrowScalarIf(isWideScalarTruncStore(0), changeTo(0, S32)); + + // TODO: Need a bitcast lower option? + Actions + .legalIf([=](const LegalityQuery &Query) { + const LLT Ty0 = Query.Types[0]; + unsigned Size = Ty0.getSizeInBits(); + unsigned MemSize = Query.MMODescrs[0].SizeInBits; + unsigned Align = Query.MMODescrs[0].AlignInBits; + + // No extending vector loads. + if (Size > MemSize && Ty0.isVector()) + return false; + // FIXME: Widening store from alignment not valid. + if (MemSize < Size) + MemSize = std::max(MemSize, Align); + + switch (MemSize) { + case 8: + case 16: + return Size == 32; + case 32: + case 64: + case 128: + return true; + case 96: + return ST.hasDwordx3LoadStores(); + case 256: + case 512: + return true; + default: + return false; + } + }) + .widenScalarToNextPow2(0) + // TODO: v3s32->v4s32 with alignment + .moreElementsIf(vectorSmallerThan(0, 32), moreEltsToNext32Bit(0)); + } - // FIXME: Handle alignment requirements. auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD}) - .legalForTypesWithMemDesc({ - {S32, GlobalPtr, 8, 8}, - {S32, GlobalPtr, 16, 8}, - {S32, LocalPtr, 8, 8}, - {S32, LocalPtr, 16, 8}, - {S32, PrivatePtr, 8, 8}, - {S32, PrivatePtr, 16, 8}}); + .legalForTypesWithMemDesc({{S32, GlobalPtr, 8, 8}, + {S32, GlobalPtr, 16, 2 * 8}, + {S32, LocalPtr, 8, 8}, + {S32, LocalPtr, 16, 16}, + {S32, PrivatePtr, 8, 8}, + {S32, PrivatePtr, 16, 16}}); if (ST.hasFlatAddressSpace()) { - ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8}, - {S32, FlatPtr, 16, 8}}); + ExtLoads.legalForTypesWithMemDesc( + {{S32, FlatPtr, 8, 8}, {S32, FlatPtr, 16, 16}}); } ExtLoads.clampScalar(0, S32, S32) |