diff options
Diffstat (limited to 'llvm')
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcError.h | 3 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h | 244 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h | 171 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h | 123 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h | 79 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 524 | ||||
-rw-r--r-- | llvm/lib/ExecutionEngine/Orc/OrcError.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp | 60 | ||||
-rw-r--r-- | llvm/tools/lli/ChildTarget/ChildTarget.cpp | 6 | ||||
-rw-r--r-- | llvm/tools/lli/RemoteJITUtils.h | 1 | ||||
-rw-r--r-- | llvm/tools/lli/lli.cpp | 21 | ||||
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 103 |
12 files changed, 877 insertions, 460 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h index 48f35d6b39b..aeee03f86e9 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -26,7 +26,8 @@ enum class OrcErrorCode : int { RemoteMProtectAddrUnrecognized, RemoteIndirectStubsOwnerDoesNotExist, RemoteIndirectStubsOwnerIdAlreadyInUse, - UnexpectedRPCCall + UnexpectedRPCCall, + UnexpectedRPCResponse, }; std::error_code orcError(OrcErrorCode ErrCode); diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index 8068733dcdd..9ecf904c9ff 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -36,6 +36,7 @@ namespace remote { template <typename ChannelT> class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI { public: + /// Remote memory manager. class RCMemoryManager : public RuntimeDyld::MemoryManager { public: @@ -105,11 +106,13 @@ public: DEBUG(dbgs() << "Allocator " << Id << " reserved:\n"); if (CodeSize != 0) { - std::error_code EC = Client.reserveMem(Unmapped.back().RemoteCodeAddr, - Id, CodeSize, CodeAlign); - // FIXME; Add error to poll. - assert(!EC && "Failed reserving remote memory."); - (void)EC; + if (auto AddrOrErr = Client.reserveMem(Id, CodeSize, CodeAlign)) + Unmapped.back().RemoteCodeAddr = *AddrOrErr; + else { + // FIXME; Add error to poll. + assert(!AddrOrErr.getError() && "Failed reserving remote memory."); + } + DEBUG(dbgs() << " code: " << format("0x%016x", Unmapped.back().RemoteCodeAddr) << " (" << CodeSize << " bytes, alignment " << CodeAlign @@ -117,11 +120,13 @@ public: } if (RODataSize != 0) { - std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRODataAddr, - Id, RODataSize, RODataAlign); - // FIXME; Add error to poll. - assert(!EC && "Failed reserving remote memory."); - (void)EC; + if (auto AddrOrErr = Client.reserveMem(Id, RODataSize, RODataAlign)) + Unmapped.back().RemoteRODataAddr = *AddrOrErr; + else { + // FIXME; Add error to poll. + assert(!AddrOrErr.getError() && "Failed reserving remote memory."); + } + DEBUG(dbgs() << " ro-data: " << format("0x%016x", Unmapped.back().RemoteRODataAddr) << " (" << RODataSize << " bytes, alignment " @@ -129,11 +134,13 @@ public: } if (RWDataSize != 0) { - std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRWDataAddr, - Id, RWDataSize, RWDataAlign); - // FIXME; Add error to poll. - assert(!EC && "Failed reserving remote memory."); - (void)EC; + if (auto AddrOrErr = Client.reserveMem(Id, RWDataSize, RWDataAlign)) + Unmapped.back().RemoteRWDataAddr = *AddrOrErr; + else { + // FIXME; Add error to poll. + assert(!AddrOrErr.getError() && "Failed reserving remote memory."); + } + DEBUG(dbgs() << " rw-data: " << format("0x%016x", Unmapped.back().RemoteRWDataAddr) << " (" << RWDataSize << " bytes, alignment " @@ -431,8 +438,10 @@ public: TargetAddress PtrBase; unsigned NumStubsEmitted; - Remote.emitIndirectStubs(StubBase, PtrBase, NumStubsEmitted, Id, - NewStubsRequired); + if (auto StubInfoOrErr = Remote.emitIndirectStubs(Id, NewStubsRequired)) + std::tie(StubBase, PtrBase, NumStubsEmitted) = *StubInfoOrErr; + else + return StubInfoOrErr.getError(); unsigned NewBlockId = RemoteIndirectStubsInfos.size(); RemoteIndirectStubsInfos.push_back({StubBase, PtrBase, NumStubsEmitted}); @@ -484,8 +493,12 @@ public: void grow() override { TargetAddress BlockAddr = 0; uint32_t NumTrampolines = 0; - auto EC = Remote.emitTrampolineBlock(BlockAddr, NumTrampolines); - assert(!EC && "Failed to create trampolines"); + if (auto TrampolineInfoOrErr = Remote.emitTrampolineBlock()) + std::tie(BlockAddr, NumTrampolines) = *TrampolineInfoOrErr; + else { + // FIXME: Return error. + llvm_unreachable("Failed to create trampolines"); + } uint32_t TrampolineSize = Remote.getTrampolineSize(); for (unsigned I = 0; I < NumTrampolines; ++I) @@ -503,53 +516,33 @@ public: OrcRemoteTargetClient H(Channel, EC); if (EC) return EC; - return H; + return ErrorOr<OrcRemoteTargetClient>(std::move(H)); } /// Call the int(void) function at the given address in the target and return /// its result. - std::error_code callIntVoid(int &Result, TargetAddress Addr) { + ErrorOr<int> callIntVoid(TargetAddress Addr) { DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) << "\n"); - if (auto EC = call<CallIntVoid>(Channel, Addr)) - return EC; - - unsigned NextProcId; - if (auto EC = listenForCompileRequests(NextProcId)) - return EC; - - if (NextProcId != CallIntVoidResponseId) - return orcError(OrcErrorCode::UnexpectedRPCCall); - - return handle<CallIntVoidResponse>(Channel, [&](int R) { - Result = R; - DEBUG(dbgs() << "Result: " << R << "\n"); - return std::error_code(); - }); + auto Listen = + [&](RPCChannel &C, uint32_t Id) { + return listenForCompileRequests(C, Id); + }; + return callSTHandling<CallIntVoid>(Channel, Listen, Addr); } /// Call the int(int, char*[]) function at the given address in the target and /// return its result. - std::error_code callMain(int &Result, TargetAddress Addr, - const std::vector<std::string> &Args) { + ErrorOr<int> callMain(TargetAddress Addr, + const std::vector<std::string> &Args) { DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x", Addr) << "\n"); - if (auto EC = call<CallMain>(Channel, Addr, Args)) - return EC; - - unsigned NextProcId; - if (auto EC = listenForCompileRequests(NextProcId)) - return EC; - - if (NextProcId != CallMainResponseId) - return orcError(OrcErrorCode::UnexpectedRPCCall); - - return handle<CallMainResponse>(Channel, [&](int R) { - Result = R; - DEBUG(dbgs() << "Result: " << R << "\n"); - return std::error_code(); - }); + auto Listen = + [&](RPCChannel &C, uint32_t Id) { + return listenForCompileRequests(C, Id); + }; + return callSTHandling<CallMain>(Channel, Listen, Addr, Args); } /// Call the void() function at the given address in the target and wait for @@ -558,17 +551,11 @@ public: DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr) << "\n"); - if (auto EC = call<CallVoidVoid>(Channel, Addr)) - return EC; - - unsigned NextProcId; - if (auto EC = listenForCompileRequests(NextProcId)) - return EC; - - if (NextProcId != CallVoidVoidResponseId) - return orcError(OrcErrorCode::UnexpectedRPCCall); - - return handle<CallVoidVoidResponse>(Channel, doNothing); + auto Listen = + [&](RPCChannel &C, JITFuncId Id) { + return listenForCompileRequests(C, Id); + }; + return callSTHandling<CallVoidVoid>(Channel, Listen, Addr); } /// Create an RCMemoryManager which will allocate its memory on the remote @@ -578,7 +565,7 @@ public: assert(!MM && "MemoryManager should be null before creation."); auto Id = AllocatorIds.getNext(); - if (auto EC = call<CreateRemoteAllocator>(Channel, Id)) + if (auto EC = callST<CreateRemoteAllocator>(Channel, Id)) return EC; MM = llvm::make_unique<RCMemoryManager>(*this, Id); return std::error_code(); @@ -590,7 +577,7 @@ public: createIndirectStubsManager(std::unique_ptr<RCIndirectStubsManager> &I) { assert(!I && "Indirect stubs manager should be null before creation."); auto Id = IndirectStubOwnerIds.getNext(); - if (auto EC = call<CreateIndirectStubsOwner>(Channel, Id)) + if (auto EC = callST<CreateIndirectStubsOwner>(Channel, Id)) return EC; I = llvm::make_unique<RCIndirectStubsManager>(*this, Id); return std::error_code(); @@ -599,45 +586,39 @@ public: /// Search for symbols in the remote process. Note: This should be used by /// symbol resolvers *after* they've searched the local symbol table in the /// JIT stack. - std::error_code getSymbolAddress(TargetAddress &Addr, StringRef Name) { + ErrorOr<TargetAddress> getSymbolAddress(StringRef Name) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - // Request remote symbol address. - if (auto EC = call<GetSymbolAddress>(Channel, Name)) - return EC; - - return expect<GetSymbolAddressResponse>(Channel, [&](TargetAddress &A) { - Addr = A; - DEBUG(dbgs() << "Remote address lookup " << Name << " = " - << format("0x%016x", Addr) << "\n"); - return std::error_code(); - }); + return callST<GetSymbolAddress>(Channel, Name); } /// Get the triple for the remote target. const std::string &getTargetTriple() const { return RemoteTargetTriple; } - std::error_code terminateSession() { return call<TerminateSession>(Channel); } + std::error_code terminateSession() { + return callST<TerminateSession>(Channel); + } private: OrcRemoteTargetClient(ChannelT &Channel, std::error_code &EC) : Channel(Channel) { - if ((EC = call<GetRemoteInfo>(Channel))) - return; - - EC = expect<GetRemoteInfoResponse>( - Channel, readArgs(RemoteTargetTriple, RemotePointerSize, RemotePageSize, - RemoteTrampolineSize, RemoteIndirectStubSize)); + if (auto RIOrErr = callST<GetRemoteInfo>(Channel)) { + std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize, + RemoteTrampolineSize, RemoteIndirectStubSize) = + *RIOrErr; + EC = std::error_code(); + } else + EC = RIOrErr.getError(); } std::error_code deregisterEHFrames(TargetAddress Addr, uint32_t Size) { - return call<RegisterEHFrames>(Channel, Addr, Size); + return callST<RegisterEHFrames>(Channel, Addr, Size); } void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { - if (auto EC = call<DestroyRemoteAllocator>(Channel, Id)) { + if (auto EC = callST<DestroyRemoteAllocator>(Channel, Id)) { // FIXME: This will be triggered by a removeModuleSet call: Propagate // error return up through that. llvm_unreachable("Failed to destroy remote allocator."); @@ -647,19 +628,13 @@ private: std::error_code destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { IndirectStubOwnerIds.release(Id); - return call<DestroyIndirectStubsOwner>(Channel, Id); + return callST<DestroyIndirectStubsOwner>(Channel, Id); } - std::error_code emitIndirectStubs(TargetAddress &StubBase, - TargetAddress &PtrBase, - uint32_t &NumStubsEmitted, - ResourceIdMgr::ResourceId Id, - uint32_t NumStubsRequired) { - if (auto EC = call<EmitIndirectStubs>(Channel, Id, NumStubsRequired)) - return EC; - - return expect<EmitIndirectStubsResponse>( - Channel, readArgs(StubBase, PtrBase, NumStubsEmitted)); + ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>> + emitIndirectStubs(ResourceIdMgr::ResourceId Id, + uint32_t NumStubsRequired) { + return callST<EmitIndirectStubs>(Channel, Id, NumStubsRequired); } std::error_code emitResolverBlock() { @@ -667,24 +642,16 @@ private: if (ExistingError) return ExistingError; - return call<EmitResolverBlock>(Channel); + return callST<EmitResolverBlock>(Channel); } - std::error_code emitTrampolineBlock(TargetAddress &BlockAddr, - uint32_t &NumTrampolines) { + ErrorOr<std::tuple<TargetAddress, uint32_t>> + emitTrampolineBlock() { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - if (auto EC = call<EmitTrampolineBlock>(Channel)) - return EC; - - return expect<EmitTrampolineBlockResponse>( - Channel, [&](TargetAddress BAddr, uint32_t NTrampolines) { - BlockAddr = BAddr; - NumTrampolines = NTrampolines; - return std::error_code(); - }); + return callST<EmitTrampolineBlock>(Channel); } uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; } @@ -693,67 +660,46 @@ private: uint32_t getTrampolineSize() const { return RemoteTrampolineSize; } - std::error_code listenForCompileRequests(uint32_t &NextId) { + std::error_code listenForCompileRequests(RPCChannel &C, uint32_t &Id) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - if (auto EC = getNextProcId(Channel, NextId)) - return EC; - - while (NextId == RequestCompileId) { - TargetAddress TrampolineAddr = 0; - if (auto EC = handle<RequestCompile>(Channel, readArgs(TrampolineAddr))) - return EC; - - TargetAddress ImplAddr = CompileCallback(TrampolineAddr); - if (auto EC = call<RequestCompileResponse>(Channel, ImplAddr)) - return EC; - - if (auto EC = getNextProcId(Channel, NextId)) + if (Id == RequestCompileId) { + if (auto EC = handle<RequestCompile>(C, CompileCallback)) return EC; + return std::error_code(); } - - return std::error_code(); + // else + return orcError(OrcErrorCode::UnexpectedRPCCall); } - std::error_code readMem(char *Dst, TargetAddress Src, uint64_t Size) { + ErrorOr<std::vector<char>> readMem(char *Dst, TargetAddress Src, uint64_t Size) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - if (auto EC = call<ReadMem>(Channel, Src, Size)) - return EC; - - if (auto EC = expect<ReadMemResponse>( - Channel, [&]() { return Channel.readBytes(Dst, Size); })) - return EC; - - return std::error_code(); + return callST<ReadMem>(Channel, Src, Size); } std::error_code registerEHFrames(TargetAddress &RAddr, uint32_t Size) { - return call<RegisterEHFrames>(Channel, RAddr, Size); + return callST<RegisterEHFrames>(Channel, RAddr, Size); } - std::error_code reserveMem(TargetAddress &RemoteAddr, - ResourceIdMgr::ResourceId Id, uint64_t Size, - uint32_t Align) { + ErrorOr<TargetAddress> reserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, + uint32_t Align) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - if (std::error_code EC = call<ReserveMem>(Channel, Id, Size, Align)) - return EC; - - return expect<ReserveMemResponse>(Channel, readArgs(RemoteAddr)); + return callST<ReserveMem>(Channel, Id, Size, Align); } std::error_code setProtections(ResourceIdMgr::ResourceId Id, TargetAddress RemoteSegAddr, unsigned ProtFlags) { - return call<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags); + return callST<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags); } std::error_code writeMem(TargetAddress Addr, const char *Src, uint64_t Size) { @@ -761,15 +707,7 @@ private: if (ExistingError) return ExistingError; - // Make the send call. - if (auto EC = call<WriteMem>(Channel, Addr, Size)) - return EC; - - // Follow this up with the section contents. - if (auto EC = Channel.appendBytes(Src, Size)) - return EC; - - return Channel.send(); + return callST<WriteMem>(Channel, DirectBufferWriter(Src, Addr, Size)); } std::error_code writePointer(TargetAddress Addr, TargetAddress PtrVal) { @@ -777,7 +715,7 @@ private: if (ExistingError) return ExistingError; - return call<WritePtr>(Channel, Addr, PtrVal); + return callST<WritePtr>(Channel, Addr, PtrVal); } static std::error_code doNothing() { return std::error_code(); } diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index 94327d0e320..e9d4ac7af96 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -24,8 +24,48 @@ namespace llvm { namespace orc { namespace remote { +class DirectBufferWriter { +public: + DirectBufferWriter() = default; + DirectBufferWriter(const char *Src, TargetAddress Dst, uint64_t Size) + : Src(Src), Dst(Dst), Size(Size) {} + + const char *getSrc() const { return Src; } + TargetAddress getDst() const { return Dst; } + uint64_t getSize() const { return Size; } +private: + const char *Src; + TargetAddress Dst; + uint64_t Size; +}; + +inline std::error_code serialize(RPCChannel &C, + const DirectBufferWriter &DBW) { + if (auto EC = serialize(C, DBW.getDst())) + return EC; + if (auto EC = serialize(C, DBW.getSize())) + return EC; + return C.appendBytes(DBW.getSrc(), DBW.getSize()); +} + +inline std::error_code deserialize(RPCChannel &C, + DirectBufferWriter &DBW) { + TargetAddress Dst; + if (auto EC = deserialize(C, Dst)) + return EC; + uint64_t Size; + if (auto EC = deserialize(C, Size)) + return EC; + char *Addr = reinterpret_cast<char*>(static_cast<uintptr_t>(Dst)); + + DBW = DirectBufferWriter(0, Dst, Size); + + return C.readBytes(Addr, Size); +} + class OrcRemoteTargetRPCAPI : public RPC<RPCChannel> { protected: + class ResourceIdMgr { public: typedef uint64_t ResourceId; @@ -45,146 +85,111 @@ protected: }; public: - enum JITProcId : uint32_t { - InvalidId = 0, - CallIntVoidId, - CallIntVoidResponseId, + enum JITFuncId : uint32_t { + InvalidId = RPCFunctionIdTraits<JITFuncId>::InvalidId, + CallIntVoidId = RPCFunctionIdTraits<JITFuncId>::FirstValidId, CallMainId, - CallMainResponseId, CallVoidVoidId, - CallVoidVoidResponseId, CreateRemoteAllocatorId, CreateIndirectStubsOwnerId, DeregisterEHFramesId, DestroyRemoteAllocatorId, DestroyIndirectStubsOwnerId, EmitIndirectStubsId, - EmitIndirectStubsResponseId, EmitResolverBlockId, EmitTrampolineBlockId, - EmitTrampolineBlockResponseId, GetSymbolAddressId, - GetSymbolAddressResponseId, GetRemoteInfoId, - GetRemoteInfoResponseId, ReadMemId, - ReadMemResponseId, RegisterEHFramesId, ReserveMemId, - ReserveMemResponseId, RequestCompileId, - RequestCompileResponseId, SetProtectionsId, TerminateSessionId, WriteMemId, WritePtrId }; - static const char *getJITProcIdName(JITProcId Id); - - typedef Procedure<CallIntVoidId, void(TargetAddress Addr)> CallIntVoid; + static const char *getJITFuncIdName(JITFuncId Id); - typedef Procedure<CallIntVoidResponseId, void(int Result)> - CallIntVoidResponse; + typedef Function<CallIntVoidId, int32_t(TargetAddress Addr)> CallIntVoid; - typedef Procedure<CallMainId, void(TargetAddress Addr, - std::vector<std::string> Args)> + typedef Function<CallMainId, int32_t(TargetAddress Addr, + std::vector<std::string> Args)> CallMain; - typedef Procedure<CallMainResponseId, void(int Result)> CallMainResponse; - - typedef Procedure<CallVoidVoidId, void(TargetAddress FnAddr)> CallVoidVoid; - - typedef Procedure<CallVoidVoidResponseId, void()> CallVoidVoidResponse; + typedef Function<CallVoidVoidId, void(TargetAddress FnAddr)> CallVoidVoid; - typedef Procedure<CreateRemoteAllocatorId, - void(ResourceIdMgr::ResourceId AllocatorID)> + typedef Function<CreateRemoteAllocatorId, + void(ResourceIdMgr::ResourceId AllocatorID)> CreateRemoteAllocator; - typedef Procedure<CreateIndirectStubsOwnerId, - void(ResourceIdMgr::ResourceId StubOwnerID)> + typedef Function<CreateIndirectStubsOwnerId, + void(ResourceIdMgr::ResourceId StubOwnerID)> CreateIndirectStubsOwner; - typedef Procedure<DeregisterEHFramesId, - void(TargetAddress Addr, uint32_t Size)> + typedef Function<DeregisterEHFramesId, + void(TargetAddress Addr, uint32_t Size)> DeregisterEHFrames; - typedef Procedure<DestroyRemoteAllocatorId, - void(ResourceIdMgr::ResourceId AllocatorID)> + typedef Function<DestroyRemoteAllocatorId, + void(ResourceIdMgr::ResourceId AllocatorID)> DestroyRemoteAllocator; - typedef Procedure<DestroyIndirectStubsOwnerId, - void(ResourceIdMgr::ResourceId StubsOwnerID)> + typedef Function<DestroyIndirectStubsOwnerId, + void(ResourceIdMgr::ResourceId StubsOwnerID)> DestroyIndirectStubsOwner; - typedef Procedure<EmitIndirectStubsId, - void(ResourceIdMgr::ResourceId StubsOwnerID, - uint32_t NumStubsRequired)> + /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted). + typedef Function<EmitIndirectStubsId, + std::tuple<TargetAddress, TargetAddress, uint32_t>( + ResourceIdMgr::ResourceId StubsOwnerID, + uint32_t NumStubsRequired)> EmitIndirectStubs; - typedef Procedure<EmitIndirectStubsResponseId, - void(TargetAddress StubsBaseAddr, - TargetAddress PtrsBaseAddr, - uint32_t NumStubsEmitted)> - EmitIndirectStubsResponse; + typedef Function<EmitResolverBlockId, void()> EmitResolverBlock; - typedef Procedure<EmitResolverBlockId, void()> EmitResolverBlock; + /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines). + typedef Function<EmitTrampolineBlockId, + std::tuple<TargetAddress, uint32_t>()> EmitTrampolineBlock; - typedef Procedure<EmitTrampolineBlockId, void()> EmitTrampolineBlock; - - typedef Procedure<EmitTrampolineBlockResponseId, - void(TargetAddress BlockAddr, uint32_t NumTrampolines)> - EmitTrampolineBlockResponse; - - typedef Procedure<GetSymbolAddressId, void(std::string SymbolName)> + typedef Function<GetSymbolAddressId, TargetAddress(std::string SymbolName)> GetSymbolAddress; - typedef Procedure<GetSymbolAddressResponseId, void(uint64_t SymbolAddr)> - GetSymbolAddressResponse; - - typedef Procedure<GetRemoteInfoId, void()> GetRemoteInfo; - - typedef Procedure<GetRemoteInfoResponseId, - void(std::string Triple, uint32_t PointerSize, - uint32_t PageSize, uint32_t TrampolineSize, - uint32_t IndirectStubSize)> - GetRemoteInfoResponse; + /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize, + /// IndirectStubsSize). + typedef Function<GetRemoteInfoId, + std::tuple<std::string, uint32_t, uint32_t, uint32_t, + uint32_t>()> GetRemoteInfo; - typedef Procedure<ReadMemId, void(TargetAddress Src, uint64_t Size)> + typedef Function<ReadMemId, + std::vector<char>(TargetAddress Src, uint64_t Size)> ReadMem; - typedef Procedure<ReadMemResponseId, void()> ReadMemResponse; - - typedef Procedure<RegisterEHFramesId, - void(TargetAddress Addr, uint32_t Size)> + typedef Function<RegisterEHFramesId, + void(TargetAddress Addr, uint32_t Size)> RegisterEHFrames; - typedef Procedure<ReserveMemId, - void(ResourceIdMgr::ResourceId AllocID, uint64_t Size, - uint32_t Align)> + typedef Function<ReserveMemId, + TargetAddress(ResourceIdMgr::ResourceId AllocID, + uint64_t Size, uint32_t Align)> ReserveMem; - typedef Procedure<ReserveMemResponseId, void(TargetAddress Addr)> - ReserveMemResponse; - - typedef Procedure<RequestCompileId, void(TargetAddress TrampolineAddr)> + typedef Function<RequestCompileId, + TargetAddress(TargetAddress TrampolineAddr)> RequestCompile; - typedef Procedure<RequestCompileResponseId, void(TargetAddress ImplAddr)> - RequestCompileResponse; - - typedef Procedure<SetProtectionsId, - void(ResourceIdMgr::ResourceId AllocID, TargetAddress Dst, - uint32_t ProtFlags)> + typedef Function<SetProtectionsId, + void(ResourceIdMgr::ResourceId AllocID, TargetAddress Dst, + uint32_t ProtFlags)> SetProtections; - typedef Procedure<TerminateSessionId, void()> TerminateSession; + typedef Function<TerminateSessionId, void()> TerminateSession; - typedef Procedure<WriteMemId, - void(TargetAddress Dst, uint64_t Size /* Data to follow */)> + typedef Function<WriteMemId, void(DirectBufferWriter DB)> WriteMem; - typedef Procedure<WritePtrId, void(TargetAddress Dst, TargetAddress Val)> + typedef Function<WritePtrId, void(TargetAddress Dst, TargetAddress Val)> WritePtr; }; diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index a6afd3183aa..f15342dfea2 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -45,14 +45,14 @@ public: EHFramesRegister(std::move(EHFramesRegister)), EHFramesDeregister(std::move(EHFramesDeregister)) {} - std::error_code getNextProcId(JITProcId &Id) { + std::error_code getNextFuncId(JITFuncId &Id) { return deserialize(Channel, Id); } - std::error_code handleKnownProcedure(JITProcId Id) { + std::error_code handleKnownFunction(JITFuncId Id) { typedef OrcRemoteTargetServer ThisT; - DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n"); + DEBUG(dbgs() << "Handling known proc: " << getJITFuncIdName(Id) << "\n"); switch (Id) { case CallIntVoidId: @@ -111,27 +111,17 @@ public: llvm_unreachable("Unhandled JIT RPC procedure Id."); } - std::error_code requestCompile(TargetAddress &CompiledFnAddr, - TargetAddress TrampolineAddr) { - if (auto EC = call<RequestCompile>(Channel, TrampolineAddr)) - return EC; - - while (1) { - JITProcId Id = InvalidId; - if (auto EC = getNextProcId(Id)) - return EC; + ErrorOr<TargetAddress> requestCompile(TargetAddress TrampolineAddr) { + auto Listen = + [&](RPCChannel &C, uint32_t Id) { + return handleKnownFunction(static_cast<JITFuncId>(Id)); + }; - switch (Id) { - case RequestCompileResponseId: - return handle<RequestCompileResponse>(Channel, - readArgs(CompiledFnAddr)); - default: - if (auto EC = handleKnownProcedure(Id)) - return EC; - } - } + return callSTHandling<RequestCompile>(Channel, Listen, TrampolineAddr); + } - llvm_unreachable("Fell through request-compile command loop."); + void handleTerminateSession() { + handle<TerminateSession>(Channel, [](){ return std::error_code(); }); } private: @@ -175,18 +165,16 @@ private: static std::error_code doNothing() { return std::error_code(); } static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) { - TargetAddress CompiledFnAddr = 0; - auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr); - auto EC = T->requestCompile( - CompiledFnAddr, static_cast<TargetAddress>( - reinterpret_cast<uintptr_t>(TrampolineAddr))); - assert(!EC && "Compile request failed"); - (void)EC; - return CompiledFnAddr; + auto AddrOrErr = T->requestCompile( + static_cast<TargetAddress>( + reinterpret_cast<uintptr_t>(TrampolineAddr))); + // FIXME: Allow customizable failure substitution functions. + assert(AddrOrErr && "Compile request failed"); + return *AddrOrErr; } - std::error_code handleCallIntVoid(TargetAddress Addr) { + ErrorOr<int32_t> handleCallIntVoid(TargetAddress Addr) { typedef int (*IntVoidFnTy)(); IntVoidFnTy Fn = reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr)); @@ -195,11 +183,11 @@ private: int Result = Fn(); DEBUG(dbgs() << " Result = " << Result << "\n"); - return call<CallIntVoidResponse>(Channel, Result); + return Result; } - std::error_code handleCallMain(TargetAddress Addr, - std::vector<std::string> Args) { + ErrorOr<int32_t> handleCallMain(TargetAddress Addr, + std::vector<std::string> Args) { typedef int (*MainFnTy)(int, const char *[]); MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr)); @@ -214,7 +202,7 @@ private: int Result = Fn(ArgC, ArgV.get()); DEBUG(dbgs() << " Result = " << Result << "\n"); - return call<CallMainResponse>(Channel, Result); + return Result; } std::error_code handleCallVoidVoid(TargetAddress Addr) { @@ -226,7 +214,7 @@ private: Fn(); DEBUG(dbgs() << " Complete.\n"); - return call<CallVoidVoidResponse>(Channel); + return std::error_code(); } std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) { @@ -273,8 +261,9 @@ private: return std::error_code(); } - std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id, - uint32_t NumStubsRequired) { + ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>> + handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id, + uint32_t NumStubsRequired) { DEBUG(dbgs() << " ISMgr " << Id << " request " << NumStubsRequired << " stubs.\n"); @@ -296,8 +285,7 @@ private: auto &BlockList = StubOwnerItr->second; BlockList.push_back(std::move(IS)); - return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase, - NumStubsEmitted); + return std::make_tuple(StubsBase, PtrsBase, NumStubsEmitted); } std::error_code handleEmitResolverBlock() { @@ -316,7 +304,8 @@ private: sys::Memory::MF_EXEC); } - std::error_code handleEmitTrampolineBlock() { + ErrorOr<std::tuple<TargetAddress, uint32_t>> + handleEmitTrampolineBlock() { std::error_code EC; auto TrampolineBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( @@ -325,7 +314,7 @@ private: if (EC) return EC; - unsigned NumTrampolines = + uint32_t NumTrampolines = (sys::Process::getPageSize() - TargetT::PointerSize) / TargetT::TrampolineSize; @@ -339,20 +328,21 @@ private: TrampolineBlocks.push_back(std::move(TrampolineBlock)); - return call<EmitTrampolineBlockResponse>( - Channel, - static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)), - NumTrampolines); + auto TrampolineBaseAddr = + static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)); + + return std::make_tuple(TrampolineBaseAddr, NumTrampolines); } - std::error_code handleGetSymbolAddress(const std::string &Name) { + ErrorOr<TargetAddress> handleGetSymbolAddress(const std::string &Name) { TargetAddress Addr = SymbolLookup(Name); DEBUG(dbgs() << " Symbol '" << Name << "' = " << format("0x%016x", Addr) << "\n"); - return call<GetSymbolAddressResponse>(Channel, Addr); + return Addr; } - std::error_code handleGetRemoteInfo() { + ErrorOr<std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>> + handleGetRemoteInfo() { std::string ProcessTriple = sys::getProcessTriple(); uint32_t PointerSize = TargetT::PointerSize; uint32_t PageSize = sys::Process::getPageSize(); @@ -364,24 +354,23 @@ private: << " page size = " << PageSize << "\n" << " trampoline size = " << TrampolineSize << "\n" << " indirect stub size = " << IndirectStubSize << "\n"); - return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize, - PageSize, TrampolineSize, - IndirectStubSize); + return std::make_tuple(ProcessTriple, PointerSize, PageSize ,TrampolineSize, + IndirectStubSize); } - std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) { + ErrorOr<std::vector<char>> + handleReadMem(TargetAddress RSrc, uint64_t Size) { char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc)); DEBUG(dbgs() << " Reading " << Size << " bytes from " << format("0x%016x", RSrc) << "\n"); - if (auto EC = call<ReadMemResponse>(Channel)) - return EC; - - if (auto EC = Channel.appendBytes(Src, Size)) - return EC; + std::vector<char> Buffer; + Buffer.resize(Size); + for (char *P = Src; Size != 0; --Size) + Buffer.push_back(*P++); - return Channel.send(); + return Buffer; } std::error_code handleRegisterEHFrames(TargetAddress TAddr, uint32_t Size) { @@ -392,8 +381,9 @@ private: return std::error_code(); } - std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, - uint32_t Align) { + ErrorOr<TargetAddress> + handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, + uint32_t Align) { auto I = Allocators.find(Id); if (I == Allocators.end()) return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); @@ -408,7 +398,7 @@ private: TargetAddress AllocAddr = static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr)); - return call<ReserveMemResponse>(Channel, AllocAddr); + return AllocAddr; } std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id, @@ -425,11 +415,10 @@ private: return Allocator.setProtections(LocalAddr, Flags); } - std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) { - char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst)); - DEBUG(dbgs() << " Writing " << Size << " bytes to " - << format("0x%016x", RDst) << "\n"); - return Channel.readBytes(Dst, Size); + std::error_code handleWriteMem(DirectBufferWriter DBW) { + DEBUG(dbgs() << " Writing " << DBW.getSize() << " bytes to " + << format("0x%016x", DBW.getDst()) << "\n"); + return std::error_code(); } std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) { diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h index b97b6daf586..98314aea163 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h @@ -5,8 +5,10 @@ #include "OrcError.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Endian.h" +#include <mutex> #include <system_error> namespace llvm { @@ -26,31 +28,68 @@ public: /// Flush the stream if possible. virtual std::error_code send() = 0; + + /// Get the lock for stream reading. + std::mutex& getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex& getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; }; +/// Notify the channel that we're starting a message send. +/// Locks the channel for writing. +inline std::error_code startSendMessage(RPCChannel &C) { + C.getWriteLock().lock(); + return std::error_code(); +} + +/// Notify the channel that we're ending a message send. +/// Unlocks the channel for writing. +inline std::error_code endSendMessage(RPCChannel &C) { + C.getWriteLock().unlock(); + return std::error_code(); +} + +/// Notify the channel that we're starting a message receive. +/// Locks the channel for reading. +inline std::error_code startReceiveMessage(RPCChannel &C) { + C.getReadLock().lock(); + return std::error_code(); +} + +/// Notify the channel that we're ending a message receive. +/// Unlocks the channel for reading. +inline std::error_code endReceiveMessage(RPCChannel &C) { + C.getReadLock().unlock(); + return std::error_code(); +} + /// RPC channel serialization for a variadic list of arguments. template <typename T, typename... Ts> -std::error_code serialize_seq(RPCChannel &C, const T &Arg, const Ts &... Args) { +std::error_code serializeSeq(RPCChannel &C, const T &Arg, const Ts &... Args) { if (auto EC = serialize(C, Arg)) return EC; - return serialize_seq(C, Args...); + return serializeSeq(C, Args...); } /// RPC channel serialization for an (empty) variadic list of arguments. -inline std::error_code serialize_seq(RPCChannel &C) { +inline std::error_code serializeSeq(RPCChannel &C) { return std::error_code(); } /// RPC channel deserialization for a variadic list of arguments. template <typename T, typename... Ts> -std::error_code deserialize_seq(RPCChannel &C, T &Arg, Ts &... Args) { +std::error_code deserializeSeq(RPCChannel &C, T &Arg, Ts &... Args) { if (auto EC = deserialize(C, Arg)) return EC; - return deserialize_seq(C, Args...); + return deserializeSeq(C, Args...); } /// RPC channel serialization for an (empty) variadic list of arguments. -inline std::error_code deserialize_seq(RPCChannel &C) { +inline std::error_code deserializeSeq(RPCChannel &C) { return std::error_code(); } @@ -138,6 +177,34 @@ inline std::error_code deserialize(RPCChannel &C, std::string &S) { return C.readBytes(&S[0], Count); } +// Serialization helper for std::tuple. +template <typename TupleT, size_t... Is> +inline std::error_code serializeTupleHelper(RPCChannel &C, + const TupleT &V, + llvm::index_sequence<Is...> _) { + return serializeSeq(C, std::get<Is>(V)...); +} + +/// RPC channel serialization for std::tuple. +template <typename... ArgTs> +inline std::error_code serialize(RPCChannel &C, const std::tuple<ArgTs...> &V) { + return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); +} + +// Serialization helper for std::tuple. +template <typename TupleT, size_t... Is> +inline std::error_code deserializeTupleHelper(RPCChannel &C, + TupleT &V, + llvm::index_sequence<Is...> _) { + return deserializeSeq(C, std::get<Is>(V)...); +} + +/// RPC channel deserialization for std::tuple. +template <typename... ArgTs> +inline std::error_code deserialize(RPCChannel &C, std::tuple<ArgTs...> &V) { + return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); +} + /// RPC channel serialization for ArrayRef<T>. template <typename T> std::error_code serialize(RPCChannel &C, const ArrayRef<T> &A) { diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index d1b8546268f..3556e6b1eef 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -14,46 +14,197 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/Support/ErrorOr.h" +#include <future> +#include <map> namespace llvm { namespace orc { namespace remote { +/// Describes reserved RPC Function Ids. +/// +/// The default implementation will serve for integer and enum function id +/// types. If you want to use a custom type as your FunctionId you can +/// specialize this class and provide unique values for InvalidId, +/// ResponseId and FirstValidId. + +template <typename T> +class RPCFunctionIdTraits { +public: + constexpr static const T InvalidId = static_cast<T>(0); + constexpr static const T ResponseId = static_cast<T>(1); + constexpr static const T FirstValidId = static_cast<T>(2); +}; + // Base class containing utilities that require partial specialization. // These cannot be included in RPC, as template class members cannot be // partially specialized. class RPCBase { protected: - template <typename ProcedureIdT, ProcedureIdT ProcId, typename FnT> - class ProcedureHelper { + + // RPC Function description type. + // + // This class provides the information and operations needed to support the + // RPC primitive operations (call, expect, etc) for a given function. It + // is specialized for void and non-void functions to deal with the differences + // betwen the two. Both specializations have the same interface: + // + // Id - The function's unique identifier. + // OptionalReturn - The return type for asyncronous calls. + // ErrorReturn - The return type for synchronous calls. + // optionalToErrorReturn - Conversion from a valid OptionalReturn to an + // ErrorReturn. + // readResult - Deserialize a result from a channel. + // abandon - Abandon a promised (asynchronous) result. + // respond - Retun a result on the channel. + template <typename FunctionIdT, FunctionIdT FuncId, typename FnT> + class FunctionHelper {}; + + // RPC Function description specialization for non-void functions. + template <typename FunctionIdT, FunctionIdT FuncId, typename RetT, + typename... ArgTs> + class FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> { public: - static const ProcedureIdT Id = ProcId; + + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, + "Cannot define custom function with InvalidId or ResponseId. " + "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); + + static const FunctionIdT Id = FuncId; + + typedef Optional<RetT> OptionalReturn; + + typedef ErrorOr<RetT> ErrorReturn; + + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { + assert(V && "Return value not available"); + return std::move(*V); + } + + template <typename ChannelT> + static std::error_code readResult(ChannelT &C, + std::promise<OptionalReturn> &P) { + RetT Val; + auto EC = deserialize(C, Val); + // FIXME: Join error EC2 from endReceiveMessage with the deserialize + // error once we switch to using Error. + auto EC2 = endReceiveMessage(C); + (void)EC2; + + if (EC) { + P.set_value(OptionalReturn()); + return EC; + } + P.set_value(std::move(Val)); + return std::error_code(); + } + + static void abandon(std::promise<OptionalReturn> &P) { + P.set_value(OptionalReturn()); + } + + template <typename ChannelT, typename SequenceNumberT> + static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo, + const ErrorReturn &Result) { + FunctionIdT ResponseId = + RPCFunctionIdTraits<FunctionIdT>::ResponseId; + + // If the handler returned an error then bail out with that. + if (!Result) + return Result.getError(); + + // Otherwise open a new message on the channel and send the result. + if (auto EC = startSendMessage(C)) + return EC; + if (auto EC = serializeSeq(C, ResponseId, SeqNo, *Result)) + return EC; + return endSendMessage(C); + } }; - template <typename ChannelT, typename Proc> class CallHelper; + // RPC Function description specialization for void functions. + template <typename FunctionIdT, FunctionIdT FuncId, typename... ArgTs> + class FunctionHelper<FunctionIdT, FuncId, void(ArgTs...)> { + public: - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, - typename... ArgTs> - class CallHelper<ChannelT, - ProcedureHelper<ProcedureIdT, ProcId, void(ArgTs...)>> { + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, + "Cannot define custom function with InvalidId or ResponseId. " + "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); + + static const FunctionIdT Id = FuncId; + + typedef bool OptionalReturn; + typedef std::error_code ErrorReturn; + + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { + assert(V && "Return value not available"); + return std::error_code(); + } + + template <typename ChannelT> + static std::error_code readResult(ChannelT &C, + std::promise<OptionalReturn> &P) { + // Void functions don't have anything to deserialize, so we're good. + P.set_value(true); + return endReceiveMessage(C); + } + + static void abandon(std::promise<OptionalReturn> &P) { + P.set_value(false); + } + + template <typename ChannelT, typename SequenceNumberT> + static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo, + const ErrorReturn &Result) { + const FunctionIdT ResponseId = + RPCFunctionIdTraits<FunctionIdT>::ResponseId; + + // If the handler returned an error then bail out with that. + if (Result) + return Result; + + // Otherwise open a new message on the channel and send the result. + if (auto EC = startSendMessage(C)) + return EC; + if (auto EC = serializeSeq(C, ResponseId, SeqNo)) + return EC; + return endSendMessage(C); + } + }; + + // Helper for the call primitive. + template <typename ChannelT, typename SequenceNumberT, typename Func> + class CallHelper; + + template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, + FunctionIdT FuncId, typename RetT, typename... ArgTs> + class CallHelper<ChannelT, SequenceNumberT, + FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { public: - static std::error_code call(ChannelT &C, const ArgTs &... Args) { - if (auto EC = serialize(C, ProcId)) + static std::error_code call(ChannelT &C, SequenceNumberT SeqNo, + const ArgTs &... Args) { + if (auto EC = startSendMessage(C)) + return EC; + if (auto EC = serializeSeq(C, FuncId, SeqNo, Args...)) return EC; - // If you see a compile-error on this line you're probably calling a - // function with the wrong signature. - return serialize_seq(C, Args...); + return endSendMessage(C); } }; - template <typename ChannelT, typename Proc> class HandlerHelper; + // Helper for handle primitive. + template <typename ChannelT, typename SequenceNumberT, typename Func> + class HandlerHelper; - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, - typename... ArgTs> - class HandlerHelper<ChannelT, - ProcedureHelper<ProcedureIdT, ProcId, void(ArgTs...)>> { + template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, + FunctionIdT FuncId, typename RetT, typename... ArgTs> + class HandlerHelper<ChannelT, SequenceNumberT, + FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { public: template <typename HandlerT> static std::error_code handle(ChannelT &C, HandlerT Handler) { @@ -61,34 +212,46 @@ protected: } private: + + typedef FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> Func; + template <typename HandlerT, size_t... Is> static std::error_code readAndHandle(ChannelT &C, HandlerT Handler, llvm::index_sequence<Is...> _) { std::tuple<ArgTs...> RPCArgs; + SequenceNumberT SeqNo; // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning // for RPCArgs. Void cast RPCArgs to work around this for now. // FIXME: Remove this workaround once we can assume a working GCC version. (void)RPCArgs; - if (auto EC = deserialize_seq(C, std::get<Is>(RPCArgs)...)) + if (auto EC = deserializeSeq(C, SeqNo, std::get<Is>(RPCArgs)...)) + return EC; + + // We've deserialized the arguments, so unlock the channel for reading + // before we call the handler. This allows recursive RPC calls. + if (auto EC = endReceiveMessage(C)) return EC; - return Handler(std::get<Is>(RPCArgs)...); + + return Func::template respond<ChannelT, SequenceNumberT>( + C, SeqNo, Handler(std::get<Is>(RPCArgs)...)); } + }; - template <typename ClassT, typename... ArgTs> class MemberFnWrapper { + // Helper for wrapping member functions up as functors. + template <typename ClassT, typename RetT, typename... ArgTs> + class MemberFnWrapper { public: - typedef std::error_code (ClassT::*MethodT)(ArgTs...); + typedef RetT(ClassT::*MethodT)(ArgTs...); MemberFnWrapper(ClassT &Instance, MethodT Method) : Instance(Instance), Method(Method) {} - std::error_code operator()(ArgTs &... Args) { - return (Instance.*Method)(Args...); - } - + RetT operator()(ArgTs &... Args) { return (Instance.*Method)(Args...); } private: ClassT &Instance; MethodT Method; }; + // Helper that provides a Functor for deserializing arguments. template <typename... ArgTs> class ReadArgs { public: std::error_code operator()() { return std::error_code(); } @@ -112,7 +275,7 @@ protected: /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), and ProcedureIdT is a procedure +/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a procedure /// identifier type that must be serializable on ChannelT. /// /// These utilities support the construction of very primitive RPC utilities. @@ -129,120 +292,184 @@ protected: /// /// Overview (see comments individual types/methods for details): /// -/// Procedure<Id, Args...> : +/// Function<Id, Args...> : /// /// associates a unique serializable id with an argument list. /// /// -/// call<Proc>(Channel, Args...) : +/// call<Func>(Channel, Args...) : /// -/// Calls the remote procedure 'Proc' by serializing Proc's id followed by its +/// Calls the remote procedure 'Func' by serializing Func's id followed by its /// arguments and sending the resulting bytes to 'Channel'. /// /// -/// handle<Proc>(Channel, <functor matching std::error_code(Args...)> : +/// handle<Func>(Channel, <functor matching std::error_code(Args...)> : /// -/// Handles a call to 'Proc' by deserializing its arguments and calling the -/// given functor. This assumes that the id for 'Proc' has already been +/// Handles a call to 'Func' by deserializing its arguments and calling the +/// given functor. This assumes that the id for 'Func' has already been /// deserialized. /// -/// expect<Proc>(Channel, <functor matching std::error_code(Args...)> : +/// expect<Func>(Channel, <functor matching std::error_code(Args...)> : /// /// The same as 'handle', except that the procedure id should not have been -/// read yet. Expect will deserialize the id and assert that it matches Proc's +/// read yet. Expect will deserialize the id and assert that it matches Func's /// id. If it does not, and unexpected RPC call error is returned. - -template <typename ChannelT, typename ProcedureIdT = uint32_t> +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint16_t> class RPC : public RPCBase { public: + + RPC() = default; + RPC(const RPC&) = delete; + RPC& operator=(const RPC&) = delete; + RPC(RPC &&Other) : SequenceNumberMgr(std::move(Other.SequenceNumberMgr)), OutstandingResults(std::move(Other.OutstandingResults)) {} + RPC& operator=(RPC&&) = default; + /// Utility class for defining/referring to RPC procedures. /// /// Typedefs of this utility are used when calling/handling remote procedures. /// - /// ProcId should be a unique value of ProcedureIdT (i.e. not used with any - /// other Procedure typedef in the RPC API being defined. + /// FuncId should be a unique value of FunctionIdT (i.e. not used with any + /// other Function typedef in the RPC API being defined. /// /// the template argument Ts... gives the argument list for the remote /// procedure. /// /// E.g. /// - /// typedef Procedure<0, bool> Proc1; - /// typedef Procedure<1, std::string, std::vector<int>> Proc2; + /// typedef Function<0, bool> Func1; + /// typedef Function<1, std::string, std::vector<int>> Func2; /// - /// if (auto EC = call<Proc1>(Channel, true)) + /// if (auto EC = call<Func1>(Channel, true)) /// /* handle EC */; /// - /// if (auto EC = expect<Proc2>(Channel, + /// if (auto EC = expect<Func2>(Channel, /// [](std::string &S, std::vector<int> &V) { /// // Stuff. /// return std::error_code(); /// }) /// /* handle EC */; /// - template <ProcedureIdT ProcId, typename FnT> - using Procedure = ProcedureHelper<ProcedureIdT, ProcId, FnT>; + template <FunctionIdT FuncId, typename FnT> + using Function = FunctionHelper<FunctionIdT, FuncId, FnT>; + + /// Return type for asynchronous call primitives. + template <typename Func> + using AsyncCallResult = + std::pair<std::future<typename Func::OptionalReturn>, SequenceNumberT>; /// Serialize Args... to channel C, but do not call C.send(). /// - /// For buffered channels, this can be used to queue up several calls before - /// flushing the channel. - template <typename Proc, typename... ArgTs> - static std::error_code appendCall(ChannelT &C, const ArgTs &... Args) { - return CallHelper<ChannelT, Proc>::call(C, Args...); + /// For void functions returns a std::future<Error>. For functions that + /// return an R, returns a std::future<Optional<R>>. + template <typename Func, typename... ArgTs> + ErrorOr<AsyncCallResult<Func>> + appendCallAsync(ChannelT &C, const ArgTs &... Args) { + auto SeqNo = SequenceNumberMgr.getSequenceNumber(); + std::promise<typename Func::OptionalReturn> Promise; + auto Result = Promise.get_future(); + OutstandingResults[SeqNo] = std::move(Promise); + + if (auto EC = + CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, + Args...)) { + abandonOutstandingResults(); + return EC; + } else + return AsyncCallResult<Func>(std::move(Result), SeqNo); } /// Serialize Args... to channel C and call C.send(). - template <typename Proc, typename... ArgTs> - static std::error_code call(ChannelT &C, const ArgTs &... Args) { - if (auto EC = appendCall<Proc>(C, Args...)) + template <typename Func, typename... ArgTs> + ErrorOr<AsyncCallResult<Func>> + callAsync(ChannelT &C, const ArgTs &... Args) { + auto SeqNo = SequenceNumberMgr.getSequenceNumber(); + std::promise<typename Func::OptionalReturn> Promise; + auto Result = Promise.get_future(); + OutstandingResults[SeqNo] = + createOutstandingResult<Func>(std::move(Promise)); + if (auto EC = + CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, Args...)) { + abandonOutstandingResults(); + return EC; + } + if (auto EC = C.send()) { + abandonOutstandingResults(); return EC; - return C.send(); + } + return AsyncCallResult<Func>(std::move(Result), SeqNo); + } + + /// This can be used in single-threaded mode. + template <typename Func, typename HandleFtor, typename... ArgTs> + typename Func::ErrorReturn + callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) { + if (auto ResultAndSeqNoOrErr = callAsync<Func>(C, Args...)) { + auto &ResultAndSeqNo = *ResultAndSeqNoOrErr; + if (auto EC = waitForResult(C, ResultAndSeqNo.second, HandleOther)) + return EC; + return Func::optionalToErrorReturn(ResultAndSeqNo.first.get()); + } else + return ResultAndSeqNoOrErr.getError(); + } + + // This can be used in single-threaded mode. + template <typename Func, typename... ArgTs> + typename Func::ErrorReturn + callST(ChannelT &C, const ArgTs &... Args) { + return callSTHandling<Func>(C, handleNone, Args...); } - /// Deserialize and return an enum whose underlying type is ProcedureIdT. - static std::error_code getNextProcId(ChannelT &C, ProcedureIdT &Id) { + /// Start receiving a new function call. + /// + /// Calls startReceiveMessage on the channel, then deserializes a FunctionId + /// into Id. + std::error_code startReceivingFunction(ChannelT &C, FunctionIdT &Id) { + if (auto EC = startReceiveMessage(C)) + return EC; + return deserialize(C, Id); } - /// Deserialize args for Proc from C and call Handler. The signature of + /// Deserialize args for Func from C and call Handler. The signature of /// handler must conform to 'std::error_code(Args...)' where Args... matches - /// the arguments used in the Proc typedef. - template <typename Proc, typename HandlerT> + /// the arguments used in the Func typedef. + template <typename Func, typename HandlerT> static std::error_code handle(ChannelT &C, HandlerT Handler) { - return HandlerHelper<ChannelT, Proc>::handle(C, Handler); + return HandlerHelper<ChannelT, SequenceNumberT, Func>::handle(C, Handler); } /// Helper version of 'handle' for calling member functions. - template <typename Proc, typename ClassT, typename... ArgTs> + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> static std::error_code handle(ChannelT &C, ClassT &Instance, - std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { - return handle<Proc>( - C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); + RetT (ClassT::*HandlerMethod)(ArgTs...)) { + return handle<Func>( + C, + MemberFnWrapper<ClassT, RetT, ArgTs...>(Instance, HandlerMethod)); } - /// Deserialize a ProcedureIdT from C and verify it matches the id for Proc. + /// Deserialize a FunctionIdT from C and verify it matches the id for Func. /// If the id does match, deserialize the arguments and call the handler /// (similarly to handle). /// If the id does not match, return an unexpect RPC call error and do not /// deserialize any further bytes. - template <typename Proc, typename HandlerT> - static std::error_code expect(ChannelT &C, HandlerT Handler) { - ProcedureIdT ProcId; - if (auto EC = getNextProcId(C, ProcId)) - return EC; - if (ProcId != Proc::Id) + template <typename Func, typename HandlerT> + std::error_code expect(ChannelT &C, HandlerT Handler) { + FunctionIdT FuncId; + if (auto EC = startReceivingFunction(C, FuncId)) + return std::move(EC); + if (FuncId != Func::Id) return orcError(OrcErrorCode::UnexpectedRPCCall); - return handle<Proc>(C, Handler); + return handle<Func>(C, Handler); } /// Helper version of expect for calling member functions. - template <typename Proc, typename ClassT, typename... ArgTs> + template <typename Func, typename ClassT, typename... ArgTs> static std::error_code expect(ChannelT &C, ClassT &Instance, std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { - return expect<Proc>( + return expect<Func>( C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); } @@ -251,18 +478,163 @@ public: /// channel. /// E.g. /// - /// typedef Procedure<0, bool, int> Proc1; + /// typedef Function<0, bool, int> Func1; /// /// ... /// bool B; /// int I; - /// if (auto EC = expect<Proc1>(Channel, readArgs(B, I))) + /// if (auto EC = expect<Func1>(Channel, readArgs(B, I))) /// /* Handle Args */ ; /// template <typename... ArgTs> static ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { return ReadArgs<ArgTs...>(Args...); } + + /// Read a response from Channel. + /// This should be called from the receive loop to retrieve results. + std::error_code handleResponse(ChannelT &C, SequenceNumberT &SeqNo) { + if (auto EC = deserialize(C, SeqNo)) { + abandonOutstandingResults(); + return EC; + } + + auto I = OutstandingResults.find(SeqNo); + if (I == OutstandingResults.end()) { + abandonOutstandingResults(); + return orcError(OrcErrorCode::UnexpectedRPCResponse); + } + + if (auto EC = I->second->readResult(C)) { + abandonOutstandingResults(); + // FIXME: Release sequence numbers? + return EC; + } + + OutstandingResults.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + + return std::error_code(); + } + + // Loop waiting for a result with the given sequence number. + // This can be used as a receive loop if the user doesn't have a default. + template <typename HandleOtherFtor> + std::error_code waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo, + HandleOtherFtor &HandleOther = handleNone) { + bool GotTgtResult = false; + + while (!GotTgtResult) { + FunctionIdT Id = + RPCFunctionIdTraits<FunctionIdT>::InvalidId; + if (auto EC = startReceivingFunction(C, Id)) + return EC; + if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) { + SequenceNumberT SeqNo; + if (auto EC = handleResponse(C, SeqNo)) + return EC; + GotTgtResult = (SeqNo == TgtSeqNo); + } else if (auto EC = HandleOther(C, Id)) + return EC; + } + + return std::error_code(); + }; + + // Default handler for 'other' (non-response) functions when waiting for a + // result from the channel. + static std::error_code handleNone(ChannelT&, FunctionIdT) { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + +private: + + // Manage sequence numbers. + class SequenceNumberManager { + public: + + SequenceNumberManager() = default; + + SequenceNumberManager(SequenceNumberManager &&Other) + : NextSequenceNumber(std::move(Other.NextSequenceNumber)), + FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {} + + SequenceNumberManager& operator=(SequenceNumberManager &&Other) { + NextSequenceNumber = std::move(Other.NextSequenceNumber); + FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers); + } + + void reset() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } + + SequenceNumberT getSequenceNumber() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } + + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard<std::mutex> Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } + + private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector<SequenceNumberT> FreeSequenceNumbers; + }; + + // Base class for results that haven't been returned from the other end of the + // RPC connection yet. + class OutstandingResult { + public: + virtual ~OutstandingResult() {} + virtual std::error_code readResult(ChannelT &C) = 0; + virtual void abandon() = 0; + }; + + // Outstanding results for a specific function. + template <typename Func> + class OutstandingResultImpl : public OutstandingResult { + private: + public: + OutstandingResultImpl(std::promise<typename Func::OptionalReturn> &&P) + : P(std::move(P)) {} + + std::error_code readResult(ChannelT &C) override { + return Func::readResult(C, P); + } + + void abandon() override { Func::abandon(P); } + + private: + std::promise<typename Func::OptionalReturn> P; + }; + + // Create an outstanding result for the given function. + template <typename Func> + std::unique_ptr<OutstandingResult> + createOutstandingResult(std::promise<typename Func::OptionalReturn> &&P) { + return llvm::make_unique<OutstandingResultImpl<Func>>(std::move(P)); + } + + // Abandon all outstanding results. + void abandonOutstandingResults() { + for (auto &KV : OutstandingResults) + KV.second->abandon(); + OutstandingResults.clear(); + SequenceNumberMgr.reset(); + } + + SequenceNumberManager SequenceNumberMgr; + std::map<SequenceNumberT, std::unique_ptr<OutstandingResult>> + OutstandingResults; }; } // end namespace remote diff --git a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp index e95115ec6fe..5e12c86c704 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp @@ -38,6 +38,8 @@ public: return "Remote indirect stubs owner Id already in use"; case OrcErrorCode::UnexpectedRPCCall: return "Unexpected RPC call"; + case OrcErrorCode::UnexpectedRPCResponse: + return "Unexpected RPC response"; } llvm_unreachable("Unhandled error code"); } diff --git a/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp b/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp index 81e51a83021..d1a021aee3a 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp @@ -13,50 +13,40 @@ namespace llvm { namespace orc { namespace remote { -#define PROCNAME(X) \ +#define FUNCNAME(X) \ case X ## Id: \ return #X -const char *OrcRemoteTargetRPCAPI::getJITProcIdName(JITProcId Id) { +const char *OrcRemoteTargetRPCAPI::getJITFuncIdName(JITFuncId Id) { switch (Id) { case InvalidId: - return "*** Invalid JITProcId ***"; - PROCNAME(CallIntVoid); - PROCNAME(CallIntVoidResponse); - PROCNAME(CallMain); - PROCNAME(CallMainResponse); - PROCNAME(CallVoidVoid); - PROCNAME(CallVoidVoidResponse); - PROCNAME(CreateRemoteAllocator); - PROCNAME(CreateIndirectStubsOwner); - PROCNAME(DeregisterEHFrames); - PROCNAME(DestroyRemoteAllocator); - PROCNAME(DestroyIndirectStubsOwner); - PROCNAME(EmitIndirectStubs); - PROCNAME(EmitIndirectStubsResponse); - PROCNAME(EmitResolverBlock); - PROCNAME(EmitTrampolineBlock); - PROCNAME(EmitTrampolineBlockResponse); - PROCNAME(GetSymbolAddress); - PROCNAME(GetSymbolAddressResponse); - PROCNAME(GetRemoteInfo); - PROCNAME(GetRemoteInfoResponse); - PROCNAME(ReadMem); - PROCNAME(ReadMemResponse); - PROCNAME(RegisterEHFrames); - PROCNAME(ReserveMem); - PROCNAME(ReserveMemResponse); - PROCNAME(RequestCompile); - PROCNAME(RequestCompileResponse); - PROCNAME(SetProtections); - PROCNAME(TerminateSession); - PROCNAME(WriteMem); - PROCNAME(WritePtr); + return "*** Invalid JITFuncId ***"; + FUNCNAME(CallIntVoid); + FUNCNAME(CallMain); + FUNCNAME(CallVoidVoid); + FUNCNAME(CreateRemoteAllocator); + FUNCNAME(CreateIndirectStubsOwner); + FUNCNAME(DeregisterEHFrames); + FUNCNAME(DestroyRemoteAllocator); + FUNCNAME(DestroyIndirectStubsOwner); + FUNCNAME(EmitIndirectStubs); + FUNCNAME(EmitResolverBlock); + FUNCNAME(EmitTrampolineBlock); + FUNCNAME(GetSymbolAddress); + FUNCNAME(GetRemoteInfo); + FUNCNAME(ReadMem); + FUNCNAME(RegisterEHFrames); + FUNCNAME(ReserveMem); + FUNCNAME(RequestCompile); + FUNCNAME(SetProtections); + FUNCNAME(TerminateSession); + FUNCNAME(WriteMem); + FUNCNAME(WritePtr); }; return nullptr; } -#undef PROCNAME +#undef FUNCNAME } // end namespace remote } // end namespace orc diff --git a/llvm/tools/lli/ChildTarget/ChildTarget.cpp b/llvm/tools/lli/ChildTarget/ChildTarget.cpp index 93925d6aa87..33de1850547 100644 --- a/llvm/tools/lli/ChildTarget/ChildTarget.cpp +++ b/llvm/tools/lli/ChildTarget/ChildTarget.cpp @@ -54,8 +54,8 @@ int main(int argc, char *argv[]) { JITServer Server(Channel, SymbolLookup, RegisterEHFrames, DeregisterEHFrames); while (1) { - JITServer::JITProcId Id = JITServer::InvalidId; - if (auto EC = Server.getNextProcId(Id)) { + JITServer::JITFuncId Id = JITServer::InvalidId; + if (auto EC = Server.getNextFuncId(Id)) { errs() << "Error: " << EC.message() << "\n"; return 1; } @@ -63,7 +63,7 @@ int main(int argc, char *argv[]) { case JITServer::TerminateSessionId: return 0; default: - if (auto EC = Server.handleKnownProcedure(Id)) { + if (auto EC = Server.handleKnownFunction(Id)) { errs() << "Error: " << EC.message() << "\n"; return 1; } diff --git a/llvm/tools/lli/RemoteJITUtils.h b/llvm/tools/lli/RemoteJITUtils.h index d5488ad555c..63915d13dde 100644 --- a/llvm/tools/lli/RemoteJITUtils.h +++ b/llvm/tools/lli/RemoteJITUtils.h @@ -16,6 +16,7 @@ #include "llvm/ExecutionEngine/Orc/RPCChannel.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" +#include <mutex> #if !defined(_MSC_VER) && !defined(__MINGW32__) #include <unistd.h> diff --git a/llvm/tools/lli/lli.cpp b/llvm/tools/lli/lli.cpp index 0f30a4a5ff0..ce99b6ac7a0 100644 --- a/llvm/tools/lli/lli.cpp +++ b/llvm/tools/lli/lli.cpp @@ -582,7 +582,7 @@ int main(int argc, char **argv, char * const *envp) { // Reset errno to zero on entry to main. errno = 0; - int Result; + int Result = -1; // Sanity check use of remote-jit: LLI currently only supports use of the // remote JIT on Unix platforms. @@ -681,12 +681,13 @@ int main(int argc, char **argv, char * const *envp) { static_cast<ForwardingMemoryManager*>(RTDyldMM)->setResolver( orc::createLambdaResolver( [&](const std::string &Name) { - orc::TargetAddress Addr = 0; - if (auto EC = R->getSymbolAddress(Addr, Name)) { - errs() << "Failure during symbol lookup: " << EC.message() << "\n"; - exit(1); - } - return RuntimeDyld::SymbolInfo(Addr, JITSymbolFlags::Exported); + if (auto AddrOrErr = R->getSymbolAddress(Name)) + return RuntimeDyld::SymbolInfo(*AddrOrErr, JITSymbolFlags::Exported); + else { + errs() << "Failure during symbol lookup: " + << AddrOrErr.getError().message() << "\n"; + exit(1); + } }, [](const std::string &Name) { return nullptr; } )); @@ -698,8 +699,10 @@ int main(int argc, char **argv, char * const *envp) { EE->finalizeObject(); DEBUG(dbgs() << "Executing '" << EntryFn->getName() << "' at 0x" << format("%llx", Entry) << "\n"); - if (auto EC = R->callIntVoid(Result, Entry)) - errs() << "ERROR: " << EC.message() << "\n"; + if (auto ResultOrErr = R->callIntVoid(Entry)) + Result = *ResultOrErr; + else + errs() << "ERROR: " << ResultOrErr.getError().message() << "\n"; // Like static constructors, the remote target MCJIT support doesn't handle // this yet. It could. FIXME. diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 3b01c3828b6..77632e35eb1 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -44,26 +44,25 @@ private: class DummyRPC : public testing::Test, public RPC<QueueChannel> { public: - typedef Procedure<1, void(bool)> Proc1; - typedef Procedure<2, void(int8_t, uint8_t, int16_t, uint16_t, - int32_t, uint32_t, int64_t, uint64_t, - bool, std::string, std::vector<int>)> AllTheTypes; + typedef Function<2, void(bool)> BasicVoid; + typedef Function<3, int32_t(bool)> BasicInt; + typedef Function<4, void(int8_t, uint8_t, int16_t, uint16_t, + int32_t, uint32_t, int64_t, uint64_t, + bool, std::string, std::vector<int>)> AllTheTypes; }; -TEST_F(DummyRPC, TestBasic) { +TEST_F(DummyRPC, TestAsyncBasicVoid) { std::queue<char> Queue; QueueChannel C(Queue); - { - // Make a call to Proc1. - auto EC = call<Proc1>(C, true); - EXPECT_FALSE(EC) << "Simple call over queue failed"; - } + // Make an async call. + auto ResOrErr = callAsync<BasicVoid>(C, true); + EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; { // Expect a call to Proc1. - auto EC = expect<Proc1>(C, + auto EC = expect<BasicVoid>(C, [&](bool &B) { EXPECT_EQ(B, true) << "Bool serialization broken"; @@ -71,31 +70,71 @@ TEST_F(DummyRPC, TestBasic) { }); EXPECT_FALSE(EC) << "Simple expect over queue failed"; } + + { + // Wait for the result. + auto EC = waitForResult(C, ResOrErr->second, handleNone); + EXPECT_FALSE(EC) << "Could not read result."; + } + + // Verify that the function returned ok. + auto Val = ResOrErr->first.get(); + EXPECT_TRUE(Val) << "Remote void function failed to execute."; } -TEST_F(DummyRPC, TestSerialization) { +TEST_F(DummyRPC, TestAsyncBasicInt) { std::queue<char> Queue; QueueChannel C(Queue); + // Make an async call. + auto ResOrErr = callAsync<BasicInt>(C, false); + EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; + { - // Make a call to Proc1. - std::vector<int> v({42, 7}); - auto EC = call<AllTheTypes>(C, - -101, - 250, - -10000, - 10000, - -1000000000, - 1000000000, - -10000000000, - 10000000000, - true, - "foo", - v); - EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; + // Expect a call to Proc1. + auto EC = expect<BasicInt>(C, + [&](bool &B) { + EXPECT_EQ(B, false) + << "Bool serialization broken"; + return 42; + }); + EXPECT_FALSE(EC) << "Simple expect over queue failed"; } { + // Wait for the result. + auto EC = waitForResult(C, ResOrErr->second, handleNone); + EXPECT_FALSE(EC) << "Could not read result."; + } + + // Verify that the function returned ok. + auto Val = ResOrErr->first.get(); + EXPECT_TRUE(!!Val) << "Remote int function failed to execute."; + EXPECT_EQ(*Val, 42) << "Remote int function return wrong value."; +} + +TEST_F(DummyRPC, TestSerialization) { + std::queue<char> Queue; + QueueChannel C(Queue); + + // Make a call to Proc1. + std::vector<int> v({42, 7}); + auto ResOrErr = callAsync<AllTheTypes>(C, + -101, + 250, + -10000, + 10000, + -1000000000, + 1000000000, + -10000000000, + 10000000000, + true, + "foo", + v); + EXPECT_TRUE(!!ResOrErr) + << "Big (serialization test) call over queue failed"; + + { // Expect a call to Proc1. auto EC = expect<AllTheTypes>(C, [&](int8_t &s8, @@ -136,4 +175,14 @@ TEST_F(DummyRPC, TestSerialization) { }); EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; } + + { + // Wait for the result. + auto EC = waitForResult(C, ResOrErr->second, handleNone); + EXPECT_FALSE(EC) << "Could not read result."; + } + + // Verify that the function returned ok. + auto Val = ResOrErr->first.get(); + EXPECT_TRUE(Val) << "Remote void function failed to execute."; } |