summaryrefslogtreecommitdiffstats
path: root/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.h
blob: d82012b5a8531e02c8c51bdad274d569a00715a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
//===------ OrcTestCommon.h - Utilities for Orc Unit Tests ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Common utilities for the Orc unit tests.
//
//===----------------------------------------------------------------------===//


#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H
#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H

#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "gtest/gtest.h"

#include <memory>

namespace llvm {

namespace orc {
// CoreAPIsStandardTest that saves a bunch of boilerplate by providing the
// following:
//
// (1) ES -- An ExecutionSession
// (2) Foo, Bar, Baz, Qux -- SymbolStringPtrs for strings "foo", "bar", "baz",
//     and "qux" respectively.
// (3) FooAddr, BarAddr, BazAddr, QuxAddr -- Dummy addresses. Guaranteed
//     distinct and non-null.
// (4) FooSym, BarSym, BazSym, QuxSym -- JITEvaluatedSymbols with FooAddr,
//     BarAddr, BazAddr, and QuxAddr respectively. All with default strong,
//     linkage and non-hidden visibility.
// (5) V -- A JITDylib associated with ES.
class CoreAPIsBasedStandardTest : public testing::Test {
protected:
  std::shared_ptr<SymbolStringPool> SSP = std::make_shared<SymbolStringPool>();
  ExecutionSession ES{SSP};
  JITDylib &JD = ES.createJITDylib("JD");
  SymbolStringPtr Foo = ES.intern("foo");
  SymbolStringPtr Bar = ES.intern("bar");
  SymbolStringPtr Baz = ES.intern("baz");
  SymbolStringPtr Qux = ES.intern("qux");
  static const JITTargetAddress FooAddr = 1U;
  static const JITTargetAddress BarAddr = 2U;
  static const JITTargetAddress BazAddr = 3U;
  static const JITTargetAddress QuxAddr = 4U;
  JITEvaluatedSymbol FooSym =
      JITEvaluatedSymbol(FooAddr, JITSymbolFlags::Exported);
  JITEvaluatedSymbol BarSym =
      JITEvaluatedSymbol(BarAddr, JITSymbolFlags::Exported);
  JITEvaluatedSymbol BazSym =
      JITEvaluatedSymbol(BazAddr, JITSymbolFlags::Exported);
  JITEvaluatedSymbol QuxSym =
      JITEvaluatedSymbol(QuxAddr, JITSymbolFlags::Exported);
};

} // end namespace orc

class OrcNativeTarget {
public:
  static void initialize() {
    if (!NativeTargetInitialized) {
      InitializeNativeTarget();
      InitializeNativeTargetAsmParser();
      InitializeNativeTargetAsmPrinter();
      NativeTargetInitialized = true;
    }
  }

private:
  static bool NativeTargetInitialized;
};

class SimpleMaterializationUnit : public orc::MaterializationUnit {
public:
  using MaterializeFunction =
      std::function<void(orc::MaterializationResponsibility)>;
  using DiscardFunction =
      std::function<void(const orc::JITDylib &, orc::SymbolStringPtr)>;
  using DestructorFunction = std::function<void()>;

  SimpleMaterializationUnit(
      orc::SymbolFlagsMap SymbolFlags, MaterializeFunction Materialize,
      DiscardFunction Discard = DiscardFunction(),
      DestructorFunction Destructor = DestructorFunction())
      : MaterializationUnit(std::move(SymbolFlags), orc::VModuleKey()),
        Materialize(std::move(Materialize)), Discard(std::move(Discard)),
        Destructor(std::move(Destructor)) {}

  ~SimpleMaterializationUnit() override {
    if (Destructor)
      Destructor();
  }

  StringRef getName() const override { return "<Simple>"; }

  void materialize(orc::MaterializationResponsibility R) override {
    Materialize(std::move(R));
  }

  void discard(const orc::JITDylib &JD,
               const orc::SymbolStringPtr &Name) override {
    if (Discard)
      Discard(JD, std::move(Name));
    else
      llvm_unreachable("Discard not supported");
  }

private:
  MaterializeFunction Materialize;
  DiscardFunction Discard;
  DestructorFunction Destructor;
};

// Base class for Orc tests that will execute code.
class OrcExecutionTest {
public:

  OrcExecutionTest() {

    // Initialize the native target if it hasn't been done already.
    OrcNativeTarget::initialize();

    // Try to select a TargetMachine for the host.
    TM.reset(EngineBuilder().selectTarget());

    if (TM) {
      // If we found a TargetMachine, check that it's one that Orc supports.
      const Triple& TT = TM->getTargetTriple();

      // Bail out for windows platforms. We do not support these yet.
      if ((TT.getArch() != Triple::x86_64 && TT.getArch() != Triple::x86) ||
           TT.isOSWindows())
        return;

      // Target can JIT?
      SupportsJIT = TM->getTarget().hasJIT();
      // Use ability to create callback manager to detect whether Orc
      // has indirection support on this platform. This way the test
      // and Orc code do not get out of sync.
      SupportsIndirection = !!orc::createLocalCompileCallbackManager(TT, ES, 0);
    }
  };

protected:
  orc::ExecutionSession ES;
  LLVMContext Context;
  std::unique_ptr<TargetMachine> TM;
  bool SupportsJIT = false;
  bool SupportsIndirection = false;
};

class ModuleBuilder {
public:
  ModuleBuilder(LLVMContext &Context, StringRef Triple,
                StringRef Name);

  Function *createFunctionDecl(FunctionType *FTy, StringRef Name) {
    return Function::Create(FTy, GlobalValue::ExternalLinkage, Name, M.get());
  }

  Module* getModule() { return M.get(); }
  const Module* getModule() const { return M.get(); }
  std::unique_ptr<Module> takeModule() { return std::move(M); }

private:
  std::unique_ptr<Module> M;
};

// Dummy struct type.
struct DummyStruct {
  int X[256];
};

inline StructType *getDummyStructTy(LLVMContext &Context) {
  return StructType::get(ArrayType::get(Type::getInt32Ty(Context), 256));
}

template <typename HandleT, typename ModuleT>
class MockBaseLayer {
public:

  using ModuleHandleT = HandleT;

  using AddModuleSignature =
    Expected<ModuleHandleT>(ModuleT M,
                            std::shared_ptr<JITSymbolResolver> R);

  using RemoveModuleSignature = Error(ModuleHandleT H);
  using FindSymbolSignature = JITSymbol(const std::string &Name,
                                        bool ExportedSymbolsOnly);
  using FindSymbolInSignature = JITSymbol(ModuleHandleT H,
                                          const std::string &Name,
                                          bool ExportedSymbolsONly);
  using EmitAndFinalizeSignature = Error(ModuleHandleT H);

  std::function<AddModuleSignature> addModuleImpl;
  std::function<RemoveModuleSignature> removeModuleImpl;
  std::function<FindSymbolSignature> findSymbolImpl;
  std::function<FindSymbolInSignature> findSymbolInImpl;
  std::function<EmitAndFinalizeSignature> emitAndFinalizeImpl;

  Expected<ModuleHandleT> addModule(ModuleT M,
                                    std::shared_ptr<JITSymbolResolver> R) {
    assert(addModuleImpl &&
           "addModule called, but no mock implementation was provided");
    return addModuleImpl(std::move(M), std::move(R));
  }

  Error removeModule(ModuleHandleT H) {
    assert(removeModuleImpl &&
           "removeModule called, but no mock implementation was provided");
    return removeModuleImpl(H);
  }

  JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) {
    assert(findSymbolImpl &&
           "findSymbol called, but no mock implementation was provided");
    return findSymbolImpl(Name, ExportedSymbolsOnly);
  }

  JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name,
                         bool ExportedSymbolsOnly) {
    assert(findSymbolInImpl &&
           "findSymbolIn called, but no mock implementation was provided");
    return findSymbolInImpl(H, Name, ExportedSymbolsOnly);
  }

  Error emitAndFinaliez(ModuleHandleT H) {
    assert(emitAndFinalizeImpl &&
           "emitAndFinalize called, but no mock implementation was provided");
    return emitAndFinalizeImpl(H);
  }
};

class ReturnNullJITSymbol {
public:
  template <typename... Args>
  JITSymbol operator()(Args...) const {
    return nullptr;
  }
};

template <typename ReturnT>
class DoNothingAndReturn {
public:
  DoNothingAndReturn(ReturnT Ret) : Ret(std::move(Ret)) {}

  template <typename... Args>
  void operator()(Args...) const { return Ret; }
private:
  ReturnT Ret;
};

template <>
class DoNothingAndReturn<void> {
public:
  template <typename... Args>
  void operator()(Args...) const { }
};

} // namespace llvm

#endif
OpenPOWER on IntegriCloud