diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/IPO/PartialInlining.cpp | 12 | ||||
-rw-r--r-- | llvm/lib/Transforms/Utils/CodeExtractor.cpp | 57 | ||||
-rw-r--r-- | llvm/lib/Transforms/Utils/InlineFunction.cpp | 28 |
3 files changed, 73 insertions, 24 deletions
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp index c47d8b78df3..c00e13c4ae2 100644 --- a/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -149,7 +149,12 @@ struct PartialInlinerImpl { // the return block. void NormalizeReturnBlock(); - // Do function outlining: + // Do function outlining. + // NOTE: For vararg functions that do the vararg handling in the outlined + // function, we temporarily generate IR that does not properly + // forward varargs to the outlined function. Calling InlineFunction + // will update calls to the outlined functions to properly forward + // the varargs. Function *doFunctionOutlining(); Function *OrigFunc = nullptr; @@ -813,7 +818,8 @@ Function *PartialInlinerImpl::FunctionCloner::doFunctionOutlining() { // Extract the body of the if. OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI) + ClonedFuncBFI.get(), &BPI, + /* AllowVarargs */ true) .extractCodeRegion(); if (OutlinedFunc) { @@ -938,7 +944,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { << ore::NV("Caller", CS.getCaller()); InlineFunctionInfo IFI(nullptr, GetAssumptionCache, PSI); - if (!InlineFunction(CS, IFI)) + if (!InlineFunction(CS, IFI, nullptr, true, Cloner.OutlinedFunc)) continue; ORE.emit(OR); diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 557171e1a28..f9e2727e804 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -78,7 +78,8 @@ AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, cl::desc("Aggregate arguments to code-extracted functions")); /// \brief Test whether a block is valid for extraction. -bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { +bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB, + bool AllowVarArgs) { // Landing pads must be in the function where they were inserted for cleanup. if (BB.isEHPad()) return false; @@ -110,14 +111,19 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { } } - // Don't hoist code containing allocas, invokes, or vastarts. + // Don't hoist code containing allocas or invokes. If explicitly requested, + // allow vastart. for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { if (isa<AllocaInst>(I) || isa<InvokeInst>(I)) return false; if (const CallInst *CI = dyn_cast<CallInst>(I)) if (const Function *F = CI->getCalledFunction()) - if (F->getIntrinsicID() == Intrinsic::vastart) - return false; + if (F->getIntrinsicID() == Intrinsic::vastart) { + if (AllowVarArgs) + continue; + else + return false; + } } return true; @@ -125,7 +131,8 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { /// \brief Build a set of blocks to extract if the input blocks are viable. static SetVector<BasicBlock *> -buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) { +buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, + bool AllowVarArgs) { assert(!BBs.empty() && "The set of blocks to extract must be non-empty"); SetVector<BasicBlock *> Result; @@ -138,7 +145,7 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) { if (!Result.insert(BB)) llvm_unreachable("Repeated basic blocks in extraction input"); - if (!CodeExtractor::isBlockValidForExtraction(*BB)) { + if (!CodeExtractor::isBlockValidForExtraction(*BB, AllowVarArgs)) { Result.clear(); return Result; } @@ -160,15 +167,17 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) { CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI) + BranchProbabilityInfo *BPI, bool AllowVarArgs) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(BBs, DT)) {} + BPI(BPI), AllowVarArgs(AllowVarArgs), + Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs)) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT)) {} + BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, + /* AllowVarArgs */ false)) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -594,7 +603,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, paramTy.push_back(PointerType::getUnqual(StructTy)); } FunctionType *funcType = - FunctionType::get(RetTy, paramTy, false); + FunctionType::get(RetTy, paramTy, + AllowVarArgs && oldFunction->isVarArg()); // Create the new function Function *newFunction = Function::Create(funcType, @@ -957,12 +967,31 @@ Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; - ValueSet inputs, outputs, SinkingCands, HoistingCands; - BasicBlock *CommonExit = nullptr; - // Assumption: this is a single-entry code region, and the header is the first // block in the region. BasicBlock *header = *Blocks.begin(); + Function *oldFunction = header->getParent(); + + // For functions with varargs, check that varargs handling is only done in the + // outlined function, i.e vastart and vaend are only used in outlined blocks. + if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) { + auto containsVarArgIntrinsic = [](Instruction &I) { + if (const CallInst *CI = dyn_cast<CallInst>(&I)) + if (const Function *F = CI->getCalledFunction()) + return F->getIntrinsicID() == Intrinsic::vastart || + F->getIntrinsicID() == Intrinsic::vaend; + return false; + }; + + for (auto &BB : *oldFunction) { + if (Blocks.count(&BB)) + continue; + if (llvm::any_of(BB, containsVarArgIntrinsic)) + return nullptr; + } + } + ValueSet inputs, outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; // Calculate the entry frequency of the new function before we change the root // block. @@ -984,8 +1013,6 @@ Function *CodeExtractor::extractCodeRegion() { // that the return is not in the region. splitReturnBlocks(); - Function *oldFunction = header->getParent(); - // This takes place of the original loop BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 6b1391e0c80..23a72e86e50 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -1490,7 +1490,8 @@ static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, /// exists in the instruction stream. Similarly this will inline a recursive /// function by one level. bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, - AAResults *CalleeAAR, bool InsertLifetime) { + AAResults *CalleeAAR, bool InsertLifetime, + Function *ForwardVarArgsTo) { Instruction *TheCall = CS.getInstruction(); assert(TheCall->getParent() && TheCall->getFunction() && "Instruction not in function!"); @@ -1500,8 +1501,9 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, Function *CalledFunc = CS.getCalledFunction(); if (!CalledFunc || // Can't inline external function or indirect - CalledFunc->isDeclaration() || // call, or call to a vararg function! - CalledFunc->getFunctionType()->isVarArg()) return false; + CalledFunc->isDeclaration() || + (!ForwardVarArgsTo && CalledFunc->isVarArg())) // call, or call to a vararg function! + return false; // The inliner does not know how to inline through calls with operand bundles // in general ... @@ -1628,8 +1630,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, auto &DL = Caller->getParent()->getDataLayout(); - assert(CalledFunc->arg_size() == CS.arg_size() && - "No varargs calls can be inlined!"); + assert((CalledFunc->arg_size() == CS.arg_size() || ForwardVarArgsTo) && + "Varargs calls can only be inlined if the Varargs are forwarded!"); // Calculate the vector of arguments to pass into the function cloner, which // matches up the formal to the actual argument values. @@ -1811,6 +1813,11 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, replaceDbgDeclareForAlloca(AI, AI, DIB, /*Deref=*/false); } + SmallVector<Value*,4> VarArgsToForward; + for (unsigned i = CalledFunc->getFunctionType()->getNumParams(); + i < CS.getNumArgOperands(); i++) + VarArgsToForward.push_back(CS.getArgOperand(i)); + bool InlinedMustTailCalls = false, InlinedDeoptimizeCalls = false; if (InlinedFunctionInfo.ContainsCalls) { CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None; @@ -1819,7 +1826,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E; ++BB) { - for (Instruction &I : *BB) { + for (auto II = BB->begin(); II != BB->end();) { + Instruction &I = *II++; CallInst *CI = dyn_cast<CallInst>(&I); if (!CI) continue; @@ -1850,6 +1858,14 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // 'nounwind'. if (MarkNoUnwind) CI->setDoesNotThrow(); + + if (ForwardVarArgsTo && CI->getCalledFunction() == ForwardVarArgsTo) { + SmallVector<Value*, 6> Params(CI->arg_operands()); + Params.append(VarArgsToForward.begin(), VarArgsToForward.end()); + CallInst *Call = CallInst::Create(CI->getCalledFunction(), Params, "", CI); + CI->replaceAllUsesWith(Call); + CI->eraseFromParent(); + } } } } |