diff options
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcError.h | 3 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h | 244 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h | 171 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h | 123 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h | 79 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 524 | ||||
-rw-r--r-- | llvm/lib/ExecutionEngine/Orc/OrcError.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp | 60 | ||||
-rw-r--r-- | llvm/tools/lli/ChildTarget/ChildTarget.cpp | 6 | ||||
-rw-r--r-- | llvm/tools/lli/RemoteJITUtils.h | 1 | ||||
-rw-r--r-- | llvm/tools/lli/lli.cpp | 21 | ||||
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 103 |
12 files changed, 460 insertions, 877 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h index aeee03f86e9..48f35d6b39b 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -26,8 +26,7 @@ enum class OrcErrorCode : int { RemoteMProtectAddrUnrecognized, RemoteIndirectStubsOwnerDoesNotExist, RemoteIndirectStubsOwnerIdAlreadyInUse, - UnexpectedRPCCall, - UnexpectedRPCResponse, + UnexpectedRPCCall }; std::error_code orcError(OrcErrorCode ErrCode); diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index 9ecf904c9ff..8068733dcdd 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -36,7 +36,6 @@ namespace remote { template <typename ChannelT> class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI { public: - /// Remote memory manager. class RCMemoryManager : public RuntimeDyld::MemoryManager { public: @@ -106,13 +105,11 @@ public: DEBUG(dbgs() << "Allocator " << Id << " reserved:\n"); if (CodeSize != 0) { - if (auto AddrOrErr = Client.reserveMem(Id, CodeSize, CodeAlign)) - Unmapped.back().RemoteCodeAddr = *AddrOrErr; - else { - // FIXME; Add error to poll. - assert(!AddrOrErr.getError() && "Failed reserving remote memory."); - } - + std::error_code EC = Client.reserveMem(Unmapped.back().RemoteCodeAddr, + Id, CodeSize, CodeAlign); + // FIXME; Add error to poll. + assert(!EC && "Failed reserving remote memory."); + (void)EC; DEBUG(dbgs() << " code: " << format("0x%016x", Unmapped.back().RemoteCodeAddr) << " (" << CodeSize << " bytes, alignment " << CodeAlign @@ -120,13 +117,11 @@ public: } if (RODataSize != 0) { - if (auto AddrOrErr = Client.reserveMem(Id, RODataSize, RODataAlign)) - Unmapped.back().RemoteRODataAddr = *AddrOrErr; - else { - // FIXME; Add error to poll. - assert(!AddrOrErr.getError() && "Failed reserving remote memory."); - } - + std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRODataAddr, + Id, RODataSize, RODataAlign); + // FIXME; Add error to poll. + assert(!EC && "Failed reserving remote memory."); + (void)EC; DEBUG(dbgs() << " ro-data: " << format("0x%016x", Unmapped.back().RemoteRODataAddr) << " (" << RODataSize << " bytes, alignment " @@ -134,13 +129,11 @@ public: } if (RWDataSize != 0) { - if (auto AddrOrErr = Client.reserveMem(Id, RWDataSize, RWDataAlign)) - Unmapped.back().RemoteRWDataAddr = *AddrOrErr; - else { - // FIXME; Add error to poll. - assert(!AddrOrErr.getError() && "Failed reserving remote memory."); - } - + std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRWDataAddr, + Id, RWDataSize, RWDataAlign); + // FIXME; Add error to poll. + assert(!EC && "Failed reserving remote memory."); + (void)EC; DEBUG(dbgs() << " rw-data: " << format("0x%016x", Unmapped.back().RemoteRWDataAddr) << " (" << RWDataSize << " bytes, alignment " @@ -438,10 +431,8 @@ public: TargetAddress PtrBase; unsigned NumStubsEmitted; - if (auto StubInfoOrErr = Remote.emitIndirectStubs(Id, NewStubsRequired)) - std::tie(StubBase, PtrBase, NumStubsEmitted) = *StubInfoOrErr; - else - return StubInfoOrErr.getError(); + Remote.emitIndirectStubs(StubBase, PtrBase, NumStubsEmitted, Id, + NewStubsRequired); unsigned NewBlockId = RemoteIndirectStubsInfos.size(); RemoteIndirectStubsInfos.push_back({StubBase, PtrBase, NumStubsEmitted}); @@ -493,12 +484,8 @@ public: void grow() override { TargetAddress BlockAddr = 0; uint32_t NumTrampolines = 0; - if (auto TrampolineInfoOrErr = Remote.emitTrampolineBlock()) - std::tie(BlockAddr, NumTrampolines) = *TrampolineInfoOrErr; - else { - // FIXME: Return error. - llvm_unreachable("Failed to create trampolines"); - } + auto EC = Remote.emitTrampolineBlock(BlockAddr, NumTrampolines); + assert(!EC && "Failed to create trampolines"); uint32_t TrampolineSize = Remote.getTrampolineSize(); for (unsigned I = 0; I < NumTrampolines; ++I) @@ -516,33 +503,53 @@ public: OrcRemoteTargetClient H(Channel, EC); if (EC) return EC; - return ErrorOr<OrcRemoteTargetClient>(std::move(H)); + return H; } /// Call the int(void) function at the given address in the target and return /// its result. - ErrorOr<int> callIntVoid(TargetAddress Addr) { + std::error_code callIntVoid(int &Result, TargetAddress Addr) { DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) << "\n"); - auto Listen = - [&](RPCChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling<CallIntVoid>(Channel, Listen, Addr); + if (auto EC = call<CallIntVoid>(Channel, Addr)) + return EC; + + unsigned NextProcId; + if (auto EC = listenForCompileRequests(NextProcId)) + return EC; + + if (NextProcId != CallIntVoidResponseId) + return orcError(OrcErrorCode::UnexpectedRPCCall); + + return handle<CallIntVoidResponse>(Channel, [&](int R) { + Result = R; + DEBUG(dbgs() << "Result: " << R << "\n"); + return std::error_code(); + }); } /// Call the int(int, char*[]) function at the given address in the target and /// return its result. - ErrorOr<int> callMain(TargetAddress Addr, - const std::vector<std::string> &Args) { + std::error_code callMain(int &Result, TargetAddress Addr, + const std::vector<std::string> &Args) { DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x", Addr) << "\n"); - auto Listen = - [&](RPCChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling<CallMain>(Channel, Listen, Addr, Args); + if (auto EC = call<CallMain>(Channel, Addr, Args)) + return EC; + + unsigned NextProcId; + if (auto EC = listenForCompileRequests(NextProcId)) + return EC; + + if (NextProcId != CallMainResponseId) + return orcError(OrcErrorCode::UnexpectedRPCCall); + + return handle<CallMainResponse>(Channel, [&](int R) { + Result = R; + DEBUG(dbgs() << "Result: " << R << "\n"); + return std::error_code(); + }); } /// Call the void() function at the given address in the target and wait for @@ -551,11 +558,17 @@ public: DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr) << "\n"); - auto Listen = - [&](RPCChannel &C, JITFuncId Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling<CallVoidVoid>(Channel, Listen, Addr); + if (auto EC = call<CallVoidVoid>(Channel, Addr)) + return EC; + + unsigned NextProcId; + if (auto EC = listenForCompileRequests(NextProcId)) + return EC; + + if (NextProcId != CallVoidVoidResponseId) + return orcError(OrcErrorCode::UnexpectedRPCCall); + + return handle<CallVoidVoidResponse>(Channel, doNothing); } /// Create an RCMemoryManager which will allocate its memory on the remote @@ -565,7 +578,7 @@ public: assert(!MM && "MemoryManager should be null before creation."); auto Id = AllocatorIds.getNext(); - if (auto EC = callST<CreateRemoteAllocator>(Channel, Id)) + if (auto EC = call<CreateRemoteAllocator>(Channel, Id)) return EC; MM = llvm::make_unique<RCMemoryManager>(*this, Id); return std::error_code(); @@ -577,7 +590,7 @@ public: createIndirectStubsManager(std::unique_ptr<RCIndirectStubsManager> &I) { assert(!I && "Indirect stubs manager should be null before creation."); auto Id = IndirectStubOwnerIds.getNext(); - if (auto EC = callST<CreateIndirectStubsOwner>(Channel, Id)) + if (auto EC = call<CreateIndirectStubsOwner>(Channel, Id)) return EC; I = llvm::make_unique<RCIndirectStubsManager>(*this, Id); return std::error_code(); @@ -586,39 +599,45 @@ public: /// Search for symbols in the remote process. Note: This should be used by /// symbol resolvers *after* they've searched the local symbol table in the /// JIT stack. - ErrorOr<TargetAddress> getSymbolAddress(StringRef Name) { + std::error_code getSymbolAddress(TargetAddress &Addr, StringRef Name) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - return callST<GetSymbolAddress>(Channel, Name); + // Request remote symbol address. + if (auto EC = call<GetSymbolAddress>(Channel, Name)) + return EC; + + return expect<GetSymbolAddressResponse>(Channel, [&](TargetAddress &A) { + Addr = A; + DEBUG(dbgs() << "Remote address lookup " << Name << " = " + << format("0x%016x", Addr) << "\n"); + return std::error_code(); + }); } /// Get the triple for the remote target. const std::string &getTargetTriple() const { return RemoteTargetTriple; } - std::error_code terminateSession() { - return callST<TerminateSession>(Channel); - } + std::error_code terminateSession() { return call<TerminateSession>(Channel); } private: OrcRemoteTargetClient(ChannelT &Channel, std::error_code &EC) : Channel(Channel) { - if (auto RIOrErr = callST<GetRemoteInfo>(Channel)) { - std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize, - RemoteTrampolineSize, RemoteIndirectStubSize) = - *RIOrErr; - EC = std::error_code(); - } else - EC = RIOrErr.getError(); + if ((EC = call<GetRemoteInfo>(Channel))) + return; + + EC = expect<GetRemoteInfoResponse>( + Channel, readArgs(RemoteTargetTriple, RemotePointerSize, RemotePageSize, + RemoteTrampolineSize, RemoteIndirectStubSize)); } std::error_code deregisterEHFrames(TargetAddress Addr, uint32_t Size) { - return callST<RegisterEHFrames>(Channel, Addr, Size); + return call<RegisterEHFrames>(Channel, Addr, Size); } void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { - if (auto EC = callST<DestroyRemoteAllocator>(Channel, Id)) { + if (auto EC = call<DestroyRemoteAllocator>(Channel, Id)) { // FIXME: This will be triggered by a removeModuleSet call: Propagate // error return up through that. llvm_unreachable("Failed to destroy remote allocator."); @@ -628,13 +647,19 @@ private: std::error_code destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { IndirectStubOwnerIds.release(Id); - return callST<DestroyIndirectStubsOwner>(Channel, Id); + return call<DestroyIndirectStubsOwner>(Channel, Id); } - ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>> - emitIndirectStubs(ResourceIdMgr::ResourceId Id, - uint32_t NumStubsRequired) { - return callST<EmitIndirectStubs>(Channel, Id, NumStubsRequired); + std::error_code emitIndirectStubs(TargetAddress &StubBase, + TargetAddress &PtrBase, + uint32_t &NumStubsEmitted, + ResourceIdMgr::ResourceId Id, + uint32_t NumStubsRequired) { + if (auto EC = call<EmitIndirectStubs>(Channel, Id, NumStubsRequired)) + return EC; + + return expect<EmitIndirectStubsResponse>( + Channel, readArgs(StubBase, PtrBase, NumStubsEmitted)); } std::error_code emitResolverBlock() { @@ -642,16 +667,24 @@ private: if (ExistingError) return ExistingError; - return callST<EmitResolverBlock>(Channel); + return call<EmitResolverBlock>(Channel); } - ErrorOr<std::tuple<TargetAddress, uint32_t>> - emitTrampolineBlock() { + std::error_code emitTrampolineBlock(TargetAddress &BlockAddr, + uint32_t &NumTrampolines) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - return callST<EmitTrampolineBlock>(Channel); + if (auto EC = call<EmitTrampolineBlock>(Channel)) + return EC; + + return expect<EmitTrampolineBlockResponse>( + Channel, [&](TargetAddress BAddr, uint32_t NTrampolines) { + BlockAddr = BAddr; + NumTrampolines = NTrampolines; + return std::error_code(); + }); } uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; } @@ -660,46 +693,67 @@ private: uint32_t getTrampolineSize() const { return RemoteTrampolineSize; } - std::error_code listenForCompileRequests(RPCChannel &C, uint32_t &Id) { + std::error_code listenForCompileRequests(uint32_t &NextId) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - if (Id == RequestCompileId) { - if (auto EC = handle<RequestCompile>(C, CompileCallback)) + if (auto EC = getNextProcId(Channel, NextId)) + return EC; + + while (NextId == RequestCompileId) { + TargetAddress TrampolineAddr = 0; + if (auto EC = handle<RequestCompile>(Channel, readArgs(TrampolineAddr))) + return EC; + + TargetAddress ImplAddr = CompileCallback(TrampolineAddr); + if (auto EC = call<RequestCompileResponse>(Channel, ImplAddr)) + return EC; + + if (auto EC = getNextProcId(Channel, NextId)) return EC; - return std::error_code(); } - // else - return orcError(OrcErrorCode::UnexpectedRPCCall); + + return std::error_code(); } - ErrorOr<std::vector<char>> readMem(char *Dst, TargetAddress Src, uint64_t Size) { + std::error_code readMem(char *Dst, TargetAddress Src, uint64_t Size) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - return callST<ReadMem>(Channel, Src, Size); + if (auto EC = call<ReadMem>(Channel, Src, Size)) + return EC; + + if (auto EC = expect<ReadMemResponse>( + Channel, [&]() { return Channel.readBytes(Dst, Size); })) + return EC; + + return std::error_code(); } std::error_code registerEHFrames(TargetAddress &RAddr, uint32_t Size) { - return callST<RegisterEHFrames>(Channel, RAddr, Size); + return call<RegisterEHFrames>(Channel, RAddr, Size); } - ErrorOr<TargetAddress> reserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, - uint32_t Align) { + std::error_code reserveMem(TargetAddress &RemoteAddr, + ResourceIdMgr::ResourceId Id, uint64_t Size, + uint32_t Align) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return ExistingError; - return callST<ReserveMem>(Channel, Id, Size, Align); + if (std::error_code EC = call<ReserveMem>(Channel, Id, Size, Align)) + return EC; + + return expect<ReserveMemResponse>(Channel, readArgs(RemoteAddr)); } std::error_code setProtections(ResourceIdMgr::ResourceId Id, TargetAddress RemoteSegAddr, unsigned ProtFlags) { - return callST<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags); + return call<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags); } std::error_code writeMem(TargetAddress Addr, const char *Src, uint64_t Size) { @@ -707,7 +761,15 @@ private: if (ExistingError) return ExistingError; - return callST<WriteMem>(Channel, DirectBufferWriter(Src, Addr, Size)); + // Make the send call. + if (auto EC = call<WriteMem>(Channel, Addr, Size)) + return EC; + + // Follow this up with the section contents. + if (auto EC = Channel.appendBytes(Src, Size)) + return EC; + + return Channel.send(); } std::error_code writePointer(TargetAddress Addr, TargetAddress PtrVal) { @@ -715,7 +777,7 @@ private: if (ExistingError) return ExistingError; - return callST<WritePtr>(Channel, Addr, PtrVal); + return call<WritePtr>(Channel, Addr, PtrVal); } static std::error_code doNothing() { return std::error_code(); } diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index e9d4ac7af96..94327d0e320 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -24,48 +24,8 @@ namespace llvm { namespace orc { namespace remote { -class DirectBufferWriter { -public: - DirectBufferWriter() = default; - DirectBufferWriter(const char *Src, TargetAddress Dst, uint64_t Size) - : Src(Src), Dst(Dst), Size(Size) {} - - const char *getSrc() const { return Src; } - TargetAddress getDst() const { return Dst; } - uint64_t getSize() const { return Size; } -private: - const char *Src; - TargetAddress Dst; - uint64_t Size; -}; - -inline std::error_code serialize(RPCChannel &C, - const DirectBufferWriter &DBW) { - if (auto EC = serialize(C, DBW.getDst())) - return EC; - if (auto EC = serialize(C, DBW.getSize())) - return EC; - return C.appendBytes(DBW.getSrc(), DBW.getSize()); -} - -inline std::error_code deserialize(RPCChannel &C, - DirectBufferWriter &DBW) { - TargetAddress Dst; - if (auto EC = deserialize(C, Dst)) - return EC; - uint64_t Size; - if (auto EC = deserialize(C, Size)) - return EC; - char *Addr = reinterpret_cast<char*>(static_cast<uintptr_t>(Dst)); - - DBW = DirectBufferWriter(0, Dst, Size); - - return C.readBytes(Addr, Size); -} - class OrcRemoteTargetRPCAPI : public RPC<RPCChannel> { protected: - class ResourceIdMgr { public: typedef uint64_t ResourceId; @@ -85,111 +45,146 @@ protected: }; public: - enum JITFuncId : uint32_t { - InvalidId = RPCFunctionIdTraits<JITFuncId>::InvalidId, - CallIntVoidId = RPCFunctionIdTraits<JITFuncId>::FirstValidId, + enum JITProcId : uint32_t { + InvalidId = 0, + CallIntVoidId, + CallIntVoidResponseId, CallMainId, + CallMainResponseId, CallVoidVoidId, + CallVoidVoidResponseId, CreateRemoteAllocatorId, CreateIndirectStubsOwnerId, DeregisterEHFramesId, DestroyRemoteAllocatorId, DestroyIndirectStubsOwnerId, EmitIndirectStubsId, + EmitIndirectStubsResponseId, EmitResolverBlockId, EmitTrampolineBlockId, + EmitTrampolineBlockResponseId, GetSymbolAddressId, + GetSymbolAddressResponseId, GetRemoteInfoId, + GetRemoteInfoResponseId, ReadMemId, + ReadMemResponseId, RegisterEHFramesId, ReserveMemId, + ReserveMemResponseId, RequestCompileId, + RequestCompileResponseId, SetProtectionsId, TerminateSessionId, WriteMemId, WritePtrId }; - static const char *getJITFuncIdName(JITFuncId Id); + static const char *getJITProcIdName(JITProcId Id); + + typedef Procedure<CallIntVoidId, void(TargetAddress Addr)> CallIntVoid; - typedef Function<CallIntVoidId, int32_t(TargetAddress Addr)> CallIntVoid; + typedef Procedure<CallIntVoidResponseId, void(int Result)> + CallIntVoidResponse; - typedef Function<CallMainId, int32_t(TargetAddress Addr, - std::vector<std::string> Args)> + typedef Procedure<CallMainId, void(TargetAddress Addr, + std::vector<std::string> Args)> CallMain; - typedef Function<CallVoidVoidId, void(TargetAddress FnAddr)> CallVoidVoid; + typedef Procedure<CallMainResponseId, void(int Result)> CallMainResponse; + + typedef Procedure<CallVoidVoidId, void(TargetAddress FnAddr)> CallVoidVoid; + + typedef Procedure<CallVoidVoidResponseId, void()> CallVoidVoidResponse; - typedef Function<CreateRemoteAllocatorId, - void(ResourceIdMgr::ResourceId AllocatorID)> + typedef Procedure<CreateRemoteAllocatorId, + void(ResourceIdMgr::ResourceId AllocatorID)> CreateRemoteAllocator; - typedef Function<CreateIndirectStubsOwnerId, - void(ResourceIdMgr::ResourceId StubOwnerID)> + typedef Procedure<CreateIndirectStubsOwnerId, + void(ResourceIdMgr::ResourceId StubOwnerID)> CreateIndirectStubsOwner; - typedef Function<DeregisterEHFramesId, - void(TargetAddress Addr, uint32_t Size)> + typedef Procedure<DeregisterEHFramesId, + void(TargetAddress Addr, uint32_t Size)> DeregisterEHFrames; - typedef Function<DestroyRemoteAllocatorId, - void(ResourceIdMgr::ResourceId AllocatorID)> + typedef Procedure<DestroyRemoteAllocatorId, + void(ResourceIdMgr::ResourceId AllocatorID)> DestroyRemoteAllocator; - typedef Function<DestroyIndirectStubsOwnerId, - void(ResourceIdMgr::ResourceId StubsOwnerID)> + typedef Procedure<DestroyIndirectStubsOwnerId, + void(ResourceIdMgr::ResourceId StubsOwnerID)> DestroyIndirectStubsOwner; - /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted). - typedef Function<EmitIndirectStubsId, - std::tuple<TargetAddress, TargetAddress, uint32_t>( - ResourceIdMgr::ResourceId StubsOwnerID, - uint32_t NumStubsRequired)> + typedef Procedure<EmitIndirectStubsId, + void(ResourceIdMgr::ResourceId StubsOwnerID, + uint32_t NumStubsRequired)> EmitIndirectStubs; - typedef Function<EmitResolverBlockId, void()> EmitResolverBlock; + typedef Procedure<EmitIndirectStubsResponseId, + void(TargetAddress StubsBaseAddr, + TargetAddress PtrsBaseAddr, + uint32_t NumStubsEmitted)> + EmitIndirectStubsResponse; - /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines). - typedef Function<EmitTrampolineBlockId, - std::tuple<TargetAddress, uint32_t>()> EmitTrampolineBlock; + typedef Procedure<EmitResolverBlockId, void()> EmitResolverBlock; - typedef Function<GetSymbolAddressId, TargetAddress(std::string SymbolName)> + typedef Procedure<EmitTrampolineBlockId, void()> EmitTrampolineBlock; + + typedef Procedure<EmitTrampolineBlockResponseId, + void(TargetAddress BlockAddr, uint32_t NumTrampolines)> + EmitTrampolineBlockResponse; + + typedef Procedure<GetSymbolAddressId, void(std::string SymbolName)> GetSymbolAddress; - /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize, - /// IndirectStubsSize). - typedef Function<GetRemoteInfoId, - std::tuple<std::string, uint32_t, uint32_t, uint32_t, - uint32_t>()> GetRemoteInfo; + typedef Procedure<GetSymbolAddressResponseId, void(uint64_t SymbolAddr)> + GetSymbolAddressResponse; + + typedef Procedure<GetRemoteInfoId, void()> GetRemoteInfo; + + typedef Procedure<GetRemoteInfoResponseId, + void(std::string Triple, uint32_t PointerSize, + uint32_t PageSize, uint32_t TrampolineSize, + uint32_t IndirectStubSize)> + GetRemoteInfoResponse; - typedef Function<ReadMemId, - std::vector<char>(TargetAddress Src, uint64_t Size)> + typedef Procedure<ReadMemId, void(TargetAddress Src, uint64_t Size)> ReadMem; - typedef Function<RegisterEHFramesId, - void(TargetAddress Addr, uint32_t Size)> + typedef Procedure<ReadMemResponseId, void()> ReadMemResponse; + + typedef Procedure<RegisterEHFramesId, + void(TargetAddress Addr, uint32_t Size)> RegisterEHFrames; - typedef Function<ReserveMemId, - TargetAddress(ResourceIdMgr::ResourceId AllocID, - uint64_t Size, uint32_t Align)> + typedef Procedure<ReserveMemId, + void(ResourceIdMgr::ResourceId AllocID, uint64_t Size, + uint32_t Align)> ReserveMem; - typedef Function<RequestCompileId, - TargetAddress(TargetAddress TrampolineAddr)> + typedef Procedure<ReserveMemResponseId, void(TargetAddress Addr)> + ReserveMemResponse; + + typedef Procedure<RequestCompileId, void(TargetAddress TrampolineAddr)> RequestCompile; - typedef Function<SetProtectionsId, - void(ResourceIdMgr::ResourceId AllocID, TargetAddress Dst, - uint32_t ProtFlags)> + typedef Procedure<RequestCompileResponseId, void(TargetAddress ImplAddr)> + RequestCompileResponse; + + typedef Procedure<SetProtectionsId, + void(ResourceIdMgr::ResourceId AllocID, TargetAddress Dst, + uint32_t ProtFlags)> SetProtections; - typedef Function<TerminateSessionId, void()> TerminateSession; + typedef Procedure<TerminateSessionId, void()> TerminateSession; - typedef Function<WriteMemId, void(DirectBufferWriter DB)> + typedef Procedure<WriteMemId, + void(TargetAddress Dst, uint64_t Size /* Data to follow */)> WriteMem; - typedef Function<WritePtrId, void(TargetAddress Dst, TargetAddress Val)> + typedef Procedure<WritePtrId, void(TargetAddress Dst, TargetAddress Val)> WritePtr; }; diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index f15342dfea2..a6afd3183aa 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -45,14 +45,14 @@ public: EHFramesRegister(std::move(EHFramesRegister)), EHFramesDeregister(std::move(EHFramesDeregister)) {} - std::error_code getNextFuncId(JITFuncId &Id) { + std::error_code getNextProcId(JITProcId &Id) { return deserialize(Channel, Id); } - std::error_code handleKnownFunction(JITFuncId Id) { + std::error_code handleKnownProcedure(JITProcId Id) { typedef OrcRemoteTargetServer ThisT; - DEBUG(dbgs() << "Handling known proc: " << getJITFuncIdName(Id) << "\n"); + DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n"); switch (Id) { case CallIntVoidId: @@ -111,17 +111,27 @@ public: llvm_unreachable("Unhandled JIT RPC procedure Id."); } - ErrorOr<TargetAddress> requestCompile(TargetAddress TrampolineAddr) { - auto Listen = - [&](RPCChannel &C, uint32_t Id) { - return handleKnownFunction(static_cast<JITFuncId>(Id)); - }; + std::error_code requestCompile(TargetAddress &CompiledFnAddr, + TargetAddress TrampolineAddr) { + if (auto EC = call<RequestCompile>(Channel, TrampolineAddr)) + return EC; - return callSTHandling<RequestCompile>(Channel, Listen, TrampolineAddr); - } + while (1) { + JITProcId Id = InvalidId; + if (auto EC = getNextProcId(Id)) + return EC; - void handleTerminateSession() { - handle<TerminateSession>(Channel, [](){ return std::error_code(); }); + switch (Id) { + case RequestCompileResponseId: + return handle<RequestCompileResponse>(Channel, + readArgs(CompiledFnAddr)); + default: + if (auto EC = handleKnownProcedure(Id)) + return EC; + } + } + + llvm_unreachable("Fell through request-compile command loop."); } private: @@ -165,16 +175,18 @@ private: static std::error_code doNothing() { return std::error_code(); } static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) { + TargetAddress CompiledFnAddr = 0; + auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr); - auto AddrOrErr = T->requestCompile( - static_cast<TargetAddress>( - reinterpret_cast<uintptr_t>(TrampolineAddr))); - // FIXME: Allow customizable failure substitution functions. - assert(AddrOrErr && "Compile request failed"); - return *AddrOrErr; + auto EC = T->requestCompile( + CompiledFnAddr, static_cast<TargetAddress>( + reinterpret_cast<uintptr_t>(TrampolineAddr))); + assert(!EC && "Compile request failed"); + (void)EC; + return CompiledFnAddr; } - ErrorOr<int32_t> handleCallIntVoid(TargetAddress Addr) { + std::error_code handleCallIntVoid(TargetAddress Addr) { typedef int (*IntVoidFnTy)(); IntVoidFnTy Fn = reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr)); @@ -183,11 +195,11 @@ private: int Result = Fn(); DEBUG(dbgs() << " Result = " << Result << "\n"); - return Result; + return call<CallIntVoidResponse>(Channel, Result); } - ErrorOr<int32_t> handleCallMain(TargetAddress Addr, - std::vector<std::string> Args) { + std::error_code handleCallMain(TargetAddress Addr, + std::vector<std::string> Args) { typedef int (*MainFnTy)(int, const char *[]); MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr)); @@ -202,7 +214,7 @@ private: int Result = Fn(ArgC, ArgV.get()); DEBUG(dbgs() << " Result = " << Result << "\n"); - return Result; + return call<CallMainResponse>(Channel, Result); } std::error_code handleCallVoidVoid(TargetAddress Addr) { @@ -214,7 +226,7 @@ private: Fn(); DEBUG(dbgs() << " Complete.\n"); - return std::error_code(); + return call<CallVoidVoidResponse>(Channel); } std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) { @@ -261,9 +273,8 @@ private: return std::error_code(); } - ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>> - handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id, - uint32_t NumStubsRequired) { + std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id, + uint32_t NumStubsRequired) { DEBUG(dbgs() << " ISMgr " << Id << " request " << NumStubsRequired << " stubs.\n"); @@ -285,7 +296,8 @@ private: auto &BlockList = StubOwnerItr->second; BlockList.push_back(std::move(IS)); - return std::make_tuple(StubsBase, PtrsBase, NumStubsEmitted); + return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase, + NumStubsEmitted); } std::error_code handleEmitResolverBlock() { @@ -304,8 +316,7 @@ private: sys::Memory::MF_EXEC); } - ErrorOr<std::tuple<TargetAddress, uint32_t>> - handleEmitTrampolineBlock() { + std::error_code handleEmitTrampolineBlock() { std::error_code EC; auto TrampolineBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( @@ -314,7 +325,7 @@ private: if (EC) return EC; - uint32_t NumTrampolines = + unsigned NumTrampolines = (sys::Process::getPageSize() - TargetT::PointerSize) / TargetT::TrampolineSize; @@ -328,21 +339,20 @@ private: TrampolineBlocks.push_back(std::move(TrampolineBlock)); - auto TrampolineBaseAddr = - static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)); - - return std::make_tuple(TrampolineBaseAddr, NumTrampolines); + return call<EmitTrampolineBlockResponse>( + Channel, + static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)), + NumTrampolines); } - ErrorOr<TargetAddress> handleGetSymbolAddress(const std::string &Name) { + std::error_code handleGetSymbolAddress(const std::string &Name) { TargetAddress Addr = SymbolLookup(Name); DEBUG(dbgs() << " Symbol '" << Name << "' = " << format("0x%016x", Addr) << "\n"); - return Addr; + return call<GetSymbolAddressResponse>(Channel, Addr); } - ErrorOr<std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>> - handleGetRemoteInfo() { + std::error_code handleGetRemoteInfo() { std::string ProcessTriple = sys::getProcessTriple(); uint32_t PointerSize = TargetT::PointerSize; uint32_t PageSize = sys::Process::getPageSize(); @@ -354,23 +364,24 @@ private: << " page size = " << PageSize << "\n" << " trampoline size = " << TrampolineSize << "\n" << " indirect stub size = " << IndirectStubSize << "\n"); - return std::make_tuple(ProcessTriple, PointerSize, PageSize ,TrampolineSize, - IndirectStubSize); + return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize, + PageSize, TrampolineSize, + IndirectStubSize); } - ErrorOr<std::vector<char>> - handleReadMem(TargetAddress RSrc, uint64_t Size) { + std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) { char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc)); DEBUG(dbgs() << " Reading " << Size << " bytes from " << format("0x%016x", RSrc) << "\n"); - std::vector<char> Buffer; - Buffer.resize(Size); - for (char *P = Src; Size != 0; --Size) - Buffer.push_back(*P++); + if (auto EC = call<ReadMemResponse>(Channel)) + return EC; + + if (auto EC = Channel.appendBytes(Src, Size)) + return EC; - return Buffer; + return Channel.send(); } std::error_code handleRegisterEHFrames(TargetAddress TAddr, uint32_t Size) { @@ -381,9 +392,8 @@ private: return std::error_code(); } - ErrorOr<TargetAddress> - handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, - uint32_t Align) { + std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, + uint32_t Align) { auto I = Allocators.find(Id); if (I == Allocators.end()) return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); @@ -398,7 +408,7 @@ private: TargetAddress AllocAddr = static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr)); - return AllocAddr; + return call<ReserveMemResponse>(Channel, AllocAddr); } std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id, @@ -415,10 +425,11 @@ private: return Allocator.setProtections(LocalAddr, Flags); } - std::error_code handleWriteMem(DirectBufferWriter DBW) { - DEBUG(dbgs() << " Writing " << DBW.getSize() << " bytes to " - << format("0x%016x", DBW.getDst()) << "\n"); - return std::error_code(); + std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) { + char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst)); + DEBUG(dbgs() << " Writing " << Size << " bytes to " + << format("0x%016x", RDst) << "\n"); + return Channel.readBytes(Dst, Size); } std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) { diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h index 98314aea163..b97b6daf586 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCChannel.h @@ -5,10 +5,8 @@ #include "OrcError.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Endian.h" -#include <mutex> #include <system_error> namespace llvm { @@ -28,68 +26,31 @@ public: /// Flush the stream if possible. virtual std::error_code send() = 0; - - /// Get the lock for stream reading. - std::mutex& getReadLock() { return readLock; } - - /// Get the lock for stream writing. - std::mutex& getWriteLock() { return writeLock; } - -private: - std::mutex readLock, writeLock; }; -/// Notify the channel that we're starting a message send. -/// Locks the channel for writing. -inline std::error_code startSendMessage(RPCChannel &C) { - C.getWriteLock().lock(); - return std::error_code(); -} - -/// Notify the channel that we're ending a message send. -/// Unlocks the channel for writing. -inline std::error_code endSendMessage(RPCChannel &C) { - C.getWriteLock().unlock(); - return std::error_code(); -} - -/// Notify the channel that we're starting a message receive. -/// Locks the channel for reading. -inline std::error_code startReceiveMessage(RPCChannel &C) { - C.getReadLock().lock(); - return std::error_code(); -} - -/// Notify the channel that we're ending a message receive. -/// Unlocks the channel for reading. -inline std::error_code endReceiveMessage(RPCChannel &C) { - C.getReadLock().unlock(); - return std::error_code(); -} - /// RPC channel serialization for a variadic list of arguments. template <typename T, typename... Ts> -std::error_code serializeSeq(RPCChannel &C, const T &Arg, const Ts &... Args) { +std::error_code serialize_seq(RPCChannel &C, const T &Arg, const Ts &... Args) { if (auto EC = serialize(C, Arg)) return EC; - return serializeSeq(C, Args...); + return serialize_seq(C, Args...); } /// RPC channel serialization for an (empty) variadic list of arguments. -inline std::error_code serializeSeq(RPCChannel &C) { +inline std::error_code serialize_seq(RPCChannel &C) { return std::error_code(); } /// RPC channel deserialization for a variadic list of arguments. template <typename T, typename... Ts> -std::error_code deserializeSeq(RPCChannel &C, T &Arg, Ts &... Args) { +std::error_code deserialize_seq(RPCChannel &C, T &Arg, Ts &... Args) { if (auto EC = deserialize(C, Arg)) return EC; - return deserializeSeq(C, Args...); + return deserialize_seq(C, Args...); } /// RPC channel serialization for an (empty) variadic list of arguments. -inline std::error_code deserializeSeq(RPCChannel &C) { +inline std::error_code deserialize_seq(RPCChannel &C) { return std::error_code(); } @@ -177,34 +138,6 @@ inline std::error_code deserialize(RPCChannel &C, std::string &S) { return C.readBytes(&S[0], Count); } -// Serialization helper for std::tuple. -template <typename TupleT, size_t... Is> -inline std::error_code serializeTupleHelper(RPCChannel &C, - const TupleT &V, - llvm::index_sequence<Is...> _) { - return serializeSeq(C, std::get<Is>(V)...); -} - -/// RPC channel serialization for std::tuple. -template <typename... ArgTs> -inline std::error_code serialize(RPCChannel &C, const std::tuple<ArgTs...> &V) { - return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); -} - -// Serialization helper for std::tuple. -template <typename TupleT, size_t... Is> -inline std::error_code deserializeTupleHelper(RPCChannel &C, - TupleT &V, - llvm::index_sequence<Is...> _) { - return deserializeSeq(C, std::get<Is>(V)...); -} - -/// RPC channel deserialization for std::tuple. -template <typename... ArgTs> -inline std::error_code deserialize(RPCChannel &C, std::tuple<ArgTs...> &V) { - return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); -} - /// RPC channel serialization for ArrayRef<T>. template <typename T> std::error_code serialize(RPCChannel &C, const ArrayRef<T> &A) { diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 33ac997c09e..d1b8546268f 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -14,197 +14,46 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/OrcError.h" -#include "llvm/Support/ErrorOr.h" -#include <future> -#include <map> namespace llvm { namespace orc { namespace remote { -/// Describes reserved RPC Function Ids. -/// -/// The default implementation will serve for integer and enum function id -/// types. If you want to use a custom type as your FunctionId you can -/// specialize this class and provide unique values for InvalidId, -/// ResponseId and FirstValidId. - -template <typename T> -class RPCFunctionIdTraits { -public: - static constexpr T InvalidId = static_cast<T>(0); - static constexpr T ResponseId = static_cast<T>(1); - static constexpr T FirstValidId = static_cast<T>(2); -}; - // Base class containing utilities that require partial specialization. // These cannot be included in RPC, as template class members cannot be // partially specialized. class RPCBase { protected: - - // RPC Function description type. - // - // This class provides the information and operations needed to support the - // RPC primitive operations (call, expect, etc) for a given function. It - // is specialized for void and non-void functions to deal with the differences - // betwen the two. Both specializations have the same interface: - // - // Id - The function's unique identifier. - // OptionalReturn - The return type for asyncronous calls. - // ErrorReturn - The return type for synchronous calls. - // optionalToErrorReturn - Conversion from a valid OptionalReturn to an - // ErrorReturn. - // readResult - Deserialize a result from a channel. - // abandon - Abandon a promised (asynchronous) result. - // respond - Retun a result on the channel. - template <typename FunctionIdT, FunctionIdT FuncId, typename FnT> - class FunctionHelper {}; - - // RPC Function description specialization for non-void functions. - template <typename FunctionIdT, FunctionIdT FuncId, typename RetT, - typename... ArgTs> - class FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> { + template <typename ProcedureIdT, ProcedureIdT ProcId, typename FnT> + class ProcedureHelper { public: - - static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && - FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, - "Cannot define custom function with InvalidId or ResponseId. " - "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); - - static const FunctionIdT Id = FuncId; - - typedef Optional<RetT> OptionalReturn; - - typedef ErrorOr<RetT> ErrorReturn; - - static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { - assert(V && "Return value not available"); - return std::move(*V); - } - - template <typename ChannelT> - static std::error_code readResult(ChannelT &C, - std::promise<OptionalReturn> &P) { - RetT Val; - auto EC = deserialize(C, Val); - // FIXME: Join error EC2 from endReceiveMessage with the deserialize - // error once we switch to using Error. - auto EC2 = endReceiveMessage(C); - (void)EC2; - - if (EC) { - P.set_value(OptionalReturn()); - return EC; - } - P.set_value(std::move(Val)); - return std::error_code(); - } - - static void abandon(std::promise<OptionalReturn> &P) { - P.set_value(OptionalReturn()); - } - - template <typename ChannelT, typename SequenceNumberT> - static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo, - const ErrorReturn &Result) { - FunctionIdT ResponseId = - RPCFunctionIdTraits<FunctionIdT>::ResponseId; - - // If the handler returned an error then bail out with that. - if (!Result) - return Result.getError(); - - // Otherwise open a new message on the channel and send the result. - if (auto EC = startSendMessage(C)) - return EC; - if (auto EC = serializeSeq(C, ResponseId, SeqNo, *Result)) - return EC; - return endSendMessage(C); - } + static const ProcedureIdT Id = ProcId; }; - // RPC Function description specialization for void functions. - template <typename FunctionIdT, FunctionIdT FuncId, typename... ArgTs> - class FunctionHelper<FunctionIdT, FuncId, void(ArgTs...)> { - public: - - static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && - FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, - "Cannot define custom function with InvalidId or ResponseId. " - "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); - - static const FunctionIdT Id = FuncId; - - typedef bool OptionalReturn; - typedef std::error_code ErrorReturn; - - static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { - assert(V && "Return value not available"); - return std::error_code(); - } + template <typename ChannelT, typename Proc> class CallHelper; - template <typename ChannelT> - static std::error_code readResult(ChannelT &C, - std::promise<OptionalReturn> &P) { - // Void functions don't have anything to deserialize, so we're good. - P.set_value(true); - return endReceiveMessage(C); - } - - static void abandon(std::promise<OptionalReturn> &P) { - P.set_value(false); - } - - template <typename ChannelT, typename SequenceNumberT> - static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo, - const ErrorReturn &Result) { - const FunctionIdT ResponseId = - RPCFunctionIdTraits<FunctionIdT>::ResponseId; - - // If the handler returned an error then bail out with that. - if (Result) - return Result; - - // Otherwise open a new message on the channel and send the result. - if (auto EC = startSendMessage(C)) - return EC; - if (auto EC = serializeSeq(C, ResponseId, SeqNo)) - return EC; - return endSendMessage(C); - } - }; - - // Helper for the call primitive. - template <typename ChannelT, typename SequenceNumberT, typename Func> - class CallHelper; - - template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, - FunctionIdT FuncId, typename RetT, typename... ArgTs> - class CallHelper<ChannelT, SequenceNumberT, - FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { + template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, + typename... ArgTs> + class CallHelper<ChannelT, + ProcedureHelper<ProcedureIdT, ProcId, void(ArgTs...)>> { public: - static std::error_code call(ChannelT &C, SequenceNumberT SeqNo, - const ArgTs &... Args) { - if (auto EC = startSendMessage(C)) - return EC; - if (auto EC = serializeSeq(C, FuncId, SeqNo, Args...)) + static std::error_code call(ChannelT &C, const ArgTs &... Args) { + if (auto EC = serialize(C, ProcId)) return EC; - return endSendMessage(C); + // If you see a compile-error on this line you're probably calling a + // function with the wrong signature. + return serialize_seq(C, Args...); } }; - // Helper for handle primitive. - template <typename ChannelT, typename SequenceNumberT, typename Func> - class HandlerHelper; + template <typename ChannelT, typename Proc> class HandlerHelper; - template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, - FunctionIdT FuncId, typename RetT, typename... ArgTs> - class HandlerHelper<ChannelT, SequenceNumberT, - FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { + template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, + typename... ArgTs> + class HandlerHelper<ChannelT, + ProcedureHelper<ProcedureIdT, ProcId, void(ArgTs...)>> { public: template <typename HandlerT> static std::error_code handle(ChannelT &C, HandlerT Handler) { @@ -212,46 +61,34 @@ protected: } private: - - typedef FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> Func; - template <typename HandlerT, size_t... Is> static std::error_code readAndHandle(ChannelT &C, HandlerT Handler, llvm::index_sequence<Is...> _) { std::tuple<ArgTs...> RPCArgs; - SequenceNumberT SeqNo; // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning // for RPCArgs. Void cast RPCArgs to work around this for now. // FIXME: Remove this workaround once we can assume a working GCC version. (void)RPCArgs; - if (auto EC = deserializeSeq(C, SeqNo, std::get<Is>(RPCArgs)...)) - return EC; - - // We've deserialized the arguments, so unlock the channel for reading - // before we call the handler. This allows recursive RPC calls. - if (auto EC = endReceiveMessage(C)) + if (auto EC = deserialize_seq(C, std::get<Is>(RPCArgs)...)) return EC; - - return Func::template respond<ChannelT, SequenceNumberT>( - C, SeqNo, Handler(std::get<Is>(RPCArgs)...)); + return Handler(std::get<Is>(RPCArgs)...); } - }; - // Helper for wrapping member functions up as functors. - template <typename ClassT, typename RetT, typename... ArgTs> - class MemberFnWrapper { + template <typename ClassT, typename... ArgTs> class MemberFnWrapper { public: - typedef RetT(ClassT::*MethodT)(ArgTs...); + typedef std::error_code (ClassT::*MethodT)(ArgTs...); MemberFnWrapper(ClassT &Instance, MethodT Method) : Instance(Instance), Method(Method) {} - RetT operator()(ArgTs &... Args) { return (Instance.*Method)(Args...); } + std::error_code operator()(ArgTs &... Args) { + return (Instance.*Method)(Args...); + } + private: ClassT &Instance; MethodT Method; }; - // Helper that provides a Functor for deserializing arguments. template <typename... ArgTs> class ReadArgs { public: std::error_code operator()() { return std::error_code(); } @@ -275,7 +112,7 @@ protected: /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a procedure +/// RPCChannel interface (see RPCChannel.h), and ProcedureIdT is a procedure /// identifier type that must be serializable on ChannelT. /// /// These utilities support the construction of very primitive RPC utilities. @@ -292,184 +129,120 @@ protected: /// /// Overview (see comments individual types/methods for details): /// -/// Function<Id, Args...> : +/// Procedure<Id, Args...> : /// /// associates a unique serializable id with an argument list. /// /// -/// call<Func>(Channel, Args...) : +/// call<Proc>(Channel, Args...) : /// -/// Calls the remote procedure 'Func' by serializing Func's id followed by its +/// Calls the remote procedure 'Proc' by serializing Proc's id followed by its /// arguments and sending the resulting bytes to 'Channel'. /// /// -/// handle<Func>(Channel, <functor matching std::error_code(Args...)> : +/// handle<Proc>(Channel, <functor matching std::error_code(Args...)> : /// -/// Handles a call to 'Func' by deserializing its arguments and calling the -/// given functor. This assumes that the id for 'Func' has already been +/// Handles a call to 'Proc' by deserializing its arguments and calling the +/// given functor. This assumes that the id for 'Proc' has already been /// deserialized. /// -/// expect<Func>(Channel, <functor matching std::error_code(Args...)> : +/// expect<Proc>(Channel, <functor matching std::error_code(Args...)> : /// /// The same as 'handle', except that the procedure id should not have been -/// read yet. Expect will deserialize the id and assert that it matches Func's +/// read yet. Expect will deserialize the id and assert that it matches Proc's /// id. If it does not, and unexpected RPC call error is returned. -template <typename ChannelT, typename FunctionIdT = uint32_t, - typename SequenceNumberT = uint16_t> + +template <typename ChannelT, typename ProcedureIdT = uint32_t> class RPC : public RPCBase { public: - - RPC() = default; - RPC(const RPC&) = delete; - RPC& operator=(const RPC&) = delete; - RPC(RPC &&Other) : SequenceNumberMgr(std::move(Other.SequenceNumberMgr)), OutstandingResults(std::move(Other.OutstandingResults)) {} - RPC& operator=(RPC&&) = default; - /// Utility class for defining/referring to RPC procedures. /// /// Typedefs of this utility are used when calling/handling remote procedures. /// - /// FuncId should be a unique value of FunctionIdT (i.e. not used with any - /// other Function typedef in the RPC API being defined. + /// ProcId should be a unique value of ProcedureIdT (i.e. not used with any + /// other Procedure typedef in the RPC API being defined. /// /// the template argument Ts... gives the argument list for the remote /// procedure. /// /// E.g. /// - /// typedef Function<0, bool> Func1; - /// typedef Function<1, std::string, std::vector<int>> Func2; + /// typedef Procedure<0, bool> Proc1; + /// typedef Procedure<1, std::string, std::vector<int>> Proc2; /// - /// if (auto EC = call<Func1>(Channel, true)) + /// if (auto EC = call<Proc1>(Channel, true)) /// /* handle EC */; /// - /// if (auto EC = expect<Func2>(Channel, + /// if (auto EC = expect<Proc2>(Channel, /// [](std::string &S, std::vector<int> &V) { /// // Stuff. /// return std::error_code(); /// }) /// /* handle EC */; /// - template <FunctionIdT FuncId, typename FnT> - using Function = FunctionHelper<FunctionIdT, FuncId, FnT>; - - /// Return type for asynchronous call primitives. - template <typename Func> - using AsyncCallResult = - std::pair<std::future<typename Func::OptionalReturn>, SequenceNumberT>; + template <ProcedureIdT ProcId, typename FnT> + using Procedure = ProcedureHelper<ProcedureIdT, ProcId, FnT>; /// Serialize Args... to channel C, but do not call C.send(). /// - /// For void functions returns a std::future<Error>. For functions that - /// return an R, returns a std::future<Optional<R>>. - template <typename Func, typename... ArgTs> - ErrorOr<AsyncCallResult<Func>> - appendCallAsync(ChannelT &C, const ArgTs &... Args) { - auto SeqNo = SequenceNumberMgr.getSequenceNumber(); - std::promise<typename Func::OptionalReturn> Promise; - auto Result = Promise.get_future(); - OutstandingResults[SeqNo] = std::move(Promise); - - if (auto EC = - CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, - Args...)) { - abandonOutstandingResults(); - return EC; - } else - return AsyncCallResult<Func>(std::move(Result), SeqNo); + /// For buffered channels, this can be used to queue up several calls before + /// flushing the channel. + template <typename Proc, typename... ArgTs> + static std::error_code appendCall(ChannelT &C, const ArgTs &... Args) { + return CallHelper<ChannelT, Proc>::call(C, Args...); } /// Serialize Args... to channel C and call C.send(). - template <typename Func, typename... ArgTs> - ErrorOr<AsyncCallResult<Func>> - callAsync(ChannelT &C, const ArgTs &... Args) { - auto SeqNo = SequenceNumberMgr.getSequenceNumber(); - std::promise<typename Func::OptionalReturn> Promise; - auto Result = Promise.get_future(); - OutstandingResults[SeqNo] = - createOutstandingResult<Func>(std::move(Promise)); - if (auto EC = - CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, Args...)) { - abandonOutstandingResults(); - return EC; - } - if (auto EC = C.send()) { - abandonOutstandingResults(); + template <typename Proc, typename... ArgTs> + static std::error_code call(ChannelT &C, const ArgTs &... Args) { + if (auto EC = appendCall<Proc>(C, Args...)) return EC; - } - return AsyncCallResult<Func>(std::move(Result), SeqNo); - } - - /// This can be used in single-threaded mode. - template <typename Func, typename HandleFtor, typename... ArgTs> - typename Func::ErrorReturn - callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) { - if (auto ResultAndSeqNoOrErr = callAsync<Func>(C, Args...)) { - auto &ResultAndSeqNo = *ResultAndSeqNoOrErr; - if (auto EC = waitForResult(C, ResultAndSeqNo.second, HandleOther)) - return EC; - return Func::optionalToErrorReturn(ResultAndSeqNo.first.get()); - } else - return ResultAndSeqNoOrErr.getError(); - } - - // This can be used in single-threaded mode. - template <typename Func, typename... ArgTs> - typename Func::ErrorReturn - callST(ChannelT &C, const ArgTs &... Args) { - return callSTHandling<Func>(C, handleNone, Args...); + return C.send(); } - /// Start receiving a new function call. - /// - /// Calls startReceiveMessage on the channel, then deserializes a FunctionId - /// into Id. - std::error_code startReceivingFunction(ChannelT &C, FunctionIdT &Id) { - if (auto EC = startReceiveMessage(C)) - return EC; - + /// Deserialize and return an enum whose underlying type is ProcedureIdT. + static std::error_code getNextProcId(ChannelT &C, ProcedureIdT &Id) { return deserialize(C, Id); } - /// Deserialize args for Func from C and call Handler. The signature of + /// Deserialize args for Proc from C and call Handler. The signature of /// handler must conform to 'std::error_code(Args...)' where Args... matches - /// the arguments used in the Func typedef. - template <typename Func, typename HandlerT> + /// the arguments used in the Proc typedef. + template <typename Proc, typename HandlerT> static std::error_code handle(ChannelT &C, HandlerT Handler) { - return HandlerHelper<ChannelT, SequenceNumberT, Func>::handle(C, Handler); + return HandlerHelper<ChannelT, Proc>::handle(C, Handler); } /// Helper version of 'handle' for calling member functions. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + template <typename Proc, typename ClassT, typename... ArgTs> static std::error_code handle(ChannelT &C, ClassT &Instance, - RetT (ClassT::*HandlerMethod)(ArgTs...)) { - return handle<Func>( - C, - MemberFnWrapper<ClassT, RetT, ArgTs...>(Instance, HandlerMethod)); + std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { + return handle<Proc>( + C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); } - /// Deserialize a FunctionIdT from C and verify it matches the id for Func. + /// Deserialize a ProcedureIdT from C and verify it matches the id for Proc. /// If the id does match, deserialize the arguments and call the handler /// (similarly to handle). /// If the id does not match, return an unexpect RPC call error and do not /// deserialize any further bytes. - template <typename Func, typename HandlerT> - std::error_code expect(ChannelT &C, HandlerT Handler) { - FunctionIdT FuncId; - if (auto EC = startReceivingFunction(C, FuncId)) - return std::move(EC); - if (FuncId != Func::Id) + template <typename Proc, typename HandlerT> + static std::error_code expect(ChannelT &C, HandlerT Handler) { + ProcedureIdT ProcId; + if (auto EC = getNextProcId(C, ProcId)) + return EC; + if (ProcId != Proc::Id) return orcError(OrcErrorCode::UnexpectedRPCCall); - return handle<Func>(C, Handler); + return handle<Proc>(C, Handler); } /// Helper version of expect for calling member functions. - template <typename Func, typename ClassT, typename... ArgTs> + template <typename Proc, typename ClassT, typename... ArgTs> static std::error_code expect(ChannelT &C, ClassT &Instance, std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { - return expect<Func>( + return expect<Proc>( C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); } @@ -478,163 +251,18 @@ public: /// channel. /// E.g. /// - /// typedef Function<0, bool, int> Func1; + /// typedef Procedure<0, bool, int> Proc1; /// /// ... /// bool B; /// int I; - /// if (auto EC = expect<Func1>(Channel, readArgs(B, I))) + /// if (auto EC = expect<Proc1>(Channel, readArgs(B, I))) /// /* Handle Args */ ; /// template <typename... ArgTs> static ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { return ReadArgs<ArgTs...>(Args...); } - - /// Read a response from Channel. - /// This should be called from the receive loop to retrieve results. - std::error_code handleResponse(ChannelT &C, SequenceNumberT &SeqNo) { - if (auto EC = deserialize(C, SeqNo)) { - abandonOutstandingResults(); - return EC; - } - - auto I = OutstandingResults.find(SeqNo); - if (I == OutstandingResults.end()) { - abandonOutstandingResults(); - return orcError(OrcErrorCode::UnexpectedRPCResponse); - } - - if (auto EC = I->second->readResult(C)) { - abandonOutstandingResults(); - // FIXME: Release sequence numbers? - return EC; - } - - OutstandingResults.erase(I); - SequenceNumberMgr.releaseSequenceNumber(SeqNo); - - return std::error_code(); - } - - // Loop waiting for a result with the given sequence number. - // This can be used as a receive loop if the user doesn't have a default. - template <typename HandleOtherFtor> - std::error_code waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo, - HandleOtherFtor &HandleOther = handleNone) { - bool GotTgtResult = false; - - while (!GotTgtResult) { - FunctionIdT Id = - RPCFunctionIdTraits<FunctionIdT>::InvalidId; - if (auto EC = startReceivingFunction(C, Id)) - return EC; - if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) { - SequenceNumberT SeqNo; - if (auto EC = handleResponse(C, SeqNo)) - return EC; - GotTgtResult = (SeqNo == TgtSeqNo); - } else if (auto EC = HandleOther(C, Id)) - return EC; - } - - return std::error_code(); - } - - // Default handler for 'other' (non-response) functions when waiting for a - // result from the channel. - static std::error_code handleNone(ChannelT&, FunctionIdT) { - return orcError(OrcErrorCode::UnexpectedRPCCall); - }; - -private: - - // Manage sequence numbers. - class SequenceNumberManager { - public: - - SequenceNumberManager() = default; - - SequenceNumberManager(SequenceNumberManager &&Other) - : NextSequenceNumber(std::move(Other.NextSequenceNumber)), - FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {} - - SequenceNumberManager& operator=(SequenceNumberManager &&Other) { - NextSequenceNumber = std::move(Other.NextSequenceNumber); - FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers); - } - - void reset() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - NextSequenceNumber = 0; - FreeSequenceNumbers.clear(); - } - - SequenceNumberT getSequenceNumber() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - if (FreeSequenceNumbers.empty()) - return NextSequenceNumber++; - auto SequenceNumber = FreeSequenceNumbers.back(); - FreeSequenceNumbers.pop_back(); - return SequenceNumber; - } - - void releaseSequenceNumber(SequenceNumberT SequenceNumber) { - std::lock_guard<std::mutex> Lock(SeqNoLock); - FreeSequenceNumbers.push_back(SequenceNumber); - } - - private: - std::mutex SeqNoLock; - SequenceNumberT NextSequenceNumber = 0; - std::vector<SequenceNumberT> FreeSequenceNumbers; - }; - - // Base class for results that haven't been returned from the other end of the - // RPC connection yet. - class OutstandingResult { - public: - virtual ~OutstandingResult() {} - virtual std::error_code readResult(ChannelT &C) = 0; - virtual void abandon() = 0; - }; - - // Outstanding results for a specific function. - template <typename Func> - class OutstandingResultImpl : public OutstandingResult { - private: - public: - OutstandingResultImpl(std::promise<typename Func::OptionalReturn> &&P) - : P(std::move(P)) {} - - std::error_code readResult(ChannelT &C) override { - return Func::readResult(C, P); - } - - void abandon() override { Func::abandon(P); } - - private: - std::promise<typename Func::OptionalReturn> P; - }; - - // Create an outstanding result for the given function. - template <typename Func> - std::unique_ptr<OutstandingResult> - createOutstandingResult(std::promise<typename Func::OptionalReturn> &&P) { - return llvm::make_unique<OutstandingResultImpl<Func>>(std::move(P)); - } - - // Abandon all outstanding results. - void abandonOutstandingResults() { - for (auto &KV : OutstandingResults) - KV.second->abandon(); - OutstandingResults.clear(); - SequenceNumberMgr.reset(); - } - - SequenceNumberManager SequenceNumberMgr; - std::map<SequenceNumberT, std::unique_ptr<OutstandingResult>> - OutstandingResults; }; } // end namespace remote diff --git a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp index 5e12c86c704..e95115ec6fe 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp @@ -38,8 +38,6 @@ public: return "Remote indirect stubs owner Id already in use"; case OrcErrorCode::UnexpectedRPCCall: return "Unexpected RPC call"; - case OrcErrorCode::UnexpectedRPCResponse: - return "Unexpected RPC response"; } llvm_unreachable("Unhandled error code"); } diff --git a/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp b/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp index d1a021aee3a..81e51a83021 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp @@ -13,40 +13,50 @@ namespace llvm { namespace orc { namespace remote { -#define FUNCNAME(X) \ +#define PROCNAME(X) \ case X ## Id: \ return #X -const char *OrcRemoteTargetRPCAPI::getJITFuncIdName(JITFuncId Id) { +const char *OrcRemoteTargetRPCAPI::getJITProcIdName(JITProcId Id) { switch (Id) { case InvalidId: - return "*** Invalid JITFuncId ***"; - FUNCNAME(CallIntVoid); - FUNCNAME(CallMain); - FUNCNAME(CallVoidVoid); - FUNCNAME(CreateRemoteAllocator); - FUNCNAME(CreateIndirectStubsOwner); - FUNCNAME(DeregisterEHFrames); - FUNCNAME(DestroyRemoteAllocator); - FUNCNAME(DestroyIndirectStubsOwner); - FUNCNAME(EmitIndirectStubs); - FUNCNAME(EmitResolverBlock); - FUNCNAME(EmitTrampolineBlock); - FUNCNAME(GetSymbolAddress); - FUNCNAME(GetRemoteInfo); - FUNCNAME(ReadMem); - FUNCNAME(RegisterEHFrames); - FUNCNAME(ReserveMem); - FUNCNAME(RequestCompile); - FUNCNAME(SetProtections); - FUNCNAME(TerminateSession); - FUNCNAME(WriteMem); - FUNCNAME(WritePtr); + return "*** Invalid JITProcId ***"; + PROCNAME(CallIntVoid); + PROCNAME(CallIntVoidResponse); + PROCNAME(CallMain); + PROCNAME(CallMainResponse); + PROCNAME(CallVoidVoid); + PROCNAME(CallVoidVoidResponse); + PROCNAME(CreateRemoteAllocator); + PROCNAME(CreateIndirectStubsOwner); + PROCNAME(DeregisterEHFrames); + PROCNAME(DestroyRemoteAllocator); + PROCNAME(DestroyIndirectStubsOwner); + PROCNAME(EmitIndirectStubs); + PROCNAME(EmitIndirectStubsResponse); + PROCNAME(EmitResolverBlock); + PROCNAME(EmitTrampolineBlock); + PROCNAME(EmitTrampolineBlockResponse); + PROCNAME(GetSymbolAddress); + PROCNAME(GetSymbolAddressResponse); + PROCNAME(GetRemoteInfo); + PROCNAME(GetRemoteInfoResponse); + PROCNAME(ReadMem); + PROCNAME(ReadMemResponse); + PROCNAME(RegisterEHFrames); + PROCNAME(ReserveMem); + PROCNAME(ReserveMemResponse); + PROCNAME(RequestCompile); + PROCNAME(RequestCompileResponse); + PROCNAME(SetProtections); + PROCNAME(TerminateSession); + PROCNAME(WriteMem); + PROCNAME(WritePtr); }; return nullptr; } -#undef FUNCNAME +#undef PROCNAME } // end namespace remote } // end namespace orc diff --git a/llvm/tools/lli/ChildTarget/ChildTarget.cpp b/llvm/tools/lli/ChildTarget/ChildTarget.cpp index 33de1850547..93925d6aa87 100644 --- a/llvm/tools/lli/ChildTarget/ChildTarget.cpp +++ b/llvm/tools/lli/ChildTarget/ChildTarget.cpp @@ -54,8 +54,8 @@ int main(int argc, char *argv[]) { JITServer Server(Channel, SymbolLookup, RegisterEHFrames, DeregisterEHFrames); while (1) { - JITServer::JITFuncId Id = JITServer::InvalidId; - if (auto EC = Server.getNextFuncId(Id)) { + JITServer::JITProcId Id = JITServer::InvalidId; + if (auto EC = Server.getNextProcId(Id)) { errs() << "Error: " << EC.message() << "\n"; return 1; } @@ -63,7 +63,7 @@ int main(int argc, char *argv[]) { case JITServer::TerminateSessionId: return 0; default: - if (auto EC = Server.handleKnownFunction(Id)) { + if (auto EC = Server.handleKnownProcedure(Id)) { errs() << "Error: " << EC.message() << "\n"; return 1; } diff --git a/llvm/tools/lli/RemoteJITUtils.h b/llvm/tools/lli/RemoteJITUtils.h index 63915d13dde..d5488ad555c 100644 --- a/llvm/tools/lli/RemoteJITUtils.h +++ b/llvm/tools/lli/RemoteJITUtils.h @@ -16,7 +16,6 @@ #include "llvm/ExecutionEngine/Orc/RPCChannel.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" -#include <mutex> #if !defined(_MSC_VER) && !defined(__MINGW32__) #include <unistd.h> diff --git a/llvm/tools/lli/lli.cpp b/llvm/tools/lli/lli.cpp index ce99b6ac7a0..0f30a4a5ff0 100644 --- a/llvm/tools/lli/lli.cpp +++ b/llvm/tools/lli/lli.cpp @@ -582,7 +582,7 @@ int main(int argc, char **argv, char * const *envp) { // Reset errno to zero on entry to main. errno = 0; - int Result = -1; + int Result; // Sanity check use of remote-jit: LLI currently only supports use of the // remote JIT on Unix platforms. @@ -681,13 +681,12 @@ int main(int argc, char **argv, char * const *envp) { static_cast<ForwardingMemoryManager*>(RTDyldMM)->setResolver( orc::createLambdaResolver( [&](const std::string &Name) { - if (auto AddrOrErr = R->getSymbolAddress(Name)) - return RuntimeDyld::SymbolInfo(*AddrOrErr, JITSymbolFlags::Exported); - else { - errs() << "Failure during symbol lookup: " - << AddrOrErr.getError().message() << "\n"; - exit(1); - } + orc::TargetAddress Addr = 0; + if (auto EC = R->getSymbolAddress(Addr, Name)) { + errs() << "Failure during symbol lookup: " << EC.message() << "\n"; + exit(1); + } + return RuntimeDyld::SymbolInfo(Addr, JITSymbolFlags::Exported); }, [](const std::string &Name) { return nullptr; } )); @@ -699,10 +698,8 @@ int main(int argc, char **argv, char * const *envp) { EE->finalizeObject(); DEBUG(dbgs() << "Executing '" << EntryFn->getName() << "' at 0x" << format("%llx", Entry) << "\n"); - if (auto ResultOrErr = R->callIntVoid(Entry)) - Result = *ResultOrErr; - else - errs() << "ERROR: " << ResultOrErr.getError().message() << "\n"; + if (auto EC = R->callIntVoid(Result, Entry)) + errs() << "ERROR: " << EC.message() << "\n"; // Like static constructors, the remote target MCJIT support doesn't handle // this yet. It could. FIXME. diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 77632e35eb1..3b01c3828b6 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -44,25 +44,26 @@ private: class DummyRPC : public testing::Test, public RPC<QueueChannel> { public: - typedef Function<2, void(bool)> BasicVoid; - typedef Function<3, int32_t(bool)> BasicInt; - typedef Function<4, void(int8_t, uint8_t, int16_t, uint16_t, - int32_t, uint32_t, int64_t, uint64_t, - bool, std::string, std::vector<int>)> AllTheTypes; + typedef Procedure<1, void(bool)> Proc1; + typedef Procedure<2, void(int8_t, uint8_t, int16_t, uint16_t, + int32_t, uint32_t, int64_t, uint64_t, + bool, std::string, std::vector<int>)> AllTheTypes; }; -TEST_F(DummyRPC, TestAsyncBasicVoid) { +TEST_F(DummyRPC, TestBasic) { std::queue<char> Queue; QueueChannel C(Queue); - // Make an async call. - auto ResOrErr = callAsync<BasicVoid>(C, true); - EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; + { + // Make a call to Proc1. + auto EC = call<Proc1>(C, true); + EXPECT_FALSE(EC) << "Simple call over queue failed"; + } { // Expect a call to Proc1. - auto EC = expect<BasicVoid>(C, + auto EC = expect<Proc1>(C, [&](bool &B) { EXPECT_EQ(B, true) << "Bool serialization broken"; @@ -70,70 +71,30 @@ TEST_F(DummyRPC, TestAsyncBasicVoid) { }); EXPECT_FALSE(EC) << "Simple expect over queue failed"; } - - { - // Wait for the result. - auto EC = waitForResult(C, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; - } - - // Verify that the function returned ok. - auto Val = ResOrErr->first.get(); - EXPECT_TRUE(Val) << "Remote void function failed to execute."; } -TEST_F(DummyRPC, TestAsyncBasicInt) { +TEST_F(DummyRPC, TestSerialization) { std::queue<char> Queue; QueueChannel C(Queue); - // Make an async call. - auto ResOrErr = callAsync<BasicInt>(C, false); - EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; - - { - // Expect a call to Proc1. - auto EC = expect<BasicInt>(C, - [&](bool &B) { - EXPECT_EQ(B, false) - << "Bool serialization broken"; - return 42; - }); - EXPECT_FALSE(EC) << "Simple expect over queue failed"; - } - { - // Wait for the result. - auto EC = waitForResult(C, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; + // Make a call to Proc1. + std::vector<int> v({42, 7}); + auto EC = call<AllTheTypes>(C, + -101, + 250, + -10000, + 10000, + -1000000000, + 1000000000, + -10000000000, + 10000000000, + true, + "foo", + v); + EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; } - // Verify that the function returned ok. - auto Val = ResOrErr->first.get(); - EXPECT_TRUE(!!Val) << "Remote int function failed to execute."; - EXPECT_EQ(*Val, 42) << "Remote int function return wrong value."; -} - -TEST_F(DummyRPC, TestSerialization) { - std::queue<char> Queue; - QueueChannel C(Queue); - - // Make a call to Proc1. - std::vector<int> v({42, 7}); - auto ResOrErr = callAsync<AllTheTypes>(C, - -101, - 250, - -10000, - 10000, - -1000000000, - 1000000000, - -10000000000, - 10000000000, - true, - "foo", - v); - EXPECT_TRUE(!!ResOrErr) - << "Big (serialization test) call over queue failed"; - { // Expect a call to Proc1. auto EC = expect<AllTheTypes>(C, @@ -175,14 +136,4 @@ TEST_F(DummyRPC, TestSerialization) { }); EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; } - - { - // Wait for the result. - auto EC = waitForResult(C, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; - } - - // Verify that the function returned ok. - auto Val = ResOrErr->first.get(); - EXPECT_TRUE(Val) << "Remote void function failed to execute."; } |