diff options
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/Legacy.h | 60 | ||||
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp | 73 |
2 files changed, 133 insertions, 0 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Legacy.h b/llvm/include/llvm/ExecutionEngine/Orc/Legacy.h index 11143a872a5..27bc9c47502 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/Legacy.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/Legacy.h @@ -32,6 +32,66 @@ private: SymbolResolver &R; }; +/// @brief Use the given legacy-style FindSymbol function (i.e. a function that +/// takes a const std::string& or StringRef and returns a JITSymbol) to +/// find the flags for each symbol in Symbols and store their flags in +/// FlagsMap. If any JITSymbol returned by FindSymbol is in an error +/// state the function returns immediately with that error, otherwise it +/// returns the set of symbols not found. +/// +/// Useful for implementing lookupFlags bodies that query legacy resolvers. +template <typename FindSymbolFn> +Expected<LookupFlagsResult> +lookupFlagsWithLegacyFn(const SymbolNameSet &Symbols, FindSymbolFn FindSymbol) { + SymbolFlagsMap SymbolFlags; + SymbolNameSet SymbolsNotFound; + + for (auto &S : Symbols) { + if (JITSymbol Sym = FindSymbol(*S)) + SymbolFlags[S] = Sym.getFlags(); + else if (auto Err = Sym.takeError()) + return std::move(Err); + else + SymbolsNotFound.insert(S); + } + + return LookupFlagsResult{std::move(SymbolFlags), std::move(SymbolsNotFound)}; +} + +/// @brief Use the given legacy-style FindSymbol function (i.e. a function that +/// takes a const std::string& or StringRef and returns a JITSymbol) to +/// find the address and flags for each symbol in Symbols and store the +/// result in Query. If any JITSymbol returned by FindSymbol is in an +/// error then Query.setFailed(...) is called with that error and the +/// function returns immediately. On success, returns the set of symbols +/// not found. +/// +/// Useful for implementing lookup bodies that query legacy resolvers. +template <typename FindSymbolFn> +SymbolNameSet lookupWithLegacyFn(AsynchronousSymbolQuery &Query, + const SymbolNameSet &Symbols, + FindSymbolFn FindSymbol) { + SymbolNameSet SymbolsNotFound; + + for (auto &S : Symbols) { + if (JITSymbol Sym = FindSymbol(*S)) { + if (auto Addr = Sym.getAddress()) { + Query.setDefinition(S, JITEvaluatedSymbol(*Addr, Sym.getFlags())); + Query.notifySymbolFinalized(); + } else { + Query.setFailed(Addr.takeError()); + return {}; + } + } else if (auto Err = Sym.takeError()) { + Query.setFailed(std::move(Err)); + return {}; + } else + SymbolsNotFound.insert(S); + } + + return SymbolsNotFound; +} + } // End namespace orc } // End namespace llvm diff --git a/llvm/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp b/llvm/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp index 12c43b58625..963ce9f50ad 100644 --- a/llvm/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp @@ -87,4 +87,77 @@ TEST(LegacyAPIInteropTest, QueryAgainstVSO) { << "lookup returned the wrong result for address of 'foo'"; } +TEST(LegacyAPIInteropTset, LegacyLookupHelpersFn) { + constexpr JITTargetAddress FooAddr = 0xdeadbeef; + JITSymbolFlags FooFlags = JITSymbolFlags::Exported; + + bool BarMaterialized = false; + constexpr JITTargetAddress BarAddr = 0xcafef00d; + JITSymbolFlags BarFlags = static_cast<JITSymbolFlags::FlagNames>( + JITSymbolFlags::Exported | JITSymbolFlags::Weak); + + auto LegacyLookup = [&](const std::string &Name) -> JITSymbol { + if (Name == "foo") + return {FooAddr, FooFlags}; + + if (Name == "bar") { + auto BarMaterializer = [&]() -> Expected<JITTargetAddress> { + BarMaterialized = true; + return BarAddr; + }; + + return {BarMaterializer, BarFlags}; + } + + return nullptr; + }; + + SymbolStringPool SP; + auto Foo = SP.intern("foo"); + auto Bar = SP.intern("bar"); + auto Baz = SP.intern("baz"); + + SymbolNameSet Symbols({Foo, Bar, Baz}); + + auto LFR = lookupFlagsWithLegacyFn(Symbols, LegacyLookup); + + EXPECT_TRUE(!!LFR) << "lookupFlagsWithLegacy failed unexpectedly"; + EXPECT_EQ(LFR->SymbolFlags.size(), 2U) << "Wrong number of flags returned"; + EXPECT_EQ(LFR->SymbolFlags.count(Foo), 1U) << "Flags for foo missing"; + EXPECT_EQ(LFR->SymbolFlags.count(Bar), 1U) << "Flags for foo missing"; + EXPECT_EQ(LFR->SymbolFlags[Foo], FooFlags) << "Wrong flags for foo"; + EXPECT_EQ(LFR->SymbolFlags[Bar], BarFlags) << "Wrong flags for foo"; + EXPECT_EQ(LFR->SymbolsNotFound.size(), 1U) << "Expected one symbol not found"; + EXPECT_EQ(LFR->SymbolsNotFound.count(Baz), 1U) + << "Expected symbol baz to be not found"; + EXPECT_FALSE(BarMaterialized) + << "lookupFlags should not have materialized bar"; + + bool OnResolvedRun = false; + bool OnReadyRun = false; + auto OnResolved = [&](Expected<SymbolMap> Result) { + OnResolvedRun = true; + EXPECT_TRUE(!!Result) << "lookuWithLegacy failed to resolve"; + EXPECT_EQ(Result->size(), 2U) << "Wrong number of symbols resolved"; + EXPECT_EQ(Result->count(Foo), 1U) << "Result for foo missing"; + EXPECT_EQ(Result->count(Bar), 1U) << "Result for bar missing"; + EXPECT_EQ((*Result)[Foo].getAddress(), FooAddr) << "Wrong address for foo"; + EXPECT_EQ((*Result)[Foo].getFlags(), FooFlags) << "Wrong flags for foo"; + EXPECT_EQ((*Result)[Bar].getAddress(), BarAddr) << "Wrong address for bar"; + EXPECT_EQ((*Result)[Bar].getFlags(), BarFlags) << "Wrong flags for bar"; + }; + auto OnReady = [&](Error Err) { + EXPECT_FALSE(!!Err) << "Finalization unexpectedly failed"; + OnReadyRun = true; + }; + + AsynchronousSymbolQuery Q({Foo, Bar}, OnResolved, OnReady); + auto Unresolved = lookupWithLegacyFn(Q, Symbols, LegacyLookup); + + EXPECT_TRUE(OnResolvedRun) << "OnResolved was not run"; + EXPECT_TRUE(OnReadyRun) << "OnReady was not run"; + EXPECT_EQ(Unresolved.size(), 1U) << "Expected one unresolved symbol"; + EXPECT_EQ(Unresolved.count(Baz), 1U) << "Expected baz to be unresolved"; +} + } // namespace |