diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 22 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64CallLowering.cpp | 80 |
2 files changed, 91 insertions, 11 deletions
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 58c444d129d..37ac96e5290 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2193,6 +2193,20 @@ void IRTranslator::finalizeFunction() { FuncInfo.clear(); } +/// Returns true if a BasicBlock \p BB within a variadic function contains a +/// variadic musttail call. +static bool checkForMustTailInVarArgFn(bool IsVarArg, const BasicBlock &BB) { + if (!IsVarArg) + return false; + + // Walk the block backwards, because tail calls usually only appear at the end + // of a block. + return std::any_of(BB.rbegin(), BB.rend(), [](const Instruction &I) { + const auto *CI = dyn_cast<CallInst>(&I); + return CI && CI->isMustTailCall(); + }); +} + bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) { MF = &CurMF; const Function &F = MF->getFunction(); @@ -2254,6 +2268,9 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) { SwiftError.setFunction(CurMF); SwiftError.createEntriesInEntryBlock(DbgLoc); + bool IsVarArg = F.isVarArg(); + bool HasMustTailInVarArgFn = false; + // Create all blocks, in IR order, to preserve the layout. for (const BasicBlock &BB: F) { auto *&MBB = BBToMBB[&BB]; @@ -2263,8 +2280,13 @@ bool IRTranslator::runOnMachineFunction(MachineFunction &CurMF) { if (BB.hasAddressTaken()) MBB->setHasAddressTaken(); + + if (!HasMustTailInVarArgFn) + HasMustTailInVarArgFn = checkForMustTailInVarArgFn(IsVarArg, BB); } + MF->getFrameInfo().setHasMustTailInVarArgFunc(HasMustTailInVarArgFn); + // Make our arguments/constants entry block fallthrough to the IR entry block. EntryBB->addSuccessor(&getMBB(F.front())); diff --git a/llvm/lib/Target/AArch64/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/AArch64CallLowering.cpp index 9f1945cbc31..45f07fc2ae8 100644 --- a/llvm/lib/Target/AArch64/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64CallLowering.cpp @@ -368,6 +368,49 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, return Success; } +/// Helper function to compute forwarded registers for musttail calls. Computes +/// the forwarded registers, sets MBB liveness, and emits COPY instructions that +/// can be used to save + restore registers later. +static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder, + CCAssignFn *AssignFn) { + MachineBasicBlock &MBB = MIRBuilder.getMBB(); + MachineFunction &MF = MIRBuilder.getMF(); + MachineFrameInfo &MFI = MF.getFrameInfo(); + + if (!MFI.hasMustTailInVarArgFunc()) + return; + + AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); + const Function &F = MF.getFunction(); + assert(F.isVarArg() && "Expected F to be vararg?"); + + // Compute the set of forwarded registers. The rest are scratch. + SmallVector<CCValAssign, 16> ArgLocs; + CCState CCInfo(F.getCallingConv(), /*IsVarArg=*/true, MF, ArgLocs, + F.getContext()); + SmallVector<MVT, 2> RegParmTypes; + RegParmTypes.push_back(MVT::i64); + RegParmTypes.push_back(MVT::f128); + + // Later on, we can use this vector to restore the registers if necessary. + SmallVectorImpl<ForwardedRegister> &Forwards = + FuncInfo->getForwardedMustTailRegParms(); + CCInfo.analyzeMustTailForwardedRegisters(Forwards, RegParmTypes, AssignFn); + + // Conservatively forward X8, since it might be used for an aggregate + // return. + if (!CCInfo.isAllocated(AArch64::X8)) { + unsigned X8VReg = MF.addLiveIn(AArch64::X8, &AArch64::GPR64RegClass); + Forwards.push_back(ForwardedRegister(X8VReg, AArch64::X8, MVT::i64)); + } + + // Add the forwards to the MachineBasicBlock and MachineFunction. + for (const auto &F : Forwards) { + MBB.addLiveIn(F.PReg); + MIRBuilder.buildCopy(Register(F.VReg), Register(F.PReg)); + } +} + bool AArch64CallLowering::lowerFormalArguments( MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef<ArrayRef<Register>> VRegs) const { @@ -441,6 +484,8 @@ bool AArch64CallLowering::lowerFormalArguments( if (Subtarget.hasCustomCallingConv()) Subtarget.getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF); + handleMustTailForwardedRegisters(MIRBuilder, AssignFn); + // Move back to the end of the basic block. MIRBuilder.setMBB(MBB); @@ -695,16 +740,6 @@ bool AArch64CallLowering::isEligibleForTailCallOptimization( assert((!Info.IsVarArg || CalleeCC == CallingConv::C) && "Unexpected variadic calling convention"); - // Before we can musttail varargs, we need to forward parameters like in - // r345641. Make sure that we don't enable musttail with varargs without - // addressing that! - if (Info.IsVarArg && Info.IsMustTailCall) { - LLVM_DEBUG( - dbgs() - << "... Cannot handle vararg musttail functions yet.\n"); - return false; - } - // Verify that the incoming and outgoing arguments from the callee are // safe to tail call. if (!doCallerAndCalleePassArgsTheSameWay(Info, MF, InArgs)) { @@ -745,6 +780,7 @@ bool AArch64CallLowering::lowerTailCall( const Function &F = MF.getFunction(); MachineRegisterInfo &MRI = MF.getRegInfo(); const AArch64TargetLowering &TLI = *getTLI<AArch64TargetLowering>(); + AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); // True when we're tail calling, but without -tailcallopt. bool IsSibCall = !MF.getTarget().Options.GuaranteedTailCallOpt; @@ -800,7 +836,6 @@ bool AArch64CallLowering::lowerTailCall( // We aren't sibcalling, so we need to compute FPDiff. We need to do this // before handling assignments, because FPDiff must be known for memory // arguments. - AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); unsigned NumReusableBytes = FuncInfo->getBytesInStackArgArea(); SmallVector<CCValAssign, 16> OutLocs; CCState OutInfo(CalleeCC, false, MF, OutLocs, F.getContext()); @@ -823,6 +858,8 @@ bool AArch64CallLowering::lowerTailCall( assert(FPDiff % 16 == 0 && "unaligned stack on tail call"); } + const auto &Forwards = FuncInfo->getForwardedMustTailRegParms(); + // Do the actual argument marshalling. SmallVector<unsigned, 8> PhysRegs; OutgoingArgHandler Handler(MIRBuilder, MRI, MIB, AssignFnFixed, @@ -830,6 +867,27 @@ bool AArch64CallLowering::lowerTailCall( if (!handleAssignments(MIRBuilder, OutArgs, Handler)) return false; + if (Info.IsVarArg && Info.IsMustTailCall) { + // Now we know what's being passed to the function. Add uses to the call for + // the forwarded registers that we *aren't* passing as parameters. This will + // preserve the copies we build earlier. + for (const auto &F : Forwards) { + Register ForwardedReg = F.PReg; + // If the register is already passed, or aliases a register which is + // already being passed, then skip it. + if (any_of(MIB->uses(), [&ForwardedReg, &TRI](const MachineOperand &Use) { + if (!Use.isReg()) + return false; + return TRI->regsOverlap(Use.getReg(), ForwardedReg); + })) + continue; + + // We aren't passing it already, so we should add it to the call. + MIRBuilder.buildCopy(ForwardedReg, Register(F.VReg)); + MIB.addReg(ForwardedReg, RegState::Implicit); + } + } + // If we have -tailcallopt, we need to adjust the stack. We'll do the call // sequence start and end here. if (!IsSibCall) { |