summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp183
1 files changed, 108 insertions, 75 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index acdedab7e13..583a09e34ab 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -30,6 +30,7 @@
#include "SIInstrInfo.h"
#include "SIMachineFunctionInfo.h"
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
+#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
@@ -40,18 +41,6 @@
#include "llvm/Support/KnownBits.h"
using namespace llvm;
-static bool allocateKernArg(unsigned ValNo, MVT ValVT, MVT LocVT,
- CCValAssign::LocInfo LocInfo,
- ISD::ArgFlagsTy ArgFlags, CCState &State) {
- MachineFunction &MF = State.getMachineFunction();
- AMDGPUMachineFunction *MFI = MF.getInfo<AMDGPUMachineFunction>();
-
- uint64_t Offset = MFI->allocateKernArg(LocVT.getStoreSize(),
- ArgFlags.getOrigAlign());
- State.addLoc(CCValAssign::getCustomMem(ValNo, ValVT, Offset, LocVT, LocInfo));
- return true;
-}
-
static bool allocateCCRegs(unsigned ValNo, MVT ValVT, MVT LocVT,
CCValAssign::LocInfo LocInfo,
ISD::ArgFlagsTy ArgFlags, CCState &State,
@@ -910,74 +899,118 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForReturn(CallingConv::ID CC,
/// for each individual part is i8. We pass the memory type as LocVT to the
/// calling convention analysis function and the register type (Ins[x].VT) as
/// the ValVT.
-void AMDGPUTargetLowering::analyzeFormalArgumentsCompute(CCState &State,
- const SmallVectorImpl<ISD::InputArg> &Ins) const {
- for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
- const ISD::InputArg &In = Ins[i];
- EVT MemVT;
-
- unsigned NumRegs = getNumRegisters(State.getContext(), In.ArgVT);
-
- if (!Subtarget->isAmdHsaOS() &&
- (In.ArgVT == MVT::i16 || In.ArgVT == MVT::i8 || In.ArgVT == MVT::f16)) {
- // The ABI says the caller will extend these values to 32-bits.
- MemVT = In.ArgVT.isInteger() ? MVT::i32 : MVT::f32;
- } else if (NumRegs == 1) {
- // This argument is not split, so the IR type is the memory type.
- assert(!In.Flags.isSplit());
- if (In.ArgVT.isExtended()) {
- // We have an extended type, like i24, so we should just use the register type
- MemVT = In.VT;
- } else {
- MemVT = In.ArgVT;
- }
- } else if (In.ArgVT.isVector() && In.VT.isVector() &&
- In.ArgVT.getScalarType() == In.VT.getScalarType()) {
- assert(In.ArgVT.getVectorNumElements() > In.VT.getVectorNumElements());
- // We have a vector value which has been split into a vector with
- // the same scalar type, but fewer elements. This should handle
- // all the floating-point vector types.
- MemVT = In.VT;
- } else if (In.ArgVT.isVector() &&
- In.ArgVT.getVectorNumElements() == NumRegs) {
- // This arg has been split so that each element is stored in a separate
- // register.
- MemVT = In.ArgVT.getScalarType();
- } else if (In.ArgVT.isExtended()) {
- // We have an extended type, like i65.
- MemVT = In.VT;
- } else {
- unsigned MemoryBits = In.ArgVT.getStoreSizeInBits() / NumRegs;
- assert(In.ArgVT.getStoreSizeInBits() % NumRegs == 0);
- if (In.VT.isInteger()) {
- MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits);
- } else if (In.VT.isVector()) {
- assert(!In.VT.getScalarType().isFloatingPoint());
- unsigned NumElements = In.VT.getVectorNumElements();
- assert(MemoryBits % NumElements == 0);
- // This vector type has been split into another vector type with
- // a different elements size.
- EVT ScalarVT = EVT::getIntegerVT(State.getContext(),
- MemoryBits / NumElements);
- MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements);
+void AMDGPUTargetLowering::analyzeFormalArgumentsCompute(
+ CCState &State,
+ const SmallVectorImpl<ISD::InputArg> &Ins) const {
+ const MachineFunction &MF = State.getMachineFunction();
+ const Function &Fn = MF.getFunction();
+ LLVMContext &Ctx = Fn.getParent()->getContext();
+ const AMDGPUSubtarget &ST = AMDGPUSubtarget::get(MF);
+ const unsigned ExplicitOffset = ST.getExplicitKernelArgOffset(Fn);
+
+ unsigned MaxAlign = 1;
+ uint64_t ExplicitArgOffset = 0;
+ const DataLayout &DL = Fn.getParent()->getDataLayout();
+
+ unsigned InIndex = 0;
+
+ for (const Argument &Arg : Fn.args()) {
+ Type *BaseArgTy = Arg.getType();
+ unsigned Align = DL.getABITypeAlignment(BaseArgTy);
+ MaxAlign = std::max(Align, MaxAlign);
+ unsigned AllocSize = DL.getTypeAllocSize(BaseArgTy);
+
+ uint64_t ArgOffset = alignTo(ExplicitArgOffset, Align) + ExplicitOffset;
+ ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize;
+
+ // We're basically throwing away everything passed into us and starting over
+ // to get accurate in-memory offsets. The "PartOffset" is completely useless
+ // to us as computed in Ins.
+ //
+ // We also need to figure out what type legalization is trying to do to get
+ // the correct memory offsets.
+
+ SmallVector<EVT, 16> ValueVTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputeValueVTs(*this, DL, BaseArgTy, ValueVTs, &Offsets, ArgOffset);
+
+ for (unsigned Value = 0, NumValues = ValueVTs.size();
+ Value != NumValues; ++Value) {
+ uint64_t BasePartOffset = Offsets[Value];
+
+ EVT ArgVT = ValueVTs[Value];
+ EVT MemVT = ArgVT;
+ MVT RegisterVT =
+ getRegisterTypeForCallingConv(Ctx, ArgVT);
+ unsigned NumRegs =
+ getNumRegistersForCallingConv(Ctx, ArgVT);
+
+ if (!Subtarget->isAmdHsaOS() &&
+ (ArgVT == MVT::i16 || ArgVT == MVT::i8 || ArgVT == MVT::f16)) {
+ // The ABI says the caller will extend these values to 32-bits.
+ MemVT = ArgVT.isInteger() ? MVT::i32 : MVT::f32;
+ } else if (NumRegs == 1) {
+ // This argument is not split, so the IR type is the memory type.
+ if (ArgVT.isExtended()) {
+ // We have an extended type, like i24, so we should just use the
+ // register type.
+ MemVT = RegisterVT;
+ } else {
+ MemVT = ArgVT;
+ }
+ } else if (ArgVT.isVector() && RegisterVT.isVector() &&
+ ArgVT.getScalarType() == RegisterVT.getScalarType()) {
+ assert(ArgVT.getVectorNumElements() > RegisterVT.getVectorNumElements());
+ // We have a vector value which has been split into a vector with
+ // the same scalar type, but fewer elements. This should handle
+ // all the floating-point vector types.
+ MemVT = RegisterVT;
+ } else if (ArgVT.isVector() &&
+ ArgVT.getVectorNumElements() == NumRegs) {
+ // This arg has been split so that each element is stored in a separate
+ // register.
+ MemVT = ArgVT.getScalarType();
+ } else if (ArgVT.isExtended()) {
+ // We have an extended type, like i65.
+ MemVT = RegisterVT;
} else {
- llvm_unreachable("cannot deduce memory type.");
+ unsigned MemoryBits = ArgVT.getStoreSizeInBits() / NumRegs;
+ assert(ArgVT.getStoreSizeInBits() % NumRegs == 0);
+ if (RegisterVT.isInteger()) {
+ MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits);
+ } else if (RegisterVT.isVector()) {
+ assert(!RegisterVT.getScalarType().isFloatingPoint());
+ unsigned NumElements = RegisterVT.getVectorNumElements();
+ assert(MemoryBits % NumElements == 0);
+ // This vector type has been split into another vector type with
+ // a different elements size.
+ EVT ScalarVT = EVT::getIntegerVT(State.getContext(),
+ MemoryBits / NumElements);
+ MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements);
+ } else {
+ llvm_unreachable("cannot deduce memory type.");
+ }
}
- }
- // Convert one element vectors to scalar.
- if (MemVT.isVector() && MemVT.getVectorNumElements() == 1)
- MemVT = MemVT.getScalarType();
+ // Convert one element vectors to scalar.
+ if (MemVT.isVector() && MemVT.getVectorNumElements() == 1)
+ MemVT = MemVT.getScalarType();
- if (MemVT.isExtended()) {
- // This should really only happen if we have vec3 arguments
- assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3);
- MemVT = MemVT.getPow2VectorType(State.getContext());
- }
+ if (MemVT.isExtended()) {
+ // This should really only happen if we have vec3 arguments
+ assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3);
+ MemVT = MemVT.getPow2VectorType(State.getContext());
+ }
- assert(MemVT.isSimple());
- allocateKernArg(i, In.VT, MemVT.getSimpleVT(), CCValAssign::Full, In.Flags,
- State);
+ unsigned PartOffset = 0;
+ for (unsigned i = 0; i != NumRegs; ++i) {
+ State.addLoc(CCValAssign::getCustomMem(InIndex++, RegisterVT,
+ BasePartOffset + PartOffset,
+ MemVT.getSimpleVT(),
+ CCValAssign::Full));
+ PartOffset += MemVT.getStoreSize();
+ }
+ }
}
}
OpenPOWER on IntegriCloud