diff options
Diffstat (limited to 'llvm/lib/Target/PTX/PTXISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/PTX/PTXISelLowering.cpp | 102 | 
1 files changed, 82 insertions, 20 deletions
| diff --git a/llvm/lib/Target/PTX/PTXISelLowering.cpp b/llvm/lib/Target/PTX/PTXISelLowering.cpp index 3307d91a618..7f55871f63b 100644 --- a/llvm/lib/Target/PTX/PTXISelLowering.cpp +++ b/llvm/lib/Target/PTX/PTXISelLowering.cpp @@ -20,6 +20,7 @@  #include "llvm/Support/ErrorHandling.h"  #include "llvm/CodeGen/CallingConvLower.h"  #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFrameInfo.h"  #include "llvm/CodeGen/MachineRegisterInfo.h"  #include "llvm/CodeGen/SelectionDAG.h"  #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" @@ -352,40 +353,101 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,                               SmallVectorImpl<SDValue> &InVals) const {    MachineFunction& MF = DAG.getMachineFunction(); -  PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); -  PTXParamManager &PM = MFI->getParamManager(); - +  PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>(); +  PTXParamManager &PM = PTXMFI->getParamManager(); +  MachineFrameInfo *MFI = MF.getFrameInfo(); +      assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&           "Calls are not handled for the target device"); +  // Identify the callee function +  const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); +  const Function *function = cast<Function>(GV); +   +  // allow non-device calls only for printf +  bool isPrintf = function->getName() == "printf" || function->getName() == "puts";	 +   +  assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) && +			 "PTX function calls must be to PTX device functions"); +   +  unsigned outSize = isPrintf ? 2 : Outs.size(); +      std::vector<SDValue> Ops;    // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs] -  Ops.resize(Outs.size() + Ins.size() + 4); +  Ops.resize(outSize + Ins.size() + 4);    Ops[0] = Chain;    // Identify the callee function -  const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); -  assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device && -         "PTX function calls must be to PTX device functions");    Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());    Ops[Ins.size()+2] = Callee; -  // Generate STORE_PARAM nodes for each function argument.  In PTX, function -  // arguments are explicitly stored into .param variables and passed as -  // arguments. There is no register/stack-based calling convention in PTX. -  Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32); -  for (unsigned i = 0; i != OutVals.size(); ++i) { -    unsigned Size = OutVals[i].getValueType().getSizeInBits(); -    unsigned Param = PM.addLocalParam(Size); -    const std::string &ParamName = PM.getParamName(Param); -    SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), -                                                     MVT::Other); +  // #Outs +  Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32); +   +  if (isPrintf) { +    // first argument is the address of the global string variable in memory +    unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits()); +    SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(), +                                                      MVT::Other);      Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, -                        ParamValue, OutVals[i]); -    Ops[i+Ins.size()+4] = ParamValue; -  } +                        ParamValue0, OutVals[0]); +    Ops[Ins.size()+4] = ParamValue0; +       +    // alignment is the maximum size of all the arguments +    unsigned alignment = 0; +    for (unsigned i = 1; i < OutVals.size(); ++i) { +      alignment = std::max(alignment,  +    		               OutVals[i].getValueType().getSizeInBits()); +    } + +    // size is the alignment multiplied by the number of arguments +    unsigned size = alignment * (OutVals.size() - 1); +   +    // second argument is the address of the stack object (unless no arguments) +    unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits()); +    SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(), +                                                      MVT::Other); +    Ops[Ins.size()+5] = ParamValue1; +     +    if (size > 0) +    { +      // create a local stack object to store the arguments +      unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false); +      SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy()); +	   +      // store each of the arguments to the stack in turn +      for (unsigned int i = 1; i != OutVals.size(); i++) { +        SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy())); +        Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr, +                             MachinePointerInfo(), +                             false, false, 0); +      } +      // copy the address of the local frame index to get the address in non-local space +      SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex); + +      // store this address in the second argument +      Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr); +    } +  } +  else +  { +	  // Generate STORE_PARAM nodes for each function argument.  In PTX, function +	  // arguments are explicitly stored into .param variables and passed as +	  // arguments. There is no register/stack-based calling convention in PTX. +	  for (unsigned i = 0; i != OutVals.size(); ++i) { +		unsigned Size = OutVals[i].getValueType().getSizeInBits(); +		unsigned Param = PM.addLocalParam(Size); +		const std::string &ParamName = PM.getParamName(Param); +		SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), +														 MVT::Other); +		Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, +							ParamValue, OutVals[i]); +		Ops[i+Ins.size()+4] = ParamValue; +	  } +  } +      std::vector<SDValue> InParams;    // Generate list of .param variables to hold the return value(s). | 

