diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 183 |
1 files changed, 108 insertions, 75 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index acdedab7e13..583a09e34ab 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -30,6 +30,7 @@ #include "SIInstrInfo.h" #include "SIMachineFunctionInfo.h" #include "MCTargetDesc/AMDGPUMCTargetDesc.h" +#include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineRegisterInfo.h" @@ -40,18 +41,6 @@ #include "llvm/Support/KnownBits.h" using namespace llvm; -static bool allocateKernArg(unsigned ValNo, MVT ValVT, MVT LocVT, - CCValAssign::LocInfo LocInfo, - ISD::ArgFlagsTy ArgFlags, CCState &State) { - MachineFunction &MF = State.getMachineFunction(); - AMDGPUMachineFunction *MFI = MF.getInfo<AMDGPUMachineFunction>(); - - uint64_t Offset = MFI->allocateKernArg(LocVT.getStoreSize(), - ArgFlags.getOrigAlign()); - State.addLoc(CCValAssign::getCustomMem(ValNo, ValVT, Offset, LocVT, LocInfo)); - return true; -} - static bool allocateCCRegs(unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State, @@ -910,74 +899,118 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForReturn(CallingConv::ID CC, /// for each individual part is i8. We pass the memory type as LocVT to the /// calling convention analysis function and the register type (Ins[x].VT) as /// the ValVT. -void AMDGPUTargetLowering::analyzeFormalArgumentsCompute(CCState &State, - const SmallVectorImpl<ISD::InputArg> &Ins) const { - for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - const ISD::InputArg &In = Ins[i]; - EVT MemVT; - - unsigned NumRegs = getNumRegisters(State.getContext(), In.ArgVT); - - if (!Subtarget->isAmdHsaOS() && - (In.ArgVT == MVT::i16 || In.ArgVT == MVT::i8 || In.ArgVT == MVT::f16)) { - // The ABI says the caller will extend these values to 32-bits. - MemVT = In.ArgVT.isInteger() ? MVT::i32 : MVT::f32; - } else if (NumRegs == 1) { - // This argument is not split, so the IR type is the memory type. - assert(!In.Flags.isSplit()); - if (In.ArgVT.isExtended()) { - // We have an extended type, like i24, so we should just use the register type - MemVT = In.VT; - } else { - MemVT = In.ArgVT; - } - } else if (In.ArgVT.isVector() && In.VT.isVector() && - In.ArgVT.getScalarType() == In.VT.getScalarType()) { - assert(In.ArgVT.getVectorNumElements() > In.VT.getVectorNumElements()); - // We have a vector value which has been split into a vector with - // the same scalar type, but fewer elements. This should handle - // all the floating-point vector types. - MemVT = In.VT; - } else if (In.ArgVT.isVector() && - In.ArgVT.getVectorNumElements() == NumRegs) { - // This arg has been split so that each element is stored in a separate - // register. - MemVT = In.ArgVT.getScalarType(); - } else if (In.ArgVT.isExtended()) { - // We have an extended type, like i65. - MemVT = In.VT; - } else { - unsigned MemoryBits = In.ArgVT.getStoreSizeInBits() / NumRegs; - assert(In.ArgVT.getStoreSizeInBits() % NumRegs == 0); - if (In.VT.isInteger()) { - MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits); - } else if (In.VT.isVector()) { - assert(!In.VT.getScalarType().isFloatingPoint()); - unsigned NumElements = In.VT.getVectorNumElements(); - assert(MemoryBits % NumElements == 0); - // This vector type has been split into another vector type with - // a different elements size. - EVT ScalarVT = EVT::getIntegerVT(State.getContext(), - MemoryBits / NumElements); - MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements); +void AMDGPUTargetLowering::analyzeFormalArgumentsCompute( + CCState &State, + const SmallVectorImpl<ISD::InputArg> &Ins) const { + const MachineFunction &MF = State.getMachineFunction(); + const Function &Fn = MF.getFunction(); + LLVMContext &Ctx = Fn.getParent()->getContext(); + const AMDGPUSubtarget &ST = AMDGPUSubtarget::get(MF); + const unsigned ExplicitOffset = ST.getExplicitKernelArgOffset(Fn); + + unsigned MaxAlign = 1; + uint64_t ExplicitArgOffset = 0; + const DataLayout &DL = Fn.getParent()->getDataLayout(); + + unsigned InIndex = 0; + + for (const Argument &Arg : Fn.args()) { + Type *BaseArgTy = Arg.getType(); + unsigned Align = DL.getABITypeAlignment(BaseArgTy); + MaxAlign = std::max(Align, MaxAlign); + unsigned AllocSize = DL.getTypeAllocSize(BaseArgTy); + + uint64_t ArgOffset = alignTo(ExplicitArgOffset, Align) + ExplicitOffset; + ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize; + + // We're basically throwing away everything passed into us and starting over + // to get accurate in-memory offsets. The "PartOffset" is completely useless + // to us as computed in Ins. + // + // We also need to figure out what type legalization is trying to do to get + // the correct memory offsets. + + SmallVector<EVT, 16> ValueVTs; + SmallVector<uint64_t, 16> Offsets; + ComputeValueVTs(*this, DL, BaseArgTy, ValueVTs, &Offsets, ArgOffset); + + for (unsigned Value = 0, NumValues = ValueVTs.size(); + Value != NumValues; ++Value) { + uint64_t BasePartOffset = Offsets[Value]; + + EVT ArgVT = ValueVTs[Value]; + EVT MemVT = ArgVT; + MVT RegisterVT = + getRegisterTypeForCallingConv(Ctx, ArgVT); + unsigned NumRegs = + getNumRegistersForCallingConv(Ctx, ArgVT); + + if (!Subtarget->isAmdHsaOS() && + (ArgVT == MVT::i16 || ArgVT == MVT::i8 || ArgVT == MVT::f16)) { + // The ABI says the caller will extend these values to 32-bits. + MemVT = ArgVT.isInteger() ? MVT::i32 : MVT::f32; + } else if (NumRegs == 1) { + // This argument is not split, so the IR type is the memory type. + if (ArgVT.isExtended()) { + // We have an extended type, like i24, so we should just use the + // register type. + MemVT = RegisterVT; + } else { + MemVT = ArgVT; + } + } else if (ArgVT.isVector() && RegisterVT.isVector() && + ArgVT.getScalarType() == RegisterVT.getScalarType()) { + assert(ArgVT.getVectorNumElements() > RegisterVT.getVectorNumElements()); + // We have a vector value which has been split into a vector with + // the same scalar type, but fewer elements. This should handle + // all the floating-point vector types. + MemVT = RegisterVT; + } else if (ArgVT.isVector() && + ArgVT.getVectorNumElements() == NumRegs) { + // This arg has been split so that each element is stored in a separate + // register. + MemVT = ArgVT.getScalarType(); + } else if (ArgVT.isExtended()) { + // We have an extended type, like i65. + MemVT = RegisterVT; } else { - llvm_unreachable("cannot deduce memory type."); + unsigned MemoryBits = ArgVT.getStoreSizeInBits() / NumRegs; + assert(ArgVT.getStoreSizeInBits() % NumRegs == 0); + if (RegisterVT.isInteger()) { + MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits); + } else if (RegisterVT.isVector()) { + assert(!RegisterVT.getScalarType().isFloatingPoint()); + unsigned NumElements = RegisterVT.getVectorNumElements(); + assert(MemoryBits % NumElements == 0); + // This vector type has been split into another vector type with + // a different elements size. + EVT ScalarVT = EVT::getIntegerVT(State.getContext(), + MemoryBits / NumElements); + MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements); + } else { + llvm_unreachable("cannot deduce memory type."); + } } - } - // Convert one element vectors to scalar. - if (MemVT.isVector() && MemVT.getVectorNumElements() == 1) - MemVT = MemVT.getScalarType(); + // Convert one element vectors to scalar. + if (MemVT.isVector() && MemVT.getVectorNumElements() == 1) + MemVT = MemVT.getScalarType(); - if (MemVT.isExtended()) { - // This should really only happen if we have vec3 arguments - assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3); - MemVT = MemVT.getPow2VectorType(State.getContext()); - } + if (MemVT.isExtended()) { + // This should really only happen if we have vec3 arguments + assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3); + MemVT = MemVT.getPow2VectorType(State.getContext()); + } - assert(MemVT.isSimple()); - allocateKernArg(i, In.VT, MemVT.getSimpleVT(), CCValAssign::Full, In.Flags, - State); + unsigned PartOffset = 0; + for (unsigned i = 0; i != NumRegs; ++i) { + State.addLoc(CCValAssign::getCustomMem(InIndex++, RegisterVT, + BasePartOffset + PartOffset, + MemVT.getSimpleVT(), + CCValAssign::Full)); + PartOffset += MemVT.getStoreSize(); + } + } } } |