diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/Evaluator.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/Evaluator.cpp | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/Utils/Evaluator.cpp b/llvm/lib/Transforms/Utils/Evaluator.cpp index 9440ae3ef2a..3f9daa2d29a 100644 --- a/llvm/lib/Transforms/Utils/Evaluator.cpp +++ b/llvm/lib/Transforms/Utils/Evaluator.cpp @@ -217,6 +217,33 @@ Constant *Evaluator::ComputeLoadResult(Constant *P) { return nullptr; // don't know how to evaluate. } +Function *Evaluator::getCallee(Value *V) { + auto *CE = dyn_cast<ConstantExpr>(V); + if (!CE) + return dyn_cast<Function>(getVal(V)); + + Constant *C = + CE->getOpcode() == Instruction::BitCast + ? ConstantFoldLoadThroughBitcast(CE, CE->getOperand(0)->getType(), DL) + : CE; + return dyn_cast<Function>(C); +} + +/// If call expression contains bitcast then we may need to cast +/// evaluated return value to a type of the call expression. +Constant *Evaluator::castCallResultIfNeeded(Value *CallExpr, Constant *RV) { + ConstantExpr *CE = dyn_cast<ConstantExpr>(CallExpr); + if (!CE || CE->getOpcode() != Instruction::BitCast) + return RV; + + if (auto *FT = + dyn_cast<FunctionType>(CE->getType()->getPointerElementType())) { + RV = ConstantFoldLoadThroughBitcast(RV, FT->getReturnType(), DL); + assert(RV && "Failed to fold bitcast call expr"); + } + return RV; +} + /// Evaluate all instructions in block BB, returning true if successful, false /// if we can't evaluate it. NewBB returns the next BB that control flows into, /// or null upon return. @@ -465,7 +492,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, } // Resolve function pointers. - Function *Callee = dyn_cast<Function>(getVal(CS.getCalledValue())); + Function *Callee = getCallee(CS.getCalledValue()); if (!Callee || Callee->isInterposable()) { LLVM_DEBUG(dbgs() << "Can not resolve function pointer.\n"); return false; // Cannot resolve. @@ -478,7 +505,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, if (Callee->isDeclaration()) { // If this is a function we can constant fold, do it. if (Constant *C = ConstantFoldCall(CS, Callee, Formals, TLI)) { - InstResult = C; + InstResult = castCallResultIfNeeded(CS.getCalledValue(), C); LLVM_DEBUG(dbgs() << "Constant folded function call. Result: " << *InstResult << "\n"); } else { @@ -499,7 +526,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, return false; } ValueStack.pop_back(); - InstResult = RetVal; + InstResult = castCallResultIfNeeded(CS.getCalledValue(), RetVal); if (InstResult) { LLVM_DEBUG(dbgs() << "Successfully evaluated function. Result: " |