diff options
Diffstat (limited to 'llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp')
-rw-r--r-- | llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp | 107 |
1 files changed, 82 insertions, 25 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp index d5e47ee8251..4f6dcb113b7 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp @@ -103,14 +103,25 @@ static void FindUses(Value *V, Function &F, // - Return value is not needed: drop it // - Return value needed but not present: supply an undef // -// For now, return nullptr without creating a wrapper if the wrapper cannot -// be generated due to incompatible types. +// If the all the argument types of trivially castable to one another (i.e. +// I32 vs pointer type) then we don't create a wrapper at all (return nullptr +// instead). +// +// If there is a type mismatch that would result in an invalid wasm module +// being written then generate wrapper that contains unreachable (i.e. abort +// at runtime). Such programs are deep into undefined behaviour territory, +// but we choose to fail at runtime rather than generate and invalid module +// or fail at compiler time. The reason we delay the error is that we want +// to support the CMake which expects to be able to compile and link programs +// that refer to functions with entirely incorrect signatures (this is how +// CMake detects the existence of a function in a toolchain). static Function *CreateWrapper(Function *F, FunctionType *Ty) { Module *M = F->getParent(); - Function *Wrapper = - Function::Create(Ty, Function::PrivateLinkage, "bitcast", M); + Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage, + F->getName() + "_bitcast", M); BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper); + const DataLayout &DL = BB->getModule()->getDataLayout(); // Determine what arguments to pass. SmallVector<Value *, 4> Args; @@ -118,34 +129,80 @@ static Function *CreateWrapper(Function *F, FunctionType *Ty) { Function::arg_iterator AE = Wrapper->arg_end(); FunctionType::param_iterator PI = F->getFunctionType()->param_begin(); FunctionType::param_iterator PE = F->getFunctionType()->param_end(); + bool TypeMismatch = false; + bool WrapperNeeded = false; + + if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) || + (F->getFunctionType()->isVarArg() != Ty->isVarArg())) + WrapperNeeded = true; + for (; AI != AE && PI != PE; ++AI, ++PI) { - if (AI->getType() != *PI) { - Wrapper->eraseFromParent(); - return nullptr; + Type *ArgType = AI->getType(); + Type *ParamType = *PI; + + if (ArgType == ParamType) { + Args.push_back(&*AI); + } else { + if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) { + Instruction *PtrCast = + CastInst::CreateBitOrPointerCast(AI, ParamType, "cast"); + BB->getInstList().push_back(PtrCast); + Args.push_back(PtrCast); + } else { + LLVM_DEBUG(dbgs() << "CreateWrapper: arg type mismatch calling: " + << F->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: " + << *ParamType << " Got: " << *ArgType << "\n"); + TypeMismatch = true; + break; + } } - Args.push_back(&*AI); } - for (; PI != PE; ++PI) - Args.push_back(UndefValue::get(*PI)); - if (F->isVarArg()) - for (; AI != AE; ++AI) - Args.push_back(&*AI); - CallInst *Call = CallInst::Create(F, Args, "", BB); - - // Determine what value to return. - if (Ty->getReturnType()->isVoidTy()) - ReturnInst::Create(M->getContext(), BB); - else if (F->getFunctionType()->getReturnType()->isVoidTy()) - ReturnInst::Create(M->getContext(), UndefValue::get(Ty->getReturnType()), - BB); - else if (F->getFunctionType()->getReturnType() == Ty->getReturnType()) - ReturnInst::Create(M->getContext(), Call, BB); - else { + if (!TypeMismatch) { + for (; PI != PE; ++PI) + Args.push_back(UndefValue::get(*PI)); + if (F->isVarArg()) + for (; AI != AE; ++AI) + Args.push_back(&*AI); + + CallInst *Call = CallInst::Create(F, Args, "", BB); + + Type *ExpectedRtnType = F->getFunctionType()->getReturnType(); + Type *RtnType = Ty->getReturnType(); + // Determine what value to return. + if (RtnType->isVoidTy()) { + ReturnInst::Create(M->getContext(), BB); + WrapperNeeded = true; + } else if (ExpectedRtnType->isVoidTy()) { + ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB); + WrapperNeeded = true; + } else if (RtnType == ExpectedRtnType) { + ReturnInst::Create(M->getContext(), Call, BB); + } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType, + DL)) { + Instruction *Cast = + CastInst::CreateBitOrPointerCast(Call, RtnType, "cast"); + BB->getInstList().push_back(Cast); + ReturnInst::Create(M->getContext(), Cast, BB); + } else { + LLVM_DEBUG(dbgs() << "CreateWrapper: return type mismatch calling: " + << F->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType + << " Got: " << *RtnType << "\n"); + TypeMismatch = true; + } + } + + if (TypeMismatch) { + new UnreachableInst(M->getContext(), BB); + Wrapper->setName(F->getName() + "_bitcast_invalid"); + } else if (!WrapperNeeded) { + LLVM_DEBUG(dbgs() << "CreateWrapper: no wrapper needed: " << F->getName() + << "\n"); Wrapper->eraseFromParent(); return nullptr; } - return Wrapper; } |