| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
 | //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// Fix bitcasted functions.
///
/// WebAssembly requires caller and callee signatures to match, however in LLVM,
/// some amount of slop is vaguely permitted. Detect mismatch by looking for
/// bitcasts of functions and rewrite them to use wrapper functions instead.
///
/// This doesn't catch all cases, such as when a function's address is taken in
/// one place and casted in another, but it works for many common cases.
///
/// Note that LLVM already optimizes away function bitcasts in common cases by
/// dropping arguments as needed, so this pass only ends up getting used in less
/// common cases.
///
//===----------------------------------------------------------------------===//
#include "WebAssembly.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
#define DEBUG_TYPE "wasm-fix-function-bitcasts"
static cl::opt<bool> TemporaryWorkarounds(
  "wasm-temporary-workarounds",
  cl::desc("Apply certain temporary workarounds"),
  cl::init(true), cl::Hidden);
namespace {
class FixFunctionBitcasts final : public ModulePass {
  StringRef getPassName() const override {
    return "WebAssembly Fix Function Bitcasts";
  }
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesCFG();
    ModulePass::getAnalysisUsage(AU);
  }
  bool runOnModule(Module &M) override;
public:
  static char ID;
  FixFunctionBitcasts() : ModulePass(ID) {}
};
} // End anonymous namespace
char FixFunctionBitcasts::ID = 0;
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
                "Fix mismatching bitcasts for WebAssembly", false, false)
ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
  return new FixFunctionBitcasts();
}
// Recursively descend the def-use lists from V to find non-bitcast users of
// bitcasts of V.
static void FindUses(Value *V, Function &F,
                     SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
                     SmallPtrSetImpl<Constant *> &ConstantBCs) {
  for (Use &U : V->uses()) {
    if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
      FindUses(BC, F, Uses, ConstantBCs);
    else if (U.get()->getType() != F.getType()) {
      CallSite CS(U.getUser());
      if (!CS)
        // Skip uses that aren't immediately called
        continue;
      Value *Callee = CS.getCalledValue();
      if (Callee != V)
        // Skip calls where the function isn't the callee
        continue;
      if (isa<Constant>(U.get())) {
        // Only add constant bitcasts to the list once; they get RAUW'd
        auto c = ConstantBCs.insert(cast<Constant>(U.get()));
        if (!c.second)
          continue;
      }
      Uses.push_back(std::make_pair(&U, &F));
    }
  }
}
// Create a wrapper function with type Ty that calls F (which may have a
// different type). Attempt to support common bitcasted function idioms:
//  - Call with more arguments than needed: arguments are dropped
//  - Call with fewer arguments than needed: arguments are filled in with undef
//  - Return value is not needed: drop it
//  - Return value needed but not present: supply an undef
//
// For now, return nullptr without creating a wrapper if the wrapper cannot
// be generated due to incompatible types.
static Function *CreateWrapper(Function *F, FunctionType *Ty) {
  Module *M = F->getParent();
  Function *Wrapper =
      Function::Create(Ty, Function::PrivateLinkage, "bitcast", M);
  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
  // Determine what arguments to pass.
  SmallVector<Value *, 4> Args;
  Function::arg_iterator AI = Wrapper->arg_begin();
  Function::arg_iterator AE = Wrapper->arg_end();
  FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
  FunctionType::param_iterator PE = F->getFunctionType()->param_end();
  for (; AI != AE && PI != PE; ++AI, ++PI) {
    if (AI->getType() != *PI) {
      Wrapper->eraseFromParent();
      return nullptr;
    }
    Args.push_back(&*AI);
  }
  for (; PI != PE; ++PI)
    Args.push_back(UndefValue::get(*PI));
  if (F->isVarArg())
    for (; AI != AE; ++AI)
      Args.push_back(&*AI);
  CallInst *Call = CallInst::Create(F, Args, "", BB);
  // Determine what value to return.
  if (Ty->getReturnType()->isVoidTy())
    ReturnInst::Create(M->getContext(), BB);
  else if (F->getFunctionType()->getReturnType()->isVoidTy())
    ReturnInst::Create(M->getContext(), UndefValue::get(Ty->getReturnType()),
                       BB);
  else if (F->getFunctionType()->getReturnType() == Ty->getReturnType())
    ReturnInst::Create(M->getContext(), Call, BB);
  else {
    Wrapper->eraseFromParent();
    return nullptr;
  }
  return Wrapper;
}
bool FixFunctionBitcasts::runOnModule(Module &M) {
  Function *Main = nullptr;
  CallInst *CallMain = nullptr;
  SmallVector<std::pair<Use *, Function *>, 0> Uses;
  SmallPtrSet<Constant *, 2> ConstantBCs;
  // Collect all the places that need wrappers.
  for (Function &F : M) {
    FindUses(&F, F, Uses, ConstantBCs);
    // If we have a "main" function, and its type isn't
    // "int main(int argc, char *argv[])", create an artificial call with it
    // bitcasted to that type so that we generate a wrapper for it, so that
    // the C runtime can call it.
    if (!TemporaryWorkarounds && !F.isDeclaration() && F.getName() == "main") {
      Main = &F;
      LLVMContext &C = M.getContext();
      Type *MainArgTys[] = {
        PointerType::get(Type::getInt8PtrTy(C), 0),
        Type::getInt32Ty(C)
      };
      FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
                                               /*isVarArg=*/false);
      if (F.getFunctionType() != MainTy) {
        Value *Args[] = {
          UndefValue::get(MainArgTys[0]),
          UndefValue::get(MainArgTys[1])
        };
        Value *Casted = ConstantExpr::getBitCast(Main,
                                                 PointerType::get(MainTy, 0));
        CallMain = CallInst::Create(Casted, Args, "call_main");
        Use *UseMain = &CallMain->getOperandUse(2);
        Uses.push_back(std::make_pair(UseMain, &F));
      }
    }
  }
  DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
  for (auto &UseFunc : Uses) {
    Use *U = UseFunc.first;
    Function *F = UseFunc.second;
    PointerType *PTy = cast<PointerType>(U->get()->getType());
    FunctionType *Ty = dyn_cast<FunctionType>(PTy->getElementType());
    // If the function is casted to something like i8* as a "generic pointer"
    // to be later casted to something else, we can't generate a wrapper for it.
    // Just ignore such casts for now.
    if (!Ty)
      continue;
    // Bitcasted vararg functions occur in Emscripten's implementation of
    // EM_ASM, so suppress wrappers for them for now.
    if (TemporaryWorkarounds && (Ty->isVarArg() || F->isVarArg()))
      continue;
    auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
    if (Pair.second)
      Pair.first->second = CreateWrapper(F, Ty);
    Function *Wrapper = Pair.first->second;
    if (!Wrapper)
      continue;
    if (isa<Constant>(U->get()))
      U->get()->replaceAllUsesWith(Wrapper);
    else
      U->set(Wrapper);
  }
  // If we created a wrapper for main, rename the wrapper so that it's the
  // one that gets called from startup.
  if (CallMain) {
    Main->setName("__original_main");
    Function *MainWrapper =
        cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
    MainWrapper->setName("main");
    MainWrapper->setLinkage(Main->getLinkage());
    MainWrapper->setVisibility(Main->getVisibility());
    Main->setLinkage(Function::PrivateLinkage);
    Main->setVisibility(Function::DefaultVisibility);
    delete CallMain;
  }
  return true;
}
 |