diff options
Diffstat (limited to 'llvm')
17 files changed, 1763 insertions, 1362 deletions
diff --git a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h index c95532e8db3..718b99e4b24 100644 --- a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h +++ b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h @@ -14,7 +14,7 @@ #ifndef LLVM_TOOLS_LLI_REMOTEJITUTILS_H #define LLVM_TOOLS_LLI_REMOTEJITUTILS_H -#include "llvm/ExecutionEngine/Orc/RPCByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" #include <mutex> @@ -25,7 +25,7 @@ #endif /// RPC channel that reads from and writes from file descriptors. -class FDRPCChannel final : public llvm::orc::remote::RPCByteChannel { +class FDRPCChannel final : public llvm::orc::rpc::RawByteChannel { public: FDRPCChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} diff --git a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp index 9c21098971a..f5a06cf2bf4 100644 --- a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp +++ b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp @@ -1265,8 +1265,8 @@ int main(int argc, char *argv[]) { BinopPrecedence['*'] = 40; // highest. auto TCPChannel = connect(); - MyRemote Remote = ExitOnErr(MyRemote::Create(*TCPChannel)); - TheJIT = llvm::make_unique<KaleidoscopeJIT>(Remote); + auto Remote = ExitOnErr(MyRemote::Create(*TCPChannel)); + TheJIT = llvm::make_unique<KaleidoscopeJIT>(*Remote); // Automatically inject a definition for 'printExprResult'. FunctionProtos["printExprResult"] = @@ -1288,7 +1288,7 @@ int main(int argc, char *argv[]) { TheJIT = nullptr; // Send a terminate message to the remote to tell it to exit cleanly. - ExitOnErr(Remote.terminateSession()); + ExitOnErr(Remote->terminateSession()); return 0; } diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h index 1b3f25fae16..8841aa77f62 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -29,6 +29,7 @@ enum class OrcErrorCode : int { RemoteIndirectStubsOwnerIdAlreadyInUse, UnexpectedRPCCall, UnexpectedRPCResponse, + UnknownRPCFunction }; Error orcError(OrcErrorCode ErrCode); diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index d549fc31deb..5b2f8921fef 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// // // This file defines the OrcRemoteTargetClient class and helpers. This class -// can be used to communicate over an RPCByteChannel with an +// can be used to communicate over an RawByteChannel with an // OrcRemoteTargetServer instance to support remote-JITing. // //===----------------------------------------------------------------------===// @@ -36,23 +36,6 @@ namespace remote { template <typename ChannelT> class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI { public: - // FIXME: Remove move/copy ops once MSVC supports synthesizing move ops. - - OrcRemoteTargetClient(const OrcRemoteTargetClient &) = delete; - OrcRemoteTargetClient &operator=(const OrcRemoteTargetClient &) = delete; - - OrcRemoteTargetClient(OrcRemoteTargetClient &&Other) - : Channel(Other.Channel), ExistingError(std::move(Other.ExistingError)), - RemoteTargetTriple(std::move(Other.RemoteTargetTriple)), - RemotePointerSize(std::move(Other.RemotePointerSize)), - RemotePageSize(std::move(Other.RemotePageSize)), - RemoteTrampolineSize(std::move(Other.RemoteTrampolineSize)), - RemoteIndirectStubSize(std::move(Other.RemoteIndirectStubSize)), - AllocatorIds(std::move(Other.AllocatorIds)), - IndirectStubOwnerIds(std::move(Other.IndirectStubOwnerIds)), - CallbackManager(std::move(Other.CallbackManager)) {} - - OrcRemoteTargetClient &operator=(OrcRemoteTargetClient &&) = delete; /// Remote memory manager. class RCMemoryManager : public RuntimeDyld::MemoryManager { @@ -62,18 +45,10 @@ public: DEBUG(dbgs() << "Created remote allocator " << Id << "\n"); } - RCMemoryManager(RCMemoryManager &&Other) - : Client(std::move(Other.Client)), Id(std::move(Other.Id)), - Unmapped(std::move(Other.Unmapped)), - Unfinalized(std::move(Other.Unfinalized)) {} - - RCMemoryManager operator=(RCMemoryManager &&Other) { - Client = std::move(Other.Client); - Id = std::move(Other.Id); - Unmapped = std::move(Other.Unmapped); - Unfinalized = std::move(Other.Unfinalized); - return *this; - } + RCMemoryManager(const RCMemoryManager&) = delete; + RCMemoryManager& operator=(const RCMemoryManager&) = delete; + RCMemoryManager(RCMemoryManager&&) = default; + RCMemoryManager& operator=(RCMemoryManager&&) = default; ~RCMemoryManager() override { Client.destroyRemoteAllocator(Id); @@ -367,18 +342,10 @@ public: Alloc(uint64_t Size, unsigned Align) : Size(Size), Align(Align), Contents(new char[Size + Align - 1]) {} - Alloc(Alloc &&Other) - : Size(std::move(Other.Size)), Align(std::move(Other.Align)), - Contents(std::move(Other.Contents)), - RemoteAddr(std::move(Other.RemoteAddr)) {} - - Alloc &operator=(Alloc &&Other) { - Size = std::move(Other.Size); - Align = std::move(Other.Align); - Contents = std::move(Other.Contents); - RemoteAddr = std::move(Other.RemoteAddr); - return *this; - } + Alloc(const Alloc&) = delete; + Alloc& operator=(const Alloc&) = delete; + Alloc(Alloc&&) = default; + Alloc& operator=(Alloc&&) = default; uint64_t getSize() const { return Size; } @@ -405,24 +372,10 @@ public: struct ObjectAllocs { ObjectAllocs() = default; - - ObjectAllocs(ObjectAllocs &&Other) - : RemoteCodeAddr(std::move(Other.RemoteCodeAddr)), - RemoteRODataAddr(std::move(Other.RemoteRODataAddr)), - RemoteRWDataAddr(std::move(Other.RemoteRWDataAddr)), - CodeAllocs(std::move(Other.CodeAllocs)), - RODataAllocs(std::move(Other.RODataAllocs)), - RWDataAllocs(std::move(Other.RWDataAllocs)) {} - - ObjectAllocs &operator=(ObjectAllocs &&Other) { - RemoteCodeAddr = std::move(Other.RemoteCodeAddr); - RemoteRODataAddr = std::move(Other.RemoteRODataAddr); - RemoteRWDataAddr = std::move(Other.RemoteRWDataAddr); - CodeAllocs = std::move(Other.CodeAllocs); - RODataAllocs = std::move(Other.RODataAllocs); - RWDataAllocs = std::move(Other.RWDataAllocs); - return *this; - } + ObjectAllocs(const ObjectAllocs &) = delete; + ObjectAllocs& operator=(const ObjectAllocs &) = delete; + ObjectAllocs(ObjectAllocs&&) = default; + ObjectAllocs& operator=(ObjectAllocs&&) = default; JITTargetAddress RemoteCodeAddr = 0; JITTargetAddress RemoteRODataAddr = 0; @@ -588,23 +541,21 @@ public: /// Create an OrcRemoteTargetClient. /// Channel is the ChannelT instance to communicate on. It is assumed that /// the channel is ready to be read from and written to. - static Expected<OrcRemoteTargetClient> Create(ChannelT &Channel) { + static Expected<std::unique_ptr<OrcRemoteTargetClient>> + Create(ChannelT &Channel) { Error Err = Error::success(); - OrcRemoteTargetClient H(Channel, Err); + std::unique_ptr<OrcRemoteTargetClient> + Client(new OrcRemoteTargetClient(Channel, Err)); if (Err) return std::move(Err); - return Expected<OrcRemoteTargetClient>(std::move(H)); + return std::move(Client); } /// Call the int(void) function at the given address in the target and return /// its result. Expected<int> callIntVoid(JITTargetAddress Addr) { DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) << "\n"); - - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling<CallIntVoid>(Channel, Listen, Addr); + return callB<CallIntVoid>(Addr); } /// Call the int(int, char*[]) function at the given address in the target and @@ -613,11 +564,7 @@ public: const std::vector<std::string> &Args) { DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x", Addr) << "\n"); - - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling<CallMain>(Channel, Listen, Addr, Args); + return callB<CallMain>(Addr, Args); } /// Call the void() function at the given address in the target and wait for @@ -625,11 +572,7 @@ public: Error callVoidVoid(JITTargetAddress Addr) { DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr) << "\n"); - - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling<CallVoidVoid>(Channel, Listen, Addr); + return callB<CallVoidVoid>(Addr); } /// Create an RCMemoryManager which will allocate its memory on the remote @@ -638,7 +581,7 @@ public: assert(!MM && "MemoryManager should be null before creation."); auto Id = AllocatorIds.getNext(); - if (auto Err = callST<CreateRemoteAllocator>(Channel, Id)) + if (auto Err = callB<CreateRemoteAllocator>(Id)) return Err; MM = llvm::make_unique<RCMemoryManager>(*this, Id); return Error::success(); @@ -649,7 +592,7 @@ public: Error createIndirectStubsManager(std::unique_ptr<RCIndirectStubsManager> &I) { assert(!I && "Indirect stubs manager should be null before creation."); auto Id = IndirectStubOwnerIds.getNext(); - if (auto Err = callST<CreateIndirectStubsOwner>(Channel, Id)) + if (auto Err = callB<CreateIndirectStubsOwner>(Id)) return Err; I = llvm::make_unique<RCIndirectStubsManager>(*this, Id); return Error::success(); @@ -662,7 +605,7 @@ public: return std::move(ExistingError); // Emit the resolver block on the JIT server. - if (auto Err = callST<EmitResolverBlock>(Channel)) + if (auto Err = callB<EmitResolverBlock>()) return std::move(Err); // Create the callback manager. @@ -679,18 +622,28 @@ public: if (ExistingError) return std::move(ExistingError); - return callST<GetSymbolAddress>(Channel, Name); + return callB<GetSymbolAddress>(Name); } /// Get the triple for the remote target. const std::string &getTargetTriple() const { return RemoteTargetTriple; } - Error terminateSession() { return callST<TerminateSession>(Channel); } + Error terminateSession() { return callB<TerminateSession>(); } private: - OrcRemoteTargetClient(ChannelT &Channel, Error &Err) : Channel(Channel) { + + OrcRemoteTargetClient(ChannelT &Channel, Error &Err) + : OrcRemoteTargetRPCAPI(Channel) { ErrorAsOutParameter EAO(&Err); - if (auto RIOrErr = callST<GetRemoteInfo>(Channel)) { + + addHandler<RequestCompile>( + [this](JITTargetAddress Addr) -> JITTargetAddress { + if (CallbackManager) + return CallbackManager->executeCompileCallback(Addr); + return 0; + }); + + if (auto RIOrErr = callB<GetRemoteInfo>()) { std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize, RemoteTrampolineSize, RemoteIndirectStubSize) = *RIOrErr; Err = Error::success(); @@ -700,11 +653,11 @@ private: } Error deregisterEHFrames(JITTargetAddress Addr, uint32_t Size) { - return callST<RegisterEHFrames>(Channel, Addr, Size); + return callB<RegisterEHFrames>(Addr, Size); } void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { - if (auto Err = callST<DestroyRemoteAllocator>(Channel, Id)) { + if (auto Err = callB<DestroyRemoteAllocator>(Id)) { // FIXME: This will be triggered by a removeModuleSet call: Propagate // error return up through that. llvm_unreachable("Failed to destroy remote allocator."); @@ -714,12 +667,12 @@ private: Error destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { IndirectStubOwnerIds.release(Id); - return callST<DestroyIndirectStubsOwner>(Channel, Id); + return callB<DestroyIndirectStubsOwner>(Id); } Expected<std::tuple<JITTargetAddress, JITTargetAddress, uint32_t>> emitIndirectStubs(ResourceIdMgr::ResourceId Id, uint32_t NumStubsRequired) { - return callST<EmitIndirectStubs>(Channel, Id, NumStubsRequired); + return callB<EmitIndirectStubs>(Id, NumStubsRequired); } Expected<std::tuple<JITTargetAddress, uint32_t>> emitTrampolineBlock() { @@ -727,7 +680,7 @@ private: if (ExistingError) return std::move(ExistingError); - return callST<EmitTrampolineBlock>(Channel); + return callB<EmitTrampolineBlock>(); } uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; } @@ -736,42 +689,17 @@ private: uint32_t getTrampolineSize() const { return RemoteTrampolineSize; } - Error listenForCompileRequests(RPCByteChannel &C, uint32_t &Id) { - assert(CallbackManager && - "No calback manager. enableCompileCallbacks must be called first"); - - // Check for an 'out-of-band' error, e.g. from an MM destructor. - if (ExistingError) - return std::move(ExistingError); - - // FIXME: CompileCallback could be an anonymous lambda defined at the use - // site below, but that triggers a GCC 4.7 ICE. When we move off - // GCC 4.7, tidy this up. - auto CompileCallback = - [this](JITTargetAddress Addr) -> Expected<JITTargetAddress> { - return this->CallbackManager->executeCompileCallback(Addr); - }; - - if (Id == RequestCompileId) { - if (auto Err = handle<RequestCompile>(C, CompileCallback)) - return Err; - return Error::success(); - } - // else - return orcError(OrcErrorCode::UnexpectedRPCCall); - } - Expected<std::vector<char>> readMem(char *Dst, JITTargetAddress Src, uint64_t Size) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return std::move(ExistingError); - return callST<ReadMem>(Channel, Src, Size); + return callB<ReadMem>(Src, Size); } Error registerEHFrames(JITTargetAddress &RAddr, uint32_t Size) { - return callST<RegisterEHFrames>(Channel, RAddr, Size); + return callB<RegisterEHFrames>(RAddr, Size); } Expected<JITTargetAddress> reserveMem(ResourceIdMgr::ResourceId Id, @@ -781,12 +709,12 @@ private: if (ExistingError) return std::move(ExistingError); - return callST<ReserveMem>(Channel, Id, Size, Align); + return callB<ReserveMem>(Id, Size, Align); } Error setProtections(ResourceIdMgr::ResourceId Id, JITTargetAddress RemoteSegAddr, unsigned ProtFlags) { - return callST<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags); + return callB<SetProtections>(Id, RemoteSegAddr, ProtFlags); } Error writeMem(JITTargetAddress Addr, const char *Src, uint64_t Size) { @@ -794,7 +722,7 @@ private: if (ExistingError) return std::move(ExistingError); - return callST<WriteMem>(Channel, DirectBufferWriter(Src, Addr, Size)); + return callB<WriteMem>(DirectBufferWriter(Src, Addr, Size)); } Error writePointer(JITTargetAddress Addr, JITTargetAddress PtrVal) { @@ -802,12 +730,11 @@ private: if (ExistingError) return std::move(ExistingError); - return callST<WritePtr>(Channel, Addr, PtrVal); + return callB<WritePtr>(Addr, PtrVal); } static Error doNothing() { return Error::success(); } - ChannelT &Channel; Error ExistingError = Error::success(); std::string RemoteTargetTriple; uint32_t RemotePointerSize = 0; diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index 33d6b604c61..413e286a347 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -16,7 +16,7 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H #define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H -#include "RPCByteChannel.h" +#include "RawByteChannel.h" #include "RPCUtils.h" #include "llvm/ExecutionEngine/JITSymbol.h" @@ -40,13 +40,24 @@ private: uint64_t Size; }; +} // end namespace remote + +namespace rpc { + template <> -class SerializationTraits<RPCByteChannel, DirectBufferWriter> { +class RPCTypeName<remote::DirectBufferWriter> { public: + static const char *getName() { return "DirectBufferWriter"; } +}; - static const char* getName() { return "DirectBufferWriter"; } +template <typename ChannelT> +class SerializationTraits<ChannelT, remote::DirectBufferWriter, remote::DirectBufferWriter, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>:: + value>::type> { +public: - static Error serialize(RPCByteChannel &C, const DirectBufferWriter &DBW) { + static Error serialize(ChannelT &C, const remote::DirectBufferWriter &DBW) { if (auto EC = serializeSeq(C, DBW.getDst())) return EC; if (auto EC = serializeSeq(C, DBW.getSize())) @@ -54,7 +65,7 @@ public: return C.appendBytes(DBW.getSrc(), DBW.getSize()); } - static Error deserialize(RPCByteChannel &C, DirectBufferWriter &DBW) { + static Error deserialize(ChannelT &C, remote::DirectBufferWriter &DBW) { JITTargetAddress Dst; if (auto EC = deserializeSeq(C, Dst)) return EC; @@ -63,13 +74,18 @@ public: return EC; char *Addr = reinterpret_cast<char *>(static_cast<uintptr_t>(Dst)); - DBW = DirectBufferWriter(0, Dst, Size); + DBW = remote::DirectBufferWriter(0, Dst, Size); return C.readBytes(Addr, Size); } }; -class OrcRemoteTargetRPCAPI : public RPC<RPCByteChannel> { +} // end namespace rpc + +namespace remote { + +class OrcRemoteTargetRPCAPI + : public rpc::SingleThreadedRPC<rpc::RawByteChannel> { protected: class ResourceIdMgr { public: @@ -93,119 +109,162 @@ protected: public: // FIXME: Remove constructors once MSVC supports synthesizing move-ops. - OrcRemoteTargetRPCAPI() = default; - OrcRemoteTargetRPCAPI(const OrcRemoteTargetRPCAPI &) = delete; - OrcRemoteTargetRPCAPI &operator=(const OrcRemoteTargetRPCAPI &) = delete; - - OrcRemoteTargetRPCAPI(OrcRemoteTargetRPCAPI &&) {} - OrcRemoteTargetRPCAPI &operator=(OrcRemoteTargetRPCAPI &&) { return *this; } - - enum JITFuncId : uint32_t { - InvalidId = RPCFunctionIdTraits<JITFuncId>::InvalidId, - CallIntVoidId = RPCFunctionIdTraits<JITFuncId>::FirstValidId, - CallMainId, - CallVoidVoidId, - CreateRemoteAllocatorId, - CreateIndirectStubsOwnerId, - DeregisterEHFramesId, - DestroyRemoteAllocatorId, - DestroyIndirectStubsOwnerId, - EmitIndirectStubsId, - EmitResolverBlockId, - EmitTrampolineBlockId, - GetSymbolAddressId, - GetRemoteInfoId, - ReadMemId, - RegisterEHFramesId, - ReserveMemId, - RequestCompileId, - SetProtectionsId, - TerminateSessionId, - WriteMemId, - WritePtrId - }; - - static const char *getJITFuncIdName(JITFuncId Id); - - typedef Function<CallIntVoidId, int32_t(JITTargetAddress Addr)> CallIntVoid; - - typedef Function<CallMainId, - int32_t(JITTargetAddress Addr, - std::vector<std::string> Args)> - CallMain; - - typedef Function<CallVoidVoidId, void(JITTargetAddress FnAddr)> CallVoidVoid; - - typedef Function<CreateRemoteAllocatorId, - void(ResourceIdMgr::ResourceId AllocatorID)> - CreateRemoteAllocator; - - typedef Function<CreateIndirectStubsOwnerId, - void(ResourceIdMgr::ResourceId StubOwnerID)> - CreateIndirectStubsOwner; - - typedef Function<DeregisterEHFramesId, - void(JITTargetAddress Addr, uint32_t Size)> - DeregisterEHFrames; - - typedef Function<DestroyRemoteAllocatorId, - void(ResourceIdMgr::ResourceId AllocatorID)> - DestroyRemoteAllocator; - - typedef Function<DestroyIndirectStubsOwnerId, - void(ResourceIdMgr::ResourceId StubsOwnerID)> - DestroyIndirectStubsOwner; + OrcRemoteTargetRPCAPI(rpc::RawByteChannel &C) + : rpc::SingleThreadedRPC<rpc::RawByteChannel>(C, true) {} + + class CallIntVoid : public rpc::Function<CallIntVoid, + int32_t(JITTargetAddress Addr)> { + public: + static const char* getName() { return "CallIntVoid"; } + }; + + class CallMain + : public rpc::Function<CallMain, + int32_t(JITTargetAddress Addr, + std::vector<std::string> Args)> { + public: + static const char* getName() { return "CallMain"; } + }; + + class CallVoidVoid : public rpc::Function<CallVoidVoid, + void(JITTargetAddress FnAddr)> { + public: + static const char* getName() { return "CallVoidVoid"; } + }; + + class CreateRemoteAllocator + : public rpc::Function<CreateRemoteAllocator, + void(ResourceIdMgr::ResourceId AllocatorID)> { + public: + static const char* getName() { return "CreateRemoteAllocator"; } + }; + + class CreateIndirectStubsOwner + : public rpc::Function<CreateIndirectStubsOwner, + void(ResourceIdMgr::ResourceId StubOwnerID)> { + public: + static const char* getName() { return "CreateIndirectStubsOwner"; } + }; + + class DeregisterEHFrames + : public rpc::Function<DeregisterEHFrames, + void(JITTargetAddress Addr, uint32_t Size)> { + public: + static const char* getName() { return "DeregisterEHFrames"; } + }; + + class DestroyRemoteAllocator + : public rpc::Function<DestroyRemoteAllocator, + void(ResourceIdMgr::ResourceId AllocatorID)> { + public: + static const char* getName() { return "DestroyRemoteAllocator"; } + }; + + class DestroyIndirectStubsOwner + : public rpc::Function<DestroyIndirectStubsOwner, + void(ResourceIdMgr::ResourceId StubsOwnerID)> { + public: + static const char* getName() { return "DestroyIndirectStubsOwner"; } + }; /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted). - typedef Function<EmitIndirectStubsId, - std::tuple<JITTargetAddress, JITTargetAddress, uint32_t>( - ResourceIdMgr::ResourceId StubsOwnerID, - uint32_t NumStubsRequired)> - EmitIndirectStubs; + class EmitIndirectStubs + : public rpc::Function<EmitIndirectStubs, + std::tuple<JITTargetAddress, JITTargetAddress, + uint32_t>( + ResourceIdMgr::ResourceId StubsOwnerID, + uint32_t NumStubsRequired)> { + public: + static const char* getName() { return "EmitIndirectStubs"; } + }; - typedef Function<EmitResolverBlockId, void()> EmitResolverBlock; + class EmitResolverBlock : public rpc::Function<EmitResolverBlock, void()> { + public: + static const char* getName() { return "EmitResolverBlock"; } + }; /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines). - typedef Function<EmitTrampolineBlockId, - std::tuple<JITTargetAddress, uint32_t>()> - EmitTrampolineBlock; + class EmitTrampolineBlock + : public rpc::Function<EmitTrampolineBlock, + std::tuple<JITTargetAddress, uint32_t>()> { + public: + static const char* getName() { return "EmitTrampolineBlock"; } + }; - typedef Function<GetSymbolAddressId, JITTargetAddress(std::string SymbolName)> - GetSymbolAddress; + class GetSymbolAddress + : public rpc::Function<GetSymbolAddress, + JITTargetAddress(std::string SymbolName)> { + public: + static const char* getName() { return "GetSymbolAddress"; } + }; /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize, /// IndirectStubsSize). - typedef Function<GetRemoteInfoId, std::tuple<std::string, uint32_t, uint32_t, - uint32_t, uint32_t>()> - GetRemoteInfo; + class GetRemoteInfo + : public rpc::Function<GetRemoteInfo, + std::tuple<std::string, uint32_t, uint32_t, + uint32_t, uint32_t>()> { + public: + static const char* getName() { return "GetRemoteInfo"; } + }; - typedef Function<ReadMemId, - std::vector<char>(JITTargetAddress Src, uint64_t Size)> - ReadMem; + class ReadMem + : public rpc::Function<ReadMem, + std::vector<uint8_t>(JITTargetAddress Src, + uint64_t Size)> { + public: + static const char* getName() { return "ReadMem"; } + }; - typedef Function<RegisterEHFramesId, void(JITTargetAddress Addr, uint32_t Size)> - RegisterEHFrames; + class RegisterEHFrames + : public rpc::Function<RegisterEHFrames, + void(JITTargetAddress Addr, uint32_t Size)> { + public: + static const char* getName() { return "RegisterEHFrames"; } + }; - typedef Function<ReserveMemId, - JITTargetAddress(ResourceIdMgr::ResourceId AllocID, - uint64_t Size, uint32_t Align)> - ReserveMem; + class ReserveMem + : public rpc::Function<ReserveMem, + JITTargetAddress(ResourceIdMgr::ResourceId AllocID, + uint64_t Size, uint32_t Align)> { + public: + static const char* getName() { return "ReserveMem"; } + }; - typedef Function<RequestCompileId, - JITTargetAddress(JITTargetAddress TrampolineAddr)> - RequestCompile; + class RequestCompile + : public rpc::Function<RequestCompile, + JITTargetAddress(JITTargetAddress TrampolineAddr)> { + public: + static const char* getName() { return "RequestCompile"; } + }; + + class SetProtections + : public rpc::Function<SetProtections, + void(ResourceIdMgr::ResourceId AllocID, + JITTargetAddress Dst, + uint32_t ProtFlags)> { + public: + static const char* getName() { return "SetProtections"; } + }; - typedef Function<SetProtectionsId, - void(ResourceIdMgr::ResourceId AllocID, JITTargetAddress Dst, - uint32_t ProtFlags)> - SetProtections; + class TerminateSession : public rpc::Function<TerminateSession, void()> { + public: + static const char* getName() { return "TerminateSession"; } + }; - typedef Function<TerminateSessionId, void()> TerminateSession; + class WriteMem : public rpc::Function<WriteMem, + void(remote::DirectBufferWriter DB)> { + public: + static const char* getName() { return "WriteMem"; } + }; - typedef Function<WriteMemId, void(DirectBufferWriter DB)> WriteMem; + class WritePtr + : public rpc::Function<WritePtr, + void(JITTargetAddress Dst, JITTargetAddress Val)> { + public: + static const char* getName() { return "WritePtr"; } + }; - typedef Function<WritePtrId, void(JITTargetAddress Dst, JITTargetAddress Val)> - WritePtr; }; } // end namespace remote diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index e3dfaf77566..bda4cd15342 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -41,94 +41,51 @@ public: OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup, EHFrameRegistrationFtor EHFramesRegister, EHFrameRegistrationFtor EHFramesDeregister) - : Channel(Channel), SymbolLookup(std::move(SymbolLookup)), + : OrcRemoteTargetRPCAPI(Channel), SymbolLookup(std::move(SymbolLookup)), EHFramesRegister(std::move(EHFramesRegister)), - EHFramesDeregister(std::move(EHFramesDeregister)) {} + EHFramesDeregister(std::move(EHFramesDeregister)), + TerminateFlag(false) { + + using ThisT = typename std::remove_reference<decltype(*this)>::type; + addHandler<CallIntVoid>(*this, &ThisT::handleCallIntVoid); + addHandler<CallMain>(*this, &ThisT::handleCallMain); + addHandler<CallVoidVoid>(*this, &ThisT::handleCallVoidVoid); + addHandler<CreateRemoteAllocator>(*this, + &ThisT::handleCreateRemoteAllocator); + addHandler<CreateIndirectStubsOwner>(*this, + &ThisT::handleCreateIndirectStubsOwner); + addHandler<DeregisterEHFrames>(*this, &ThisT::handleDeregisterEHFrames); + addHandler<DestroyRemoteAllocator>(*this, + &ThisT::handleDestroyRemoteAllocator); + addHandler<DestroyIndirectStubsOwner>(*this, + &ThisT::handleDestroyIndirectStubsOwner); + addHandler<EmitIndirectStubs>(*this, &ThisT::handleEmitIndirectStubs); + addHandler<EmitResolverBlock>(*this, &ThisT::handleEmitResolverBlock); + addHandler<EmitTrampolineBlock>(*this, &ThisT::handleEmitTrampolineBlock); + addHandler<GetSymbolAddress>(*this, &ThisT::handleGetSymbolAddress); + addHandler<GetRemoteInfo>(*this, &ThisT::handleGetRemoteInfo); + addHandler<ReadMem>(*this, &ThisT::handleReadMem); + addHandler<RegisterEHFrames>(*this, &ThisT::handleRegisterEHFrames); + addHandler<ReserveMem>(*this, &ThisT::handleReserveMem); + addHandler<SetProtections>(*this, &ThisT::handleSetProtections); + addHandler<TerminateSession>(*this, &ThisT::handleTerminateSession); + addHandler<WriteMem>(*this, &ThisT::handleWriteMem); + addHandler<WritePtr>(*this, &ThisT::handleWritePtr); + } // FIXME: Remove move/copy ops once MSVC supports synthesizing move ops. OrcRemoteTargetServer(const OrcRemoteTargetServer &) = delete; OrcRemoteTargetServer &operator=(const OrcRemoteTargetServer &) = delete; - OrcRemoteTargetServer(OrcRemoteTargetServer &&Other) - : Channel(Other.Channel), SymbolLookup(std::move(Other.SymbolLookup)), - EHFramesRegister(std::move(Other.EHFramesRegister)), - EHFramesDeregister(std::move(Other.EHFramesDeregister)) {} - + OrcRemoteTargetServer(OrcRemoteTargetServer &&Other) = default; OrcRemoteTargetServer &operator=(OrcRemoteTargetServer &&) = delete; - Error handleKnownFunction(JITFuncId Id) { - typedef OrcRemoteTargetServer ThisT; - - DEBUG(dbgs() << "Handling known proc: " << getJITFuncIdName(Id) << "\n"); - - switch (Id) { - case CallIntVoidId: - return handle<CallIntVoid>(Channel, *this, &ThisT::handleCallIntVoid); - case CallMainId: - return handle<CallMain>(Channel, *this, &ThisT::handleCallMain); - case CallVoidVoidId: - return handle<CallVoidVoid>(Channel, *this, &ThisT::handleCallVoidVoid); - case CreateRemoteAllocatorId: - return handle<CreateRemoteAllocator>(Channel, *this, - &ThisT::handleCreateRemoteAllocator); - case CreateIndirectStubsOwnerId: - return handle<CreateIndirectStubsOwner>( - Channel, *this, &ThisT::handleCreateIndirectStubsOwner); - case DeregisterEHFramesId: - return handle<DeregisterEHFrames>(Channel, *this, - &ThisT::handleDeregisterEHFrames); - case DestroyRemoteAllocatorId: - return handle<DestroyRemoteAllocator>( - Channel, *this, &ThisT::handleDestroyRemoteAllocator); - case DestroyIndirectStubsOwnerId: - return handle<DestroyIndirectStubsOwner>( - Channel, *this, &ThisT::handleDestroyIndirectStubsOwner); - case EmitIndirectStubsId: - return handle<EmitIndirectStubs>(Channel, *this, - &ThisT::handleEmitIndirectStubs); - case EmitResolverBlockId: - return handle<EmitResolverBlock>(Channel, *this, - &ThisT::handleEmitResolverBlock); - case EmitTrampolineBlockId: - return handle<EmitTrampolineBlock>(Channel, *this, - &ThisT::handleEmitTrampolineBlock); - case GetSymbolAddressId: - return handle<GetSymbolAddress>(Channel, *this, - &ThisT::handleGetSymbolAddress); - case GetRemoteInfoId: - return handle<GetRemoteInfo>(Channel, *this, &ThisT::handleGetRemoteInfo); - case ReadMemId: - return handle<ReadMem>(Channel, *this, &ThisT::handleReadMem); - case RegisterEHFramesId: - return handle<RegisterEHFrames>(Channel, *this, - &ThisT::handleRegisterEHFrames); - case ReserveMemId: - return handle<ReserveMem>(Channel, *this, &ThisT::handleReserveMem); - case SetProtectionsId: - return handle<SetProtections>(Channel, *this, - &ThisT::handleSetProtections); - case WriteMemId: - return handle<WriteMem>(Channel, *this, &ThisT::handleWriteMem); - case WritePtrId: - return handle<WritePtr>(Channel, *this, &ThisT::handleWritePtr); - default: - return orcError(OrcErrorCode::UnexpectedRPCCall); - } - - llvm_unreachable("Unhandled JIT RPC procedure Id."); - } Expected<JITTargetAddress> requestCompile(JITTargetAddress TrampolineAddr) { - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return handleKnownFunction(static_cast<JITFuncId>(Id)); - }; - - return callSTHandling<RequestCompile>(Channel, Listen, TrampolineAddr); + return callB<RequestCompile>(TrampolineAddr); } - Error handleTerminateSession() { - return handle<TerminateSession>(Channel, []() { return Error::success(); }); - } + bool receivedTerminate() const { return TerminateFlag; } private: struct Allocator { @@ -365,15 +322,16 @@ private: IndirectStubSize); } - Expected<std::vector<char>> handleReadMem(JITTargetAddress RSrc, uint64_t Size) { - char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc)); + Expected<std::vector<uint8_t>> handleReadMem(JITTargetAddress RSrc, + uint64_t Size) { + uint8_t *Src = reinterpret_cast<uint8_t*>(static_cast<uintptr_t>(RSrc)); DEBUG(dbgs() << " Reading " << Size << " bytes from " << format("0x%016x", RSrc) << "\n"); - std::vector<char> Buffer; + std::vector<uint8_t> Buffer; Buffer.resize(Size); - for (char *P = Src; Size != 0; --Size) + for (uint8_t *P = Src; Size != 0; --Size) Buffer.push_back(*P++); return Buffer; @@ -421,6 +379,11 @@ private: return Allocator.setProtections(LocalAddr, Flags); } + Error handleTerminateSession() { + TerminateFlag = true; + return Error::success(); + } + Error handleWriteMem(DirectBufferWriter DBW) { DEBUG(dbgs() << " Writing " << DBW.getSize() << " bytes to " << format("0x%016x", DBW.getDst()) << "\n"); @@ -436,7 +399,6 @@ private: return Error::success(); } - ChannelT &Channel; SymbolLookupFtor SymbolLookup; EHFrameRegistrationFtor EHFramesRegister, EHFramesDeregister; std::map<ResourceIdMgr::ResourceId, Allocator> Allocators; @@ -444,6 +406,7 @@ private: std::map<ResourceIdMgr::ResourceId, ISBlockOwnerList> IndirectStubsOwners; sys::OwningMemoryBlock ResolverBlock; std::vector<sys::OwningMemoryBlock> TrampolineBlocks; + bool TerminateFlag; }; } // end namespace remote diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h deleted file mode 100644 index c8cb42d5374..00000000000 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h +++ /dev/null @@ -1,231 +0,0 @@ -//===- llvm/ExecutionEngine/Orc/RPCByteChannel.h ----------------*- C++ -*-===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H -#define LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H - -#include "OrcError.h" -#include "RPCSerialization.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Endian.h" -#include "llvm/Support/Error.h" -#include <cstddef> -#include <cstdint> -#include <mutex> -#include <string> -#include <tuple> -#include <type_traits> -#include <vector> - -namespace llvm { -namespace orc { -namespace remote { - -/// Interface for byte-streams to be used with RPC. -class RPCByteChannel { -public: - virtual ~RPCByteChannel() {} - - /// Read Size bytes from the stream into *Dst. - virtual Error readBytes(char *Dst, unsigned Size) = 0; - - /// Read size bytes from *Src and append them to the stream. - virtual Error appendBytes(const char *Src, unsigned Size) = 0; - - /// Flush the stream if possible. - virtual Error send() = 0; - - /// Get the lock for stream reading. - std::mutex &getReadLock() { return readLock; } - - /// Get the lock for stream writing. - std::mutex &getWriteLock() { return writeLock; } - -private: - std::mutex readLock, writeLock; -}; - -/// Notify the channel that we're starting a message send. -/// Locks the channel for writing. -inline Error startSendMessage(RPCByteChannel &C) { - C.getWriteLock().lock(); - return Error::success(); -} - -/// Notify the channel that we're ending a message send. -/// Unlocks the channel for writing. -inline Error endSendMessage(RPCByteChannel &C) { - C.getWriteLock().unlock(); - return Error::success(); -} - -/// Notify the channel that we're starting a message receive. -/// Locks the channel for reading. -inline Error startReceiveMessage(RPCByteChannel &C) { - C.getReadLock().lock(); - return Error::success(); -} - -/// Notify the channel that we're ending a message receive. -/// Unlocks the channel for reading. -inline Error endReceiveMessage(RPCByteChannel &C) { - C.getReadLock().unlock(); - return Error::success(); -} - -template <typename ChannelT, typename T, - typename = - typename std::enable_if< - std::is_base_of<RPCByteChannel, ChannelT>::value>:: - type> -class RPCByteChannelPrimitiveSerialization { -public: - static Error serialize(ChannelT &C, T V) { - support::endian::byte_swap<T, support::big>(V); - return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); - }; - - static Error deserialize(ChannelT &C, T &V) { - if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) - return Err; - support::endian::byte_swap<T, support::big>(V); - return Error::success(); - }; -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, uint64_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, uint64_t> { -public: - static const char* getName() { return "uint64_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, int64_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, int64_t> { -public: - static const char* getName() { return "int64_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, uint32_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, uint32_t> { -public: - static const char* getName() { return "uint32_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, int32_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, int32_t> { -public: - static const char* getName() { return "int32_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, uint16_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, uint16_t> { -public: - static const char* getName() { return "uint16_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, int16_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, int16_t> { -public: - static const char* getName() { return "int16_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, uint8_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, uint8_t> { -public: - static const char* getName() { return "uint8_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, int8_t> - : public RPCByteChannelPrimitiveSerialization<ChannelT, int8_t> { -public: - static const char* getName() { return "int8_t"; } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, char> - : public RPCByteChannelPrimitiveSerialization<ChannelT, uint8_t> { -public: - static const char* getName() { return "char"; } - - static Error serialize(RPCByteChannel &C, char V) { - return serializeSeq(C, static_cast<uint8_t>(V)); - }; - - static Error deserialize(RPCByteChannel &C, char &V) { - uint8_t VV; - if (auto Err = deserializeSeq(C, VV)) - return Err; - V = static_cast<char>(V); - return Error::success(); - }; -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, bool, - typename std::enable_if< - std::is_base_of<RPCByteChannel, ChannelT>::value>:: - type> { -public: - static const char* getName() { return "bool"; } - - static Error serialize(ChannelT &C, bool V) { - return C.appendBytes(reinterpret_cast<const char *>(&V), 1); - } - - static Error deserialize(ChannelT &C, bool &V) { - return C.readBytes(reinterpret_cast<char *>(&V), 1); - } -}; - -template <typename ChannelT> -class SerializationTraits<ChannelT, std::string, - typename std::enable_if< - std::is_base_of<RPCByteChannel, ChannelT>::value>:: - type> { -public: - static const char* getName() { return "std::string"; } - - static Error serialize(RPCByteChannel &C, StringRef S) { - if (auto Err = SerializationTraits<RPCByteChannel, uint64_t>:: - serialize(C, static_cast<uint64_t>(S.size()))) - return Err; - return C.appendBytes((const char *)S.bytes_begin(), S.size()); - } - - /// RPC channel serialization for std::strings. - static Error serialize(RPCByteChannel &C, const std::string &S) { - return serialize(C, StringRef(S)); - } - - /// RPC channel deserialization for std::strings. - static Error deserialize(RPCByteChannel &C, std::string &S) { - uint64_t Count = 0; - if (auto Err = SerializationTraits<RPCByteChannel, uint64_t>:: - deserialize(C, Count)) - return Err; - S.resize(Count); - return C.readBytes(&S[0], Count); - } -}; - -} // end namespace remote -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 0e9f5157f29..d1503e91b4f 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -17,7 +17,164 @@ namespace llvm { namespace orc { -namespace remote { +namespace rpc { + +template <typename T> +class RPCTypeName; + +/// TypeNameSequence is a utility for rendering sequences of types to a string +/// by rendering each type, separated by ", ". +template <typename... ArgTs> class RPCTypeNameSequence {}; + +/// Render an empty TypeNameSequence to an ostream. +template <typename OStream> +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) { + return OS; +} + +/// Render a TypeNameSequence of a single type to an ostream. +template <typename OStream, typename ArgT> +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<ArgT> &V) { + OS << RPCTypeName<ArgT>::getName(); + return OS; +} + +/// Render a TypeNameSequence of more than one type to an ostream. +template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> +OStream& +operator<<(OStream &OS, const RPCTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { + OS << RPCTypeName<ArgT1>::getName() << ", " + << RPCTypeNameSequence<ArgT2, ArgTs...>(); + return OS; +} + +template <> +class RPCTypeName<void> { +public: + static const char* getName() { return "void"; } +}; + +template <> +class RPCTypeName<int8_t> { +public: + static const char* getName() { return "int8_t"; } +}; + +template <> +class RPCTypeName<uint8_t> { +public: + static const char* getName() { return "uint8_t"; } +}; + +template <> +class RPCTypeName<int16_t> { +public: + static const char* getName() { return "int16_t"; } +}; + +template <> +class RPCTypeName<uint16_t> { +public: + static const char* getName() { return "uint16_t"; } +}; + +template <> +class RPCTypeName<int32_t> { +public: + static const char* getName() { return "int32_t"; } +}; + +template <> +class RPCTypeName<uint32_t> { +public: + static const char* getName() { return "uint32_t"; } +}; + +template <> +class RPCTypeName<int64_t> { +public: + static const char* getName() { return "int64_t"; } +}; + +template <> +class RPCTypeName<uint64_t> { +public: + static const char* getName() { return "uint64_t"; } +}; + +template <> +class RPCTypeName<bool> { +public: + static const char* getName() { return "bool"; } +}; + +template <> +class RPCTypeName<std::string> { +public: + static const char* getName() { return "std::string"; } +}; + +template <typename T1, typename T2> +class RPCTypeName<std::pair<T1, T2>> { +public: + static const char* getName() { + std::lock_guard<std::mutex> Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence<T1, T2>() + << ">"; + return Name.data(); + } +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template <typename T1, typename T2> +std::mutex RPCTypeName<std::pair<T1, T2>>::NameMutex; +template <typename T1, typename T2> +std::string RPCTypeName<std::pair<T1, T2>>::Name; + +template <typename... ArgTs> +class RPCTypeName<std::tuple<ArgTs...>> { +public: + static const char* getName() { + std::lock_guard<std::mutex> Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "std::tuple<" + << RPCTypeNameSequence<ArgTs...>() << ">"; + return Name.data(); + } +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template <typename... ArgTs> +std::mutex RPCTypeName<std::tuple<ArgTs...>>::NameMutex; +template <typename... ArgTs> +std::string RPCTypeName<std::tuple<ArgTs...>>::Name; + +template <typename T> +class RPCTypeName<std::vector<T>> { +public: + static const char*getName() { + std::lock_guard<std::mutex> Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "std::vector<" << RPCTypeName<T>::getName() + << ">"; + return Name.data(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template <typename T> +std::mutex RPCTypeName<std::vector<T>>::NameMutex; +template <typename T> +std::string RPCTypeName<std::vector<T>>::Name; + /// The SerializationTraits<ChannelT, T> class describes how to serialize and /// deserialize an instance of type T to/from an abstract channel of type @@ -51,71 +208,92 @@ namespace remote { /// } /// /// @endcode -template <typename ChannelT, typename T, typename = void> +template <typename ChannelT, typename WireType, typename From = WireType, + typename = void> class SerializationTraits {}; -/// TypeNameSequence is a utility for rendering sequences of types to a string -/// by rendering each type, separated by ", ". -template <typename ChannelT, typename... ArgTs> class TypeNameSequence {}; +template <typename ChannelT> +class SequenceTraits { +public: + static Error emitSeparator(ChannelT &C) { return Error::success(); } + static Error consumeSeparator(ChannelT &C) { return Error::success(); } +}; -/// Render a TypeNameSequence of a single type to an ostream. -template <typename OStream, typename ChannelT, typename ArgT> -OStream &operator<<(OStream &OS, const TypeNameSequence<ChannelT, ArgT> &V) { - OS << SerializationTraits<ChannelT, ArgT>::getName(); - return OS; -} +/// Utility class for serializing sequences of values of varying types. +/// Specializations of this class contain 'serialize' and 'deserialize' methods +/// for the given channel. The ArgTs... list will determine the "over-the-wire" +/// types to be serialized. The serialize and deserialize methods take a list +/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., +/// but may be different types from ArgTs, provided that for each CArgT there +/// is a SerializationTraits specialization +/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the +/// caller argument to over-the-wire value. +template <typename ChannelT, typename... ArgTs> +class SequenceSerialization; -/// Render a TypeNameSequence of more than one type to an ostream. -template <typename OStream, typename ChannelT, typename ArgT1, typename ArgT2, - typename... ArgTs> -OStream & -operator<<(OStream &OS, - const TypeNameSequence<ChannelT, ArgT1, ArgT2, ArgTs...> &V) { - OS << SerializationTraits<ChannelT, ArgT1>::getName() << ", " - << TypeNameSequence<ChannelT, ArgT2, ArgTs...>(); - return OS; -} +template <typename ChannelT> +class SequenceSerialization<ChannelT> { +public: + static Error serialize(ChannelT &C) { return Error::success(); } + static Error deserialize(ChannelT &C) { return Error::success(); } +}; -/// RPC channel serialization for a variadic list of arguments. -template <typename ChannelT, typename T, typename... Ts> -Error serializeSeq(ChannelT &C, const T &Arg, const Ts &... Args) { - if (auto Err = SerializationTraits<ChannelT, T>::serialize(C, Arg)) - return Err; - return serializeSeq(C, Args...); -} +template <typename ChannelT, typename ArgT> +class SequenceSerialization<ChannelT, ArgT> { +public: -/// RPC channel serialization for an (empty) variadic list of arguments. -template <typename ChannelT> Error serializeSeq(ChannelT &C) { - return Error::success(); -} + template <typename CArgT> + static Error serialize(ChannelT &C, const CArgT &CArg) { + return SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg); + } + + template <typename CArgT> + static Error deserialize(ChannelT &C, CArgT &CArg) { + return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); + } +}; + +template <typename ChannelT, typename ArgT, typename... ArgTs> +class SequenceSerialization<ChannelT, ArgT, ArgTs...> { +public: -/// RPC channel deserialization for a variadic list of arguments. -template <typename ChannelT, typename T, typename... Ts> -Error deserializeSeq(ChannelT &C, T &Arg, Ts &... Args) { - if (auto Err = SerializationTraits<ChannelT, T>::deserialize(C, Arg)) - return Err; - return deserializeSeq(C, Args...); + template <typename CArgT, typename... CArgTs> + static Error serialize(ChannelT &C, const CArgT &CArg, + const CArgTs&... CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); + } + + template <typename CArgT, typename... CArgTs> + static Error deserialize(ChannelT &C, CArgT &CArg, + CArgTs&... CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); + } +}; + +template <typename ChannelT, typename... ArgTs> +Error serializeSeq(ChannelT &C, const ArgTs &... Args) { + return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, Args...); } -/// RPC channel serialization for an (empty) variadic list of arguments. -template <typename ChannelT> Error deserializeSeq(ChannelT &C) { - return Error::success(); +template <typename ChannelT, typename... ArgTs> +Error deserializeSeq(ChannelT &C, ArgTs &... Args) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); } /// SerializationTraits default specialization for std::pair. template <typename ChannelT, typename T1, typename T2> class SerializationTraits<ChannelT, std::pair<T1, T2>> { public: - static const char *getName() { - std::lock_guard<std::mutex> Lock(NameMutex); - if (Name.empty()) - Name = (std::ostringstream() - << "std::pair<" << TypeNameSequence<ChannelT, T1, T2>() << ">") - .str(); - - return Name.data(); - } - static Error serialize(ChannelT &C, const std::pair<T1, T2> &V) { return serializeSeq(C, V.first, V.second); } @@ -123,31 +301,12 @@ public: static Error deserialize(ChannelT &C, std::pair<T1, T2> &V) { return deserializeSeq(C, V.first, V.second); } - -private: - static std::mutex NameMutex; - static std::string Name; }; -template <typename ChannelT, typename T1, typename T2> -std::mutex SerializationTraits<ChannelT, std::pair<T1, T2>>::NameMutex; - -template <typename ChannelT, typename T1, typename T2> -std::string SerializationTraits<ChannelT, std::pair<T1, T2>>::Name; - /// SerializationTraits default specialization for std::tuple. template <typename ChannelT, typename... ArgTs> class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { public: - static const char *getName() { - std::lock_guard<std::mutex> Lock(NameMutex); - if (Name.empty()) - Name = (std::ostringstream() - << "std::tuple<" << TypeNameSequence<ChannelT, ArgTs...>() << ">") - .str(); - - return Name.data(); - } /// RPC channel serialization for std::tuple. static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { @@ -173,68 +332,41 @@ private: llvm::index_sequence<Is...> _) { return deserializeSeq(C, std::get<Is>(V)...); } - - static std::mutex NameMutex; - static std::string Name; }; -template <typename ChannelT, typename... ArgTs> -std::mutex SerializationTraits<ChannelT, std::tuple<ArgTs...>>::NameMutex; - -template <typename ChannelT, typename... ArgTs> -std::string SerializationTraits<ChannelT, std::tuple<ArgTs...>>::Name; - /// SerializationTraits default specialization for std::vector. template <typename ChannelT, typename T> class SerializationTraits<ChannelT, std::vector<T>> { public: - static const char *getName() { - std::lock_guard<std::mutex> Lock(NameMutex); - if (Name.empty()) - Name = (std::ostringstream() << "std::vector<" - << TypeNameSequence<ChannelT, T>() << ">") - .str(); - return Name.data(); - } + /// Serialize a std::vector<T> from std::vector<T>. static Error serialize(ChannelT &C, const std::vector<T> &V) { - if (auto Err = SerializationTraits<ChannelT, uint64_t>::serialize( - C, static_cast<uint64_t>(V.size()))) + if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) return Err; for (const auto &E : V) - if (auto Err = SerializationTraits<ChannelT, T>::serialize(C, E)) + if (auto Err = serializeSeq(C, E)) return Err; return Error::success(); } + /// Deserialize a std::vector<T> to a std::vector<T>. static Error deserialize(ChannelT &C, std::vector<T> &V) { uint64_t Count = 0; - if (auto Err = - SerializationTraits<ChannelT, uint64_t>::deserialize(C, Count)) + if (auto Err = deserializeSeq(C, Count)) return Err; V.resize(Count); for (auto &E : V) - if (auto Err = SerializationTraits<ChannelT, T>::deserialize(C, E)) + if (auto Err = deserializeSeq(C, E)) return Err; return Error::success(); } - -private: - static std::mutex NameMutex; - static std::string Name; }; -template <typename ChannelT, typename T> -std::mutex SerializationTraits<ChannelT, std::vector<T>>::NameMutex; - -template <typename ChannelT, typename T> -std::string SerializationTraits<ChannelT, std::vector<T>>::Name; - -} // end namespace remote +} // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 436c037e920..2ff27efd72d 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -1,4 +1,4 @@ -//===----- RPCUTils.h - Basic tilities for building RPC APIs ----*- C++ -*-===// +//===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -7,7 +7,11 @@ // //===----------------------------------------------------------------------===// // -// Basic utilities for building RPC APIs. +// Utilities to support construction of simple RPC APIs. +// +// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ +// programmers, high performance, low memory overhead, and efficient use of the +// communications channel. // //===----------------------------------------------------------------------===// @@ -15,10 +19,12 @@ #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H #include <map> +#include <thread> #include <vector> #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/ExecutionEngine/Orc/RPCSerialization.h" #ifdef _MSC_VER // concrt.h depends on eh.h for __uncaught_exception declaration @@ -39,32 +45,92 @@ namespace llvm { namespace orc { -namespace remote { +namespace rpc { -/// Describes reserved RPC Function Ids. -/// -/// The default implementation will serve for integer and enum function id -/// types. If you want to use a custom type as your FunctionId you can -/// specialize this class and provide unique values for InvalidId, -/// ResponseId and FirstValidId. +template <typename DerivedFunc, typename FnT> +class Function; -template <typename T> class RPCFunctionIdTraits { +// RPC Function class. +// DerivedFunc should be a user defined class with a static 'getName()' method +// returning a const char* representing the function's name. +template <typename DerivedFunc, typename RetT, typename... ArgTs> +class Function<DerivedFunc, RetT(ArgTs...)> { public: - static const T InvalidId = static_cast<T>(0); - static const T ResponseId = static_cast<T>(1); - static const T FirstValidId = static_cast<T>(2); + + /// User defined function type. + using Type = RetT(ArgTs...); + + /// Return type. + using ReturnType = RetT; + + /// Returns the full function prototype as a string. + static const char *getPrototype() { + std::lock_guard<std::mutex> Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) + << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName() + << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")"; + return Name.data(); + } +private: + static std::mutex NameMutex; + static std::string Name; }; -// Base class containing utilities that require partial specialization. -// These cannot be included in RPC, as template class members cannot be -// partially specialized. -class RPCBase { -protected: - // FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation - // supports classes without default constructors. +template <typename DerivedFunc, typename RetT, typename... ArgTs> +std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex; + +template <typename DerivedFunc, typename RetT, typename... ArgTs> +std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; + +/// Allocates RPC function ids during autonegotiation. +/// Specializations of this class must provide four members: +/// +/// static T getInvalidId(): +/// Should return a reserved id that will be used to represent missing +/// functions during autonegotiation. +/// +/// static T getResponseId(): +/// Should return a reserved id that will be used to send function responses +/// (return values). +/// +/// static T getNegotiateId(): +/// Should return a reserved id for the negotiate function, which will be used +/// to negotiate ids for user defined functions. +/// +/// template <typename Func> T allocate(): +/// Allocate a unique id for function Func. +template <typename T, typename = void> +class RPCFunctionIdAllocator; + +/// This specialization of RPCFunctionIdAllocator provides a default +/// implementation for integral types. +template <typename T> +class RPCFunctionIdAllocator<T, + typename std::enable_if< + std::is_integral<T>::value + >::type> { +public: + + static T getInvalidId() { return T(0); } + static T getResponseId() { return T(1); } + static T getNegotiateId() { return T(2); } + + template <typename Func> + T allocate(){ return NextId++; } +private: + T NextId = 3; +}; + +namespace detail { + +// FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation +// supports classes without default constructors. #ifdef _MSC_VER +namespace msvc_hacks { + // Work around MSVC's future implementation's use of default constructors: // A default constructed value in the promise will be overwritten when the // real error is set - so the default constructed Error has to be checked @@ -86,7 +152,7 @@ protected: MSVCPError(Error Err) : Error(std::move(Err)) {} }; - // Likewise for Expected: + // Work around MSVC's future implementation, similar to MSVCPError. template <typename T> class MSVCPExpected : public Expected<T> { public: @@ -123,488 +189,531 @@ protected: nullptr) : Expected<T>(std::move(Other)) {} }; +} // end namespace msvc_hacks + #endif // _MSC_VER - // RPC Function description type. - // - // This class provides the information and operations needed to support the - // RPC primitive operations (call, expect, etc) for a given function. It - // is specialized for void and non-void functions to deal with the differences - // betwen the two. Both specializations have the same interface: - // - // Id - The function's unique identifier. - // ErrorReturn - The return type for blocking calls. - // readResult - Deserialize a result from a channel. - // abandon - Abandon a promised result. - // respond - Retun a result on the channel. - template <typename FunctionIdT, FunctionIdT FuncId, typename FnT> - class FunctionHelper {}; - - // RPC Function description specialization for non-void functions. - template <typename FunctionIdT, FunctionIdT FuncId, typename RetT, - typename... ArgTs> - class FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> { - public: - static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && - FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, - "Cannot define custom function with InvalidId or ResponseId. " - "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); +// ResultTraits provides typedefs and utilities specific to the return type +// of functions. +template <typename RetT> +class ResultTraits { +public: + + // The return type wrapped in llvm::Expected. + using ErrorReturnType = Expected<RetT>; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<msvc_hacks::MSVCPExpected<RetT>>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<msvc_hacks::MSVCPExpected<RetT>>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; - static const FunctionIdT Id = FuncId; + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; +#endif - typedef Expected<RetT> ErrorReturn; + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType(RetT()); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType RetOrErr) { + consumeError(RetOrErr.takeError()); + } +}; + +// ResultTraits specialization for void functions. +template <> +class ResultTraits<void> { +public: + + // For void functions, ErrorReturnType is llvm::Error. + using ErrorReturnType = Error; - // FIXME: Ditch PErrorReturn (replace it with plain ErrorReturn) once MSVC's - // std::future implementation supports types without default - // constructors. #ifdef _MSC_VER - typedef MSVCPExpected<RetT> PErrorReturn; + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<msvc_hacks::MSVCPError>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<msvc_hacks::MSVCPError>; #else - typedef Expected<RetT> PErrorReturn; + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise<ErrorReturnType>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future<ErrorReturnType>; #endif - template <typename ChannelT> - static Error readResult(ChannelT &C, std::promise<PErrorReturn> &P) { - RetT Val; - auto Err = deserializeSeq(C, Val); - auto Err2 = endReceiveMessage(C); - Err = joinErrors(std::move(Err), std::move(Err2)); - if (Err) - return Err; + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType::success(); + } - P.set_value(std::move(Val)); - return Error::success(); - } + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType Err) { + consumeError(std::move(Err)); + } +}; - static void abandon(std::promise<PErrorReturn> &P) { - P.set_value( - make_error<StringError>("RPC function call failed to return", - inconvertibleErrorCode())); - } +// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows +// handlers for void RPC functions to return either void (in which case they +// implicitly succeed) or Error (in which case their error return is +// propagated). See usage in HandlerTraits::runHandlerHelper. +template <> +class ResultTraits<Error> : public ResultTraits<void> {}; + +// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows +// handlers for RPC functions returning a T to return either a T (in which +// case they implicitly succeed) or Expected<T> (in which case their error +// return is propagated). See usage in HandlerTraits::runHandlerHelper. +template <typename RetT> +class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; + +// Send a response of the given wire return type (WireRetT) over the +// channel, with the given sequence number. +template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> +static Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { + // If this was an error bail out. + // FIXME: Send an "error" message to the client if this is not a channel + // failure? + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = SerializationTraits<ChannelT, WireRetT, HandlerRetT>:: + serialize(C, *ResultOrErr)) + return Err; + + // Close the response message. + return C.endSendMessage(); +} + +// Send an empty response message on the given channel to indicate that +// the handler ran. +template <typename WireRetT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +static Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + return C.endSendMessage(); +} + +// This template class provides utilities related to RPC function handlers. +// The base case applies to non-function types (the template class is +// specialized for function types) and inherits from the appropriate +// speciilization for the given non-function type's call operator. +template <typename HandlerT> +class HandlerTraits + : public HandlerTraits<decltype( + &std::remove_reference<HandlerT>::type::operator())> {}; + +// Traits for handlers with a given function type. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(ArgTs...)> { +public: - static void consumeAbandoned(std::future<PErrorReturn> &P) { - consumeError(P.get().takeError()); - } + // Function type of the handler. + using Type = RetT(ArgTs...); - template <typename ChannelT, typename SequenceNumberT> - static Error respond(ChannelT &C, SequenceNumberT SeqNo, - ErrorReturn &Result) { - FunctionIdT ResponseId = RPCFunctionIdTraits<FunctionIdT>::ResponseId; + // Return type of the handler. + using ReturnType = RetT; - // If the handler returned an error then bail out with that. - if (!Result) - return Result.takeError(); + // A std::tuple wrapping the handler arguments. + using ArgStorage = + std::tuple< + typename std::decay< + typename std::remove_reference<ArgTs>::type>::type...>; - // Otherwise open a new message on the channel and send the result. - if (auto Err = startSendMessage(C)) - return Err; - if (auto Err = serializeSeq(C, ResponseId, SeqNo, *Result)) - return Err; - return endSendMessage(C); - } - }; + // Call the given handler with the given arguments. + template <typename HandlerT> + static typename ResultTraits<RetT>::ErrorReturnType + runHandler(HandlerT &Handler, ArgStorage &Args) { + return runHandlerHelper<RetT>(Handler, Args, + llvm::index_sequence_for<ArgTs...>()); + } - // RPC Function description specialization for void functions. - template <typename FunctionIdT, FunctionIdT FuncId, typename... ArgTs> - class FunctionHelper<FunctionIdT, FuncId, void(ArgTs...)> { - public: - static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && - FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, - "Cannot define custom function with InvalidId or ResponseId. " - "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); + // Serialize arguments to the channel. + template <typename ChannelT, typename... CArgTs> + static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { + return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); + } - static const FunctionIdT Id = FuncId; + // Deserialize arguments from the channel. + template <typename ChannelT, typename... CArgTs> + static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { + return deserializeArgsHelper(C, Args, + llvm::index_sequence_for<CArgTs...>()); + } - typedef Error ErrorReturn; +private: - // FIXME: Ditch PErrorReturn (replace it with plain ErrorReturn) once MSVC's - // std::future implementation supports types without default - // constructors. -#ifdef _MSC_VER - typedef MSVCPError PErrorReturn; -#else - typedef Error PErrorReturn; -#endif + // For non-void user handlers: unwrap the args tuple and call the handler, + // returning the result. + template <typename RetTAlt, typename HandlerT, size_t... Indexes> + static typename std::enable_if< + !std::is_void<RetTAlt>::value, + typename ResultTraits<RetT>::ErrorReturnType>::type + runHandlerHelper(HandlerT &Handler, ArgStorage &Args, + llvm::index_sequence<Indexes...>) { + return Handler(std::move(std::get<Indexes>(Args))...); + } - template <typename ChannelT> - static Error readResult(ChannelT &C, std::promise<PErrorReturn> &P) { - // Void functions don't have anything to deserialize, so we're good. - P.set_value(Error::success()); - return endReceiveMessage(C); - } + // For void user handlers: unwrap the args tuple and call the handler, then + // return Error::success(). + template <typename RetTAlt, typename HandlerT, size_t... Indexes> + static typename std::enable_if< + std::is_void<RetTAlt>::value, + typename ResultTraits<RetT>::ErrorReturnType>::type + runHandlerHelper(HandlerT &Handler, ArgStorage &Args, + llvm::index_sequence<Indexes...>) { + Handler(std::move(std::get<Indexes>(Args))...); + return ResultTraits<RetT>::ErrorReturnType::success(); + } - static void abandon(std::promise<PErrorReturn> &P) { - P.set_value( - make_error<StringError>("RPC function call failed to return", - inconvertibleErrorCode())); - } + template <typename ChannelT, typename... CArgTs, size_t... Indexes> + static + Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, + llvm::index_sequence<Indexes...> _) { + return SequenceSerialization<ChannelT, ArgTs...>:: + deserialize(C, std::get<Indexes>(Args)...); + } - static void consumeAbandoned(std::future<PErrorReturn> &P) { - consumeError(P.get()); - } +}; - template <typename ChannelT, typename SequenceNumberT> - static Error respond(ChannelT &C, SequenceNumberT SeqNo, - ErrorReturn &Result) { - const FunctionIdT ResponseId = - RPCFunctionIdTraits<FunctionIdT>::ResponseId; +// Handler traits for class methods (especially call operators for lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...)> + : public HandlerTraits<RetT(ArgTs...)> {}; - // If the handler returned an error then bail out with that. - if (Result) - return std::move(Result); +// Handler traits for const class methods (especially call operators for +// lambdas). +template <typename Class, typename RetT, typename... ArgTs> +class HandlerTraits<RetT (Class::*)(ArgTs...) const> + : public HandlerTraits<RetT(ArgTs...)> {}; - // Otherwise open a new message on the channel and send the result. - if (auto Err = startSendMessage(C)) - return Err; - if (auto Err = serializeSeq(C, ResponseId, SeqNo)) - return Err; - return endSendMessage(C); - } - }; +// Utility to peel the Expected wrapper off a response handler error type. +template <typename HandlerT> +class UnwrapResponseHandlerArg; - // Helper for the call primitive. - template <typename ChannelT, typename SequenceNumberT, typename Func> - class CallHelper; +template <typename ArgT> +class UnwrapResponseHandlerArg<Error(Expected<ArgT>)> { +public: + using ArgType = ArgT; +}; - template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, - FunctionIdT FuncId, typename RetT, typename... ArgTs> - class CallHelper<ChannelT, SequenceNumberT, - FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { - public: - static Error call(ChannelT &C, SequenceNumberT SeqNo, - const ArgTs &... Args) { - if (auto Err = startSendMessage(C)) - return Err; - if (auto Err = serializeSeq(C, FuncId, SeqNo, Args...)) - return Err; - return endSendMessage(C); +template <typename ArgT> +class UnwrapResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { +public: + using ArgType = ArgT; +}; + + +// ResponseHandler represents a handler for a not-yet-received function call +// result. +template <typename ChannelT> +class ResponseHandler { +public: + virtual ~ResponseHandler() {} + + // Reads the function result off the wire and acts on it. The meaning of + // "act" will depend on how this method is implemented in any given + // ResponseHandler subclass but could, for example, mean running a + // user-specified handler or setting a promise value. + virtual Error handleResponse(ChannelT &C) = 0; + + // Abandons this outstanding result. + virtual void abandon() = 0; + + // Create an error instance representing an abandoned response. + static Error createAbandonedResponseError() { + return make_error<StringError>("RPC function call failed to return", + inconvertibleErrorCode()); + } +}; + +// ResponseHandler subclass for RPC functions with non-void returns. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) + : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using ArgType = typename UnwrapResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::ArgType; + ArgType Result; + if (auto Err = SerializationTraits<ChannelT, FuncRetT, ArgType>:: + deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Result); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); } - }; + } - // Helper for handle primitive. - template <typename ChannelT, typename SequenceNumberT, typename Func> - class HandlerHelper; +private: + HandlerT Handler; +}; - template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, - FunctionIdT FuncId, typename RetT, typename... ArgTs> - class HandlerHelper<ChannelT, SequenceNumberT, - FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { - public: - template <typename HandlerT> - static Error handle(ChannelT &C, HandlerT Handler) { - return readAndHandle(C, Handler, llvm::index_sequence_for<ArgTs...>()); +// ResponseHandler subclass for RPC functions with void returns. +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, void, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) + : Handler(std::move(Handler)) {} + + // Handle the result (no actual value, just a notification that the function + // has completed on the remote end) by calling the user-defined handler with + // Error::success(). + Error handleResponse(ChannelT &C) override { + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Error::success()); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); } + } - private: - typedef FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> Func; - - template <typename HandlerT, size_t... Is> - static Error readAndHandle(ChannelT &C, HandlerT Handler, - llvm::index_sequence<Is...> _) { - std::tuple<ArgTs...> RPCArgs; - SequenceNumberT SeqNo; - // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning - // for RPCArgs. Void cast RPCArgs to work around this for now. - // FIXME: Remove this workaround once we can assume a working GCC version. - (void)RPCArgs; - if (auto Err = deserializeSeq(C, SeqNo, std::get<Is>(RPCArgs)...)) - return Err; +private: + HandlerT Handler; +}; - // We've deserialized the arguments, so unlock the channel for reading - // before we call the handler. This allows recursive RPC calls. - if (auto Err = endReceiveMessage(C)) - return Err; +// Create a ResponseHandler from a given user handler. +template <typename ChannelT, typename FuncRetT, typename HandlerT> +std::unique_ptr<ResponseHandler<ChannelT>> +createResponseHandler(HandlerT H) { + return llvm::make_unique< + ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(std::move(H)); +} + +// Helper for wrapping member functions up as functors. This is useful for +// installing methods as result handlers. +template <typename ClassT, typename RetT, typename... ArgTs> +class MemberFnWrapper { +public: + using MethodT = RetT(ClassT::*)(ArgTs...); + MemberFnWrapper(ClassT &Instance, MethodT Method) + : Instance(Instance), Method(Method) {} + RetT operator()(ArgTs &&... Args) { + return (Instance.*Method)(std::move(Args)...); + } +private: + ClassT &Instance; + MethodT Method; +}; - // Run the handler and get the result. - auto Result = Handler(std::get<Is>(RPCArgs)...); +// Helper that provides a Functor for deserializing arguments. +template <typename... ArgTs> class ReadArgs { +public: + Error operator()() { return Error::success(); } +}; - // Return the result to the client. - return Func::template respond<ChannelT, SequenceNumberT>(C, SeqNo, - Result); - } - }; +template <typename ArgT, typename... ArgTs> +class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { +public: + ReadArgs(ArgT &Arg, ArgTs &... Args) + : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} - // Helper for wrapping member functions up as functors. - template <typename ClassT, typename RetT, typename... ArgTs> - class MemberFnWrapper { - public: - typedef RetT (ClassT::*MethodT)(ArgTs...); - MemberFnWrapper(ClassT &Instance, MethodT Method) - : Instance(Instance), Method(Method) {} - RetT operator()(ArgTs &... Args) { return (Instance.*Method)(Args...); } - - private: - ClassT &Instance; - MethodT Method; - }; + Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { + this->Arg = std::move(ArgVal); + return ReadArgs<ArgTs...>::operator()(ArgVals...); + } +private: + ArgT &Arg; +}; - // Helper that provides a Functor for deserializing arguments. - template <typename... ArgTs> class ReadArgs { - public: - Error operator()() { return Error::success(); } - }; +// Manage sequence numbers. +template <typename SequenceNumberT> +class SequenceNumberManager { +public: + // Reset, making all sequence numbers available. + void reset() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } - template <typename ArgT, typename... ArgTs> - class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { - public: - ReadArgs(ArgT &Arg, ArgTs &... Args) - : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} + // Get the next available sequence number. Will re-use numbers that have + // been released. + SequenceNumberT getSequenceNumber() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } - Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { - this->Arg = std::move(ArgVal); - return ReadArgs<ArgTs...>::operator()(ArgVals...); - } + // Release a sequence number, making it available for re-use. + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard<std::mutex> Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } - private: - ArgT &Arg; - }; +private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector<SequenceNumberT> FreeSequenceNumbers; }; /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a procedure -/// identifier type that must be serializable on ChannelT. +/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure +/// identifier type that must be serializable on ChannelT, and SequenceNumberT +/// is an integral type that will be used to number in-flight function calls. /// /// These utilities support the construction of very primitive RPC utilities. /// Their intent is to ensure correct serialization and deserialization of /// procedure arguments, and to keep the client and server's view of the API in /// sync. -/// -/// These utilities do not support return values. These can be handled by -/// declaring a corresponding '.*Response' procedure and expecting it after a -/// call). They also do not support versioning: the client and server *must* be -/// compiled with the same procedure definitions. -/// -/// -/// -/// Overview (see comments individual types/methods for details): -/// -/// Function<Id, Args...> : -/// -/// associates a unique serializable id with an argument list. -/// -/// -/// call<Func>(Channel, Args...) : -/// -/// Calls the remote procedure 'Func' by serializing Func's id followed by its -/// arguments and sending the resulting bytes to 'Channel'. -/// -/// -/// handle<Func>(Channel, <functor matching Error(Args...)> : -/// -/// Handles a call to 'Func' by deserializing its arguments and calling the -/// given functor. This assumes that the id for 'Func' has already been -/// deserialized. -/// -/// expect<Func>(Channel, <functor matching Error(Args...)> : -/// -/// The same as 'handle', except that the procedure id should not have been -/// read yet. Expect will deserialize the id and assert that it matches Func's -/// id. If it does not, and unexpected RPC call error is returned. -template <typename ChannelT, typename FunctionIdT = uint32_t, - typename SequenceNumberT = uint16_t> -class RPC : public RPCBase { -public: - /// RPC default constructor. - RPC() = default; +template <typename ImplT, typename ChannelT, typename FunctionIdT, + typename SequenceNumberT> +class RPCBase { +protected: - /// RPC instances cannot be copied. - RPC(RPC &&) = default; - RPC &operator=(RPC &&) = default; + class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { + public: + static const char *getName() { return "__orc_rpc$invalid"; } + }; - /// Utility class for defining/referring to RPC procedures. - /// - /// Typedefs of this utility are used when calling/handling remote procedures. - /// - /// FuncId should be a unique value of FunctionIdT (i.e. not used with any - /// other Function typedef in the RPC API being defined. - /// - /// the template argument Ts... gives the argument list for the remote - /// procedure. - /// - /// E.g. - /// - /// typedef Function<0, bool> Func1; - /// typedef Function<1, std::string, std::vector<int>> Func2; - /// - /// if (auto Err = call<Func1>(Channel, true)) - /// /* handle Err */; - /// - /// if (auto Err = expect<Func2>(Channel, - /// [](std::string &S, std::vector<int> &V) { - /// // Stuff. - /// return Error::success(); - /// }) - /// /* handle Err */; - /// - template <FunctionIdT FuncId, typename FnT> - using Function = FunctionHelper<FunctionIdT, FuncId, FnT>; + class OrcRPCResponse : public Function<OrcRPCResponse, void()> { + public: + static const char *getName() { return "__orc_rpc$response"; } + }; - /// Return type for non-blocking call primitives. - template <typename Func> - using NonBlockingCallResult = std::future<typename Func::PErrorReturn>; + class OrcRPCNegotiate + : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> { + public: + static const char *getName() { return "__orc_rpc$negotiate"; } + }; - /// Return type for non-blocking call-with-seq primitives. - template <typename Func> - using NonBlockingCallWithSeqResult = - std::pair<NonBlockingCallResult<Func>, SequenceNumberT>; +public: - /// Call Func on Channel C. Does not block, does not call send. Returns a pair - /// of a future result and the sequence number assigned to the result. - /// - /// This utility function is primarily used for single-threaded mode support, - /// where the sequence number can be used to wait for the corresponding - /// result. In multi-threaded mode the appendCallNB method, which does not - /// return the sequence numeber, should be preferred. - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallWithSeqResult<Func>> - appendCallNBWithSeq(ChannelT &C, const ArgTs &... Args) { - auto SeqNo = SequenceNumberMgr.getSequenceNumber(); - std::promise<typename Func::PErrorReturn> Promise; - auto Result = Promise.get_future(); - OutstandingResults[SeqNo] = - createOutstandingResult<Func>(std::move(Promise)); - - if (auto Err = CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, - Args...)) { - abandonOutstandingResults(); - Func::consumeAbandoned(Result); - return std::move(Err); - } else - return NonBlockingCallWithSeqResult<Func>(std::move(Result), SeqNo); + /// Construct an RPC instance on a channel. + RPCBase(ChannelT &C, bool LazyAutoNegotiation) + : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { + // Hold ResponseId in a special variable, since we expect Response to be + // called relatively frequently, and want to avoid the map lookup. + ResponseId = FnIdAllocator.getResponseId(); + RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; + + // Register the negotiate function id and handler. + auto NegotiateId = FnIdAllocator.getNegotiateId(); + RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; + Handlers[NegotiateId] = + wrapHandler<OrcRPCNegotiate>([this](const std::string &Name) { + return handleNegotiate(Name); + }, LaunchPolicy()); } - /// The same as appendCallNBWithSeq, except that it calls C.send() to - /// flush the channel after serializing the call. - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallWithSeqResult<Func>> - callNBWithSeq(ChannelT &C, const ArgTs &... Args) { - auto Result = appendCallNBWithSeq<Func>(C, Args...); - if (!Result) - return Result; - if (auto Err = C.send()) { - abandonOutstandingResults(); - Func::consumeAbandoned(Result->first); - return std::move(Err); + /// Append a call Func, does not call send on the channel. + /// The first argument specifies a user-defined handler to be run when the + /// function returns. The handler should take an Expected<Func::ReturnType>, + /// or an Error (if Func::ReturnType is void). The handler will be called + /// with an error if the return value is abandoned due to a channel error. + template <typename Func, typename HandlerT, typename... ArgTs> + Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { + // Look up the function ID. + FunctionIdT FnId; + if (auto FnIdOrErr = getRemoteFunctionId<Func>()) + FnId = *FnIdOrErr; + else { + // This isn't a channel error so we don't want to abandon other pending + // responses, but we still need to run the user handler with an error to + // let them know the call failed. + if (auto Err = Handler(orcError(OrcErrorCode::UnknownRPCFunction))) + report_fatal_error(std::move(Err)); + return FnIdOrErr.takeError(); } - return Result; - } - /// Serialize Args... to channel C, but do not call send. - /// Returns an error if serialization fails, otherwise returns a - /// std::future<Expected<T>> (or a future<Error> for void functions). - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallResult<Func>> appendCallNB(ChannelT &C, - const ArgTs &... Args) { - auto FutureResAndSeqOrErr = appendCallNBWithSeq<Func>(C, Args...); - if (FutureResAndSeqOrErr) - return std::move(FutureResAndSeqOrErr->first); - return FutureResAndSeqOrErr.takeError(); - } + // Allocate a sequence number. + auto SeqNo = SequenceNumberMgr.getSequenceNumber(); + assert(!PendingResponses.count(SeqNo) && + "Sequence number already allocated"); + + // Install the user handler. + PendingResponses[SeqNo] = + detail::createResponseHandler<ChannelT, typename Func::ReturnType>( + std::move(Handler)); + + // Open the function call message. + if (auto Err = C.startSendMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return joinErrors(std::move(Err), C.endSendMessage()); + } - /// The same as appendCallNB, except that it calls C.send to flush the - /// channel after serializing the call. - template <typename Func, typename... ArgTs> - Expected<NonBlockingCallResult<Func>> callNB(ChannelT &C, - const ArgTs &... Args) { - auto FutureResAndSeqOrErr = callNBWithSeq<Func>(C, Args...); - if (FutureResAndSeqOrErr) - return std::move(FutureResAndSeqOrErr->first); - return FutureResAndSeqOrErr.takeError(); - } + // Serialize the call arguments. + if (auto Err = + detail::HandlerTraits<typename Func::Type>:: + serializeArgs(C, Args...)) { + abandonPendingResponses(); + return joinErrors(std::move(Err), C.endSendMessage()); + } - /// Call Func on Channel C. Blocks waiting for a result. Returns an Error - /// for void functions or an Expected<T> for functions returning a T. - /// - /// This function is for use in threaded code where another thread is - /// handling responses and incoming calls. - template <typename Func, typename... ArgTs> - typename Func::ErrorReturn callB(ChannelT &C, const ArgTs &... Args) { - if (auto FutureResOrErr = callNBWithSeq<Func>(C, Args...)) { - if (auto Err = C.send()) { - abandonOutstandingResults(); - Func::consumeAbandoned(FutureResOrErr->first); - return std::move(Err); - } - return FutureResOrErr->first.get(); - } else - return FutureResOrErr.takeError(); - } + // Close the function call messagee. + if (auto Err = C.endSendMessage()) { + abandonPendingResponses(); + return std::move(Err); + } - /// Call Func on Channel C. Block waiting for a result. While blocked, run - /// HandleOther to handle incoming calls (Response calls will be handled - /// implicitly before calling HandleOther). Returns an Error for void - /// functions or an Expected<T> for functions returning a T. - /// - /// This function is for use in single threaded mode when the calling thread - /// must act as both sender and receiver. - template <typename Func, typename HandleFtor, typename... ArgTs> - typename Func::ErrorReturn - callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) { - if (auto ResultAndSeqNoOrErr = callNBWithSeq<Func>(C, Args...)) { - auto &ResultAndSeqNo = *ResultAndSeqNoOrErr; - if (auto Err = waitForResult(C, ResultAndSeqNo.second, HandleOther)) - return std::move(Err); - return ResultAndSeqNo.first.get(); - } else - return ResultAndSeqNoOrErr.takeError(); + return Error::success(); } - /// Call Func on Channel C. Block waiting for a result. Returns an Error for - /// void functions or an Expected<T> for functions returning a T. - template <typename Func, typename... ArgTs> - typename Func::ErrorReturn callST(ChannelT &C, const ArgTs &... Args) { - return callSTHandling<Func>(C, handleNone, Args...); - } - /// Start receiving a new function call. - /// - /// Calls startReceiveMessage on the channel, then deserializes a FunctionId - /// into Id. - Error startReceivingFunction(ChannelT &C, FunctionIdT &Id) { - if (auto Err = startReceiveMessage(C)) + template <typename Func, typename HandlerT, typename... ArgTs> + Error callAsync(HandlerT Handler, const ArgTs &... Args) { + if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) return Err; - - return deserializeSeq(C, Id); + return C.send(); } - /// Deserialize args for Func from C and call Handler. The signature of - /// handler must conform to 'Error(Args...)' where Args... matches - /// the arguments used in the Func typedef. - template <typename Func, typename HandlerT> - static Error handle(ChannelT &C, HandlerT Handler) { - return HandlerHelper<ChannelT, SequenceNumberT, Func>::handle(C, Handler); - } - - /// Helper version of 'handle' for calling member functions. - template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - static Error handle(ChannelT &C, ClassT &Instance, - RetT (ClassT::*HandlerMethod)(ArgTs...)) { - return handle<Func>( - C, MemberFnWrapper<ClassT, RetT, ArgTs...>(Instance, HandlerMethod)); - } - - /// Deserialize a FunctionIdT from C and verify it matches the id for Func. - /// If the id does match, deserialize the arguments and call the handler - /// (similarly to handle). - /// If the id does not match, return an unexpect RPC call error and do not - /// deserialize any further bytes. - template <typename Func, typename HandlerT> - Error expect(ChannelT &C, HandlerT Handler) { - FunctionIdT FuncId; - if (auto Err = startReceivingFunction(C, FuncId)) - return std::move(Err); - if (FuncId != Func::Id) - return orcError(OrcErrorCode::UnexpectedRPCCall); - return handle<Func>(C, Handler); - } + /// Handle one incoming call. + Error handleOne() { + FunctionIdT FnId; + SequenceNumberT SeqNo; + if (auto Err = C.startReceiveMessage(FnId, SeqNo)) + return Err; + if (FnId == ResponseId) + return handleResponse(SeqNo); + auto I = Handlers.find(FnId); + if (I != Handlers.end()) + return I->second(C, SeqNo); - /// Helper version of expect for calling member functions. - template <typename Func, typename ClassT, typename... ArgTs> - static Error expect(ChannelT &C, ClassT &Instance, - Error (ClassT::*HandlerMethod)(ArgTs...)) { - return expect<Func>( - C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); + // else: No handler found. Report error to client? + return orcError(OrcErrorCode::UnexpectedRPCCall); } /// Helper for handling setter procedures - this method returns a functor that @@ -621,160 +730,417 @@ public: /// /* Handle Args */ ; /// template <typename... ArgTs> - static ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { - return ReadArgs<ArgTs...>(Args...); + static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { + return detail::ReadArgs<ArgTs...>(Args...); } - /// Read a response from Channel. - /// This should be called from the receive loop to retrieve results. - Error handleResponse(ChannelT &C, SequenceNumberT *SeqNoRet = nullptr) { - SequenceNumberT SeqNo; - if (auto Err = deserializeSeq(C, SeqNo)) { - abandonOutstandingResults(); - return Err; - } +protected: + // The LaunchPolicy type allows a launch policy to be specified when adding + // a function handler. See addHandlerImpl. + using LaunchPolicy = std::function<Error(std::function<Error()>)>; + + /// Add the given handler to the handler map and make it available for + /// autonegotiation and execution. + template <typename Func, typename HandlerT> + void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler), + std::move(Launch)); + } - if (SeqNoRet) - *SeqNoRet = SeqNo; + // Abandon all outstanding results. + void abandonPendingResponses() { + for (auto &KV : PendingResponses) + KV.second->abandon(); + PendingResponses.clear(); + SequenceNumberMgr.reset(); + } - auto I = OutstandingResults.find(SeqNo); - if (I == OutstandingResults.end()) { - abandonOutstandingResults(); + Error handleResponse(SequenceNumberT SeqNo) { + auto I = PendingResponses.find(SeqNo); + if (I == PendingResponses.end()) { + abandonPendingResponses(); return orcError(OrcErrorCode::UnexpectedRPCResponse); } - if (auto Err = I->second->readResult(C)) { - abandonOutstandingResults(); - // FIXME: Release sequence numbers? + auto PRHandler = std::move(I->second); + PendingResponses.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + + if (auto Err = PRHandler->handleResponse(C)) { + abandonPendingResponses(); + SequenceNumberMgr.reset(); return Err; } - OutstandingResults.erase(I); - SequenceNumberMgr.releaseSequenceNumber(SeqNo); - return Error::success(); } - // Loop waiting for a result with the given sequence number. - // This can be used as a receive loop if the user doesn't have a default. - template <typename HandleOtherFtor> - Error waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo, - HandleOtherFtor &HandleOther = handleNone) { - bool GotTgtResult = false; + FunctionIdT handleNegotiate(const std::string &Name) { + auto I = LocalFunctionIds.find(Name); + if (I == LocalFunctionIds.end()) + return FnIdAllocator.getInvalidId(); + return I->second; + } - while (!GotTgtResult) { - FunctionIdT Id = RPCFunctionIdTraits<FunctionIdT>::InvalidId; - if (auto Err = startReceivingFunction(C, Id)) - return Err; - if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) { - SequenceNumberT SeqNo; - if (auto Err = handleResponse(C, &SeqNo)) - return Err; - GotTgtResult = (SeqNo == TgtSeqNo); - } else if (auto Err = HandleOther(C, Id)) - return Err; + // Find the remote FunctionId for the given function, which must be in the + // RemoteFunctionIds map. + template <typename Func> + Expected<FunctionIdT> getRemoteFunctionId() { + // Try to find the id for the given function. + auto I = RemoteFunctionIds.find(Func::getPrototype()); + + // If we have it in the map, return it. + if (I != RemoteFunctionIds.end()) + return I->second; + + // Otherwise, if we have auto-negotiation enabled, try to negotiate it. + if (LazyAutoNegotiation) { + auto &Impl = static_cast<ImplT&>(*this); + if (auto RemoteIdOrErr = + Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { + auto &RemoteId = *RemoteIdOrErr; + + // If autonegotiation indicates that the remote end doesn't support this + // function, return an unknown function error. + if (RemoteId == FnIdAllocator.getInvalidId()) + return orcError(OrcErrorCode::UnknownRPCFunction); + + // Autonegotiation succeeded and returned a valid id. Update the map and + // return the id. + RemoteFunctionIds[Func::getPrototype()] = RemoteId; + return RemoteId; + } else { + // Autonegotiation failed. Return the error. + return RemoteIdOrErr.takeError(); + } } - return Error::success(); + // No key was available in the map and autonegotiation wasn't enabled. + // Return an unknown function error. + return orcError(OrcErrorCode::UnknownRPCFunction); } - // Default handler for 'other' (non-response) functions when waiting for a - // result from the channel. - static Error handleNone(ChannelT &, FunctionIdT) { - return orcError(OrcErrorCode::UnexpectedRPCCall); - }; + using WrappedHandlerFn = std::function<Error(ChannelT&, SequenceNumberT)>; + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) { + return + [this, Handler, Launch](ChannelT &Channel, SequenceNumberT SeqNo) -> Error { + // Start by deserializing the arguments. + auto Args = + std::make_shared<typename detail::HandlerTraits<HandlerT>::ArgStorage>(); + if (auto Err = detail::HandlerTraits<typename Func::Type>:: + deserializeArgs(Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + // Build the handler/responder. + auto Responder = + [this, Handler, Args, &Channel, SeqNo]() mutable -> Error { + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, + HTraits::runHandler(Handler, + *Args)); + }; + + // If there is an explicit launch policy then use it to launch the + // handler. + if (Launch) + return Launch(std::move(Responder)); + + // Otherwise run the handler on the listener thread. + return Responder(); + }; + } + + ChannelT &C; + + bool LazyAutoNegotiation; + RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; + + FunctionIdT ResponseId; + std::map<std::string, FunctionIdT> LocalFunctionIds; + std::map<const char*, FunctionIdT> RemoteFunctionIds; + + std::map<FunctionIdT, WrappedHandlerFn> Handlers; + + detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; + std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> + PendingResponses; +}; + +} // end namespace detail + + +template <typename ChannelT, + typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class MultiThreadedRPC + : public detail::RPCBase<MultiThreadedRPC<ChannelT, FunctionIdT, + SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { private: - // Manage sequence numbers. - class SequenceNumberManager { - public: - SequenceNumberManager() = default; + using BaseClass = + detail::RPCBase<MultiThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; - SequenceNumberManager(const SequenceNumberManager &) = delete; - SequenceNumberManager &operator=(const SequenceNumberManager &) = delete; +public: - SequenceNumberManager(SequenceNumberManager &&Other) - : NextSequenceNumber(std::move(Other.NextSequenceNumber)), - FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {} + MultiThreadedRPC(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} - SequenceNumberManager &operator=(SequenceNumberManager &&Other) { - NextSequenceNumber = std::move(Other.NextSequenceNumber); - FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers); - return *this; - } + /// The LaunchPolicy type allows a launch policy to be specified when adding + /// a function handler. See addHandler. + using LaunchPolicy = typename BaseClass::LaunchPolicy; - void reset() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - NextSequenceNumber = 0; - FreeSequenceNumbers.clear(); - } + /// Add a handler for the given RPC function. + /// This installs the given handler functor for the given RPC Function, and + /// makes the RPC function available for negotiation/calling from the remote. + /// + /// The optional LaunchPolicy argument can be used to control how the handler + /// is run when called: + /// + /// * If no LaunchPolicy is given, the handler code will be run on the RPC + /// handler thread that is reading from the channel. This handler cannot + /// make blocking RPC calls (since it would be blocking the thread used to + /// get the result), but can make non-blocking calls. + /// + /// * If a LaunchPolicy is given, the user's handler will be wrapped in a + /// call to serialize and send the result, and the resulting functor (with + /// type 'Error()' will be passed to the LaunchPolicy. The user can then + /// choose to add the wrapped handler to a work queue, spawn a new thread, + /// or anything else. + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) { + return this->template addHandlerImpl<Func>(std::move(Handler), + std::move(Launch)); + } + + /// Negotiate a function id for Func with the other end of the channel. + template <typename Func> + Error negotiateFunction() { + using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + + if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { + this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + return Error::success(); + } else + return RemoteIdOrErr.takeError(); + } - SequenceNumberT getSequenceNumber() { - std::lock_guard<std::mutex> Lock(SeqNoLock); - if (FreeSequenceNumbers.empty()) - return NextSequenceNumber++; - auto SequenceNumber = FreeSequenceNumbers.back(); - FreeSequenceNumbers.pop_back(); - return SequenceNumber; + /// Convenience method for negotiating multiple functions at once. + template <typename Func> + Error negotiateFunctions() { + return negotiateFunction<Func>(); + } + + /// Convenience method for negotiating multiple functions at once. + template <typename Func1, typename Func2, typename... Funcs> + Error negotiateFunctions() { + if (auto Err = negotiateFunction<Func1>()) + return Err; + return negotiateFunctions<Func2, Funcs...>(); + } + + /// Return type for non-blocking call primitives. + template <typename Func> + using NonBlockingCallResult = + typename detail::ResultTraits<typename Func::ReturnType>::ReturnFutureType; + + /// Call Func on Channel C. Does not block, does not call send. Returns a pair + /// of a future result and the sequence number assigned to the result. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallNB method, which does not + /// return the sequence numeber, should be preferred. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> + appendCallNB(const ArgTs &... Args) { + using RTraits = detail::ResultTraits<typename Func::ReturnType>; + using ErrorReturn = typename RTraits::ErrorReturnType; + using ErrorReturnPromise = typename RTraits::ReturnPromiseType; + + // FIXME: Stack allocate and move this into the handler once LLVM builds + // with C++14. + auto Promise = std::make_shared<ErrorReturnPromise>(); + auto FutureResult = Promise->get_future(); + + if (auto Err = this->template appendCallAsync<Func>( + [Promise](ErrorReturn RetOrErr) { + Promise->set_value(std::move(RetOrErr)); + return Error::success(); + }, Args...)) { + this->abandonPendingResponses(); + RTraits::consumeAbandoned(FutureResult.get()); + return std::move(Err); } + return std::move(FutureResult); + } - void releaseSequenceNumber(SequenceNumberT SequenceNumber) { - std::lock_guard<std::mutex> Lock(SeqNoLock); - FreeSequenceNumbers.push_back(SequenceNumber); + /// The same as appendCallNBWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<NonBlockingCallResult<Func>> + callNB(const ArgTs &... Args) { + auto Result = appendCallNB<Func>(Args...); + if (!Result) + return Result; + if (auto Err = this->C.send()) { + this->abandonPendingResponses(); + detail::ResultTraits<typename Func::ReturnType>:: + consumeAbandoned(std::move(Result->get())); + return std::move(Err); } + return Result; + } - private: - std::mutex SeqNoLock; - SequenceNumberT NextSequenceNumber = 0; - std::vector<SequenceNumberT> FreeSequenceNumbers; - }; + /// Call Func on Channel C. Blocks waiting for a result. Returns an Error + /// for void functions or an Expected<T> for functions returning a T. + /// + /// This function is for use in threaded code where another thread is + /// handling responses and incoming calls. + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &... Args) { + if (auto FutureResOrErr = callNB<Func>(Args...)) { + if (auto Err = this->C.send()) { + this->abandonPendingResponses(); + detail::ResultTraits<typename Func::ReturnType>:: + consumeAbandoned(std::move(FutureResOrErr->get())); + return std::move(Err); + } + return FutureResOrErr->get(); + } else + return FutureResOrErr.takeError(); + } - // Base class for results that haven't been returned from the other end of the - // RPC connection yet. - class OutstandingResult { - public: - virtual ~OutstandingResult() {} - virtual Error readResult(ChannelT &C) = 0; - virtual void abandon() = 0; - }; + /// Handle incoming RPC calls. + Error handlerLoop() { + while (true) + if (auto Err = this->handleOne()) + return Err; + return Error::success(); + } - // Outstanding results for a specific function. - template <typename Func> - class OutstandingResultImpl : public OutstandingResult { - private: - public: - OutstandingResultImpl(std::promise<typename Func::PErrorReturn> &&P) - : P(std::move(P)) {} +}; - Error readResult(ChannelT &C) override { return Func::readResult(C, P); } +template <typename ChannelT, + typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint32_t> +class SingleThreadedRPC + : public detail::RPCBase<SingleThreadedRPC<ChannelT, FunctionIdT, + SequenceNumberT>, + ChannelT, FunctionIdT, + SequenceNumberT> { +private: - void abandon() override { Func::abandon(P); } + using BaseClass = detail::RPCBase<SingleThreadedRPC<ChannelT, FunctionIdT, + SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; - private: - std::promise<typename Func::PErrorReturn> P; - }; + using LaunchPolicy = typename BaseClass::LaunchPolicy; + +public: + + SingleThreadedRPC(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} - // Create an outstanding result for the given function. + template <typename Func, typename HandlerT> + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler), + LaunchPolicy()); + } + + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + /// Negotiate a function id for Func with the other end of the channel. template <typename Func> - std::unique_ptr<OutstandingResult> - createOutstandingResult(std::promise<typename Func::PErrorReturn> &&P) { - return llvm::make_unique<OutstandingResultImpl<Func>>(std::move(P)); + Error negotiateFunction() { + using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + + if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { + this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + return Error::success(); + } else + return RemoteIdOrErr.takeError(); } - // Abandon all outstanding results. - void abandonOutstandingResults() { - for (auto &KV : OutstandingResults) - KV.second->abandon(); - OutstandingResults.clear(); - SequenceNumberMgr.reset(); + /// Convenience method for negotiating multiple functions at once. + template <typename Func> + Error negotiateFunctions() { + return negotiateFunction<Func>(); + } + + /// Convenience method for negotiating multiple functions at once. + template <typename Func1, typename Func2, typename... Funcs> + Error negotiateFunctions() { + if (auto Err = negotiateFunction<Func1>()) + return Err; + return negotiateFunctions<Func2, Funcs...>(); } - SequenceNumberManager SequenceNumberMgr; - std::map<SequenceNumberT, std::unique_ptr<OutstandingResult>> - OutstandingResults; + template <typename Func, typename... ArgTs, + typename AltRetT = typename Func::ReturnType> + typename detail::ResultTraits<AltRetT>::ErrorReturnType + callB(const ArgTs &... Args) { + bool ReceivedResponse = false; + using ResultType = + typename detail::ResultTraits<AltRetT>::ErrorReturnType; + auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); + + // We have to 'Check' result (which we know is in a success state at this + // point) so that it can be overwritten in the async handler. + (void)!!Result; + + if (auto Err = this->template appendCallAsync<Func>( + [&](ResultType R) { + Result = std::move(R); + ReceivedResponse = true; + return Error::success(); + }, Args...)) { + this->abandonPendingResponses(); + detail::ResultTraits<typename Func::ReturnType>:: + consumeAbandoned(std::move(Result)); + return std::move(Err); + } + + while (!ReceivedResponse) { + if (auto Err = this->handleOne()) { + this->abandonPendingResponses(); + detail::ResultTraits<typename Func::ReturnType>:: + consumeAbandoned(std::move(Result)); + return std::move(Err); + } + } + + return Result; + } + + //using detail::RPCBase<ChannelT, FunctionIdT, SequenceNumberT>::handleOne; + }; -} // end namespace remote +} // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/llvm/include/llvm/ExecutionEngine/Orc/RawByteChannel.h new file mode 100644 index 00000000000..c80074ffd7f --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RawByteChannel.h @@ -0,0 +1,182 @@ +//===- llvm/ExecutionEngine/Orc/RawByteChannel.h ----------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H + +#include "OrcError.h" +#include "RPCSerialization.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Error.h" +#include <cstddef> +#include <cstdint> +#include <mutex> +#include <string> +#include <tuple> +#include <type_traits> +#include <vector> + +namespace llvm { +namespace orc { +namespace rpc { + +/// Interface for byte-streams to be used with RPC. +class RawByteChannel { +public: + virtual ~RawByteChannel() {} + + /// Read Size bytes from the stream into *Dst. + virtual Error readBytes(char *Dst, unsigned Size) = 0; + + /// Read size bytes from *Src and append them to the stream. + virtual Error appendBytes(const char *Src, unsigned Size) = 0; + + /// Flush the stream if possible. + virtual Error send() = 0; + + /// Notify the channel that we're starting a message send. + /// Locks the channel for writing. + template <typename FunctionIdT, typename SequenceIdT> + Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { + if (auto Err = serializeSeq(*this, FnId, SeqNo)) + return Err; + writeLock.lock(); + return Error::success(); + } + + /// Notify the channel that we're ending a message send. + /// Unlocks the channel for writing. + Error endSendMessage() { + writeLock.unlock(); + return Error::success(); + } + + /// Notify the channel that we're starting a message receive. + /// Locks the channel for reading. + template <typename FunctionIdT, typename SequenceNumberT> + Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { + readLock.lock(); + return deserializeSeq(*this, FnId, SeqNo); + } + + /// Notify the channel that we're ending a message receive. + /// Unlocks the channel for reading. + Error endReceiveMessage() { + readLock.unlock(); + return Error::success(); + } + + /// Get the lock for stream reading. + std::mutex &getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex &getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; +}; + +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, T, T, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, uint8_t>::value || + std::is_same<T, int8_t>::value || + std::is_same<T, uint16_t>::value || + std::is_same<T, int16_t>::value || + std::is_same<T, uint32_t>::value || + std::is_same<T, int32_t>::value || + std::is_same<T, uint64_t>::value || + std::is_same<T, int64_t>::value || + std::is_same<T, char>::value)>::type> { +public: + static Error serialize(ChannelT &C, T V) { + support::endian::byte_swap<T, support::big>(V); + return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); + }; + + static Error deserialize(ChannelT &C, T &V) { + if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) + return Err; + support::endian::byte_swap<T, support::big>(V); + return Error::success(); + }; +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, bool, bool, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value>:: + type> { +public: + static Error serialize(ChannelT &C, bool V) { + return C.appendBytes(reinterpret_cast<const char *>(&V), 1); + } + + static Error deserialize(ChannelT &C, bool &V) { + return C.readBytes(reinterpret_cast<char *>(&V), 1); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, std::string, StringRef, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value>:: + type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, StringRef S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + return C.appendBytes((const char *)S.data(), S.size()); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, std::string, const char*, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value>:: + type> { +public: + static Error serialize(RawByteChannel &C, const char *S) { + return SerializationTraits<ChannelT, std::string, StringRef>:: + serialize(C, S); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, std::string, std::string, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value>:: + type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, const std::string &S) { + return SerializationTraits<ChannelT, std::string, StringRef>:: + serialize(C, S); + } + + /// RPC channel deserialization for std::strings. + static Error deserialize(RawByteChannel &C, std::string &S) { + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + S.resize(Count); + return C.readBytes(&S[0], Count); + } +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H diff --git a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt index 76720a7c52e..685e882e4a8 100644 --- a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt @@ -6,7 +6,6 @@ add_llvm_library(LLVMOrcJIT OrcCBindings.cpp OrcError.cpp OrcMCJITReplacement.cpp - OrcRemoteTargetRPCAPI.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc diff --git a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp index 64472f9ba37..48dcd442266 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp @@ -43,6 +43,8 @@ public: return "Unexpected RPC call"; case OrcErrorCode::UnexpectedRPCResponse: return "Unexpected RPC response"; + case OrcErrorCode::UnknownRPCFunction: + return "Unknown RPC function"; } llvm_unreachable("Unhandled error code"); } diff --git a/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp b/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp deleted file mode 100644 index d1a021aee3a..00000000000 --- a/llvm/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp +++ /dev/null @@ -1,53 +0,0 @@ -//===------- OrcRemoteTargetRPCAPI.cpp - ORC Remote API utilities ---------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" - -namespace llvm { -namespace orc { -namespace remote { - -#define FUNCNAME(X) \ - case X ## Id: \ - return #X - -const char *OrcRemoteTargetRPCAPI::getJITFuncIdName(JITFuncId Id) { - switch (Id) { - case InvalidId: - return "*** Invalid JITFuncId ***"; - FUNCNAME(CallIntVoid); - FUNCNAME(CallMain); - FUNCNAME(CallVoidVoid); - FUNCNAME(CreateRemoteAllocator); - FUNCNAME(CreateIndirectStubsOwner); - FUNCNAME(DeregisterEHFrames); - FUNCNAME(DestroyRemoteAllocator); - FUNCNAME(DestroyIndirectStubsOwner); - FUNCNAME(EmitIndirectStubs); - FUNCNAME(EmitResolverBlock); - FUNCNAME(EmitTrampolineBlock); - FUNCNAME(GetSymbolAddress); - FUNCNAME(GetRemoteInfo); - FUNCNAME(ReadMem); - FUNCNAME(RegisterEHFrames); - FUNCNAME(ReserveMem); - FUNCNAME(RequestCompile); - FUNCNAME(SetProtections); - FUNCNAME(TerminateSession); - FUNCNAME(WriteMem); - FUNCNAME(WritePtr); - }; - return nullptr; -} - -#undef FUNCNAME - -} // end namespace remote -} // end namespace orc -} // end namespace llvm diff --git a/llvm/tools/lli/ChildTarget/ChildTarget.cpp b/llvm/tools/lli/ChildTarget/ChildTarget.cpp index f6d2413655e..77b1d47a946 100644 --- a/llvm/tools/lli/ChildTarget/ChildTarget.cpp +++ b/llvm/tools/lli/ChildTarget/ChildTarget.cpp @@ -53,23 +53,12 @@ int main(int argc, char *argv[]) { RTDyldMemoryManager::deregisterEHFramesInProcess(Addr, Size); }; - FDRPCChannel Channel(InFD, OutFD); - typedef remote::OrcRemoteTargetServer<FDRPCChannel, HostOrcArch> JITServer; + FDRawChannel Channel(InFD, OutFD); + typedef remote::OrcRemoteTargetServer<FDRawChannel, HostOrcArch> JITServer; JITServer Server(Channel, SymbolLookup, RegisterEHFrames, DeregisterEHFrames); - while (1) { - uint32_t RawId; - ExitOnErr(Server.startReceivingFunction(Channel, RawId)); - auto Id = static_cast<JITServer::JITFuncId>(RawId); - switch (Id) { - case JITServer::TerminateSessionId: - ExitOnErr(Server.handleTerminateSession()); - return 0; - default: - ExitOnErr(Server.handleKnownFunction(Id)); - break; - } - } + while (!Server.receivedTerminate()) + ExitOnErr(Server.handleOne()); close(InFD); close(OutFD); diff --git a/llvm/tools/lli/RemoteJITUtils.h b/llvm/tools/lli/RemoteJITUtils.h index d47716cb880..89a51420256 100644 --- a/llvm/tools/lli/RemoteJITUtils.h +++ b/llvm/tools/lli/RemoteJITUtils.h @@ -14,7 +14,7 @@ #ifndef LLVM_TOOLS_LLI_REMOTEJITUTILS_H #define LLVM_TOOLS_LLI_REMOTEJITUTILS_H -#include "llvm/ExecutionEngine/Orc/RPCByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" #include <mutex> @@ -25,9 +25,9 @@ #endif /// RPC channel that reads from and writes from file descriptors. -class FDRPCChannel final : public llvm::orc::remote::RPCByteChannel { +class FDRawChannel final : public llvm::orc::rpc::RawByteChannel { public: - FDRPCChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} + FDRawChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} llvm::Error readBytes(char *Dst, unsigned Size) override { assert(Dst && "Attempt to read into null."); @@ -72,11 +72,12 @@ private: }; // launch the remote process (see lli.cpp) and return a channel to it. -std::unique_ptr<FDRPCChannel> launchRemote(); +std::unique_ptr<FDRawChannel> launchRemote(); namespace llvm { -// ForwardingMM - Adapter to connect MCJIT to Orc's Remote memory manager. +// ForwardingMM - Adapter to connect MCJIT to Orc's Remote8 +// memory manager. class ForwardingMemoryManager : public llvm::RTDyldMemoryManager { public: void setMemMgr(std::unique_ptr<RuntimeDyld::MemoryManager> MemMgr) { diff --git a/llvm/tools/lli/lli.cpp b/llvm/tools/lli/lli.cpp index 9dbe658beff..836a94037d7 100644 --- a/llvm/tools/lli/lli.cpp +++ b/llvm/tools/lli/lli.cpp @@ -654,20 +654,20 @@ int main(int argc, char **argv, char * const *envp) { // MCJIT itself. FIXME. // Lanch the remote process and get a channel to it. - std::unique_ptr<FDRPCChannel> C = launchRemote(); + std::unique_ptr<FDRawChannel> C = launchRemote(); if (!C) { errs() << "Failed to launch remote JIT.\n"; exit(1); } // Create a remote target client running over the channel. - typedef orc::remote::OrcRemoteTargetClient<orc::remote::RPCByteChannel> + typedef orc::remote::OrcRemoteTargetClient<orc::rpc::RawByteChannel> MyRemote; - MyRemote R = ExitOnErr(MyRemote::Create(*C)); + auto R = ExitOnErr(MyRemote::Create(*C)); // Create a remote memory manager. std::unique_ptr<MyRemote::RCMemoryManager> RemoteMM; - ExitOnErr(R.createRemoteMemoryManager(RemoteMM)); + ExitOnErr(R->createRemoteMemoryManager(RemoteMM)); // Forward MCJIT's memory manager calls to the remote memory manager. static_cast<ForwardingMemoryManager*>(RTDyldMM)->setMemMgr( @@ -678,7 +678,7 @@ int main(int argc, char **argv, char * const *envp) { orc::createLambdaResolver( [](const std::string &Name) { return nullptr; }, [&](const std::string &Name) { - if (auto Addr = ExitOnErr(R.getSymbolAddress(Name))) + if (auto Addr = ExitOnErr(R->getSymbolAddress(Name))) return JITSymbol(Addr, JITSymbolFlags::Exported); return JITSymbol(nullptr); } @@ -691,7 +691,7 @@ int main(int argc, char **argv, char * const *envp) { EE->finalizeObject(); DEBUG(dbgs() << "Executing '" << EntryFn->getName() << "' at 0x" << format("%llx", Entry) << "\n"); - Result = ExitOnErr(R.callIntVoid(Entry)); + Result = ExitOnErr(R->callIntVoid(Entry)); // Like static constructors, the remote target MCJIT support doesn't handle // this yet. It could. FIXME. @@ -702,13 +702,13 @@ int main(int argc, char **argv, char * const *envp) { EE.reset(); // Signal the remote target that we're done JITing. - ExitOnErr(R.terminateSession()); + ExitOnErr(R->terminateSession()); } return Result; } -std::unique_ptr<FDRPCChannel> launchRemote() { +std::unique_ptr<FDRawChannel> launchRemote() { #ifndef LLVM_ON_UNIX llvm_unreachable("launchRemote not supported on non-Unix platforms"); #else @@ -758,6 +758,6 @@ std::unique_ptr<FDRPCChannel> launchRemote() { close(PipeFD[1][1]); // Return an RPC channel connected to our end of the pipes. - return llvm::make_unique<FDRPCChannel>(PipeFD[1][0], PipeFD[0][1]); + return llvm::make_unique<FDRawChannel>(PipeFD[1][0], PipeFD[0][1]); #endif } diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 259a75a203f..4d703c78a0e 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ExecutionEngine/Orc/RPCByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/Orc/RPCUtils.h" #include "gtest/gtest.h" @@ -15,7 +15,7 @@ using namespace llvm; using namespace llvm::orc; -using namespace llvm::orc::remote; +using namespace llvm::orc::rpc; class Queue : public std::queue<char> { public: @@ -25,7 +25,7 @@ private: std::mutex Lock; }; -class QueueChannel : public RPCByteChannel { +class QueueChannel : public RawByteChannel { public: QueueChannel(Queue &InQueue, Queue &OutQueue) : InQueue(InQueue), OutQueue(OutQueue) {} @@ -61,126 +61,190 @@ private: Queue &OutQueue; }; -class DummyRPC : public testing::Test, public RPC<QueueChannel> { +class DummyRPCAPI { public: - enum FuncId : uint32_t { - VoidBoolId = RPCFunctionIdTraits<FuncId>::FirstValidId, - IntIntId, - AllTheTypesId + + class VoidBool : public Function<VoidBool, void(bool)> { + public: + static const char* getName() { return "VoidBool"; } + }; + + class IntInt : public Function<IntInt, int32_t(int32_t)> { + public: + static const char* getName() { return "IntInt"; } + }; + + class AllTheTypes + : public Function<AllTheTypes, + void(int8_t, uint8_t, int16_t, uint16_t, int32_t, + uint32_t, int64_t, uint64_t, bool, std::string, + std::vector<int>)> { + public: + static const char* getName() { return "AllTheTypes"; } }; +}; - typedef Function<VoidBoolId, void(bool)> VoidBool; - typedef Function<IntIntId, int32_t(int32_t)> IntInt; - typedef Function<AllTheTypesId, - void(int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, - int64_t, uint64_t, bool, std::string, std::vector<int>)> - AllTheTypes; +class DummyRPCEndpoint : public DummyRPCAPI, + public SingleThreadedRPC<QueueChannel> { +public: + DummyRPCEndpoint(Queue &Q1, Queue &Q2) + : SingleThreadedRPC(C, true), C(Q1, Q2) {} +private: + QueueChannel C; }; -TEST_F(DummyRPC, TestAsyncVoidBool) { +TEST(DummyRPC, TestAsyncVoidBool) { Queue Q1, Q2; - QueueChannel C1(Q1, Q2); - QueueChannel C2(Q2, Q1); + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); - // Make an async call. - auto ResOrErr = callNBWithSeq<VoidBool>(C1, true); - EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::VoidBool>( + [](bool B) { + EXPECT_EQ(B, true) + << "Server void(bool) received unexpected result"; + }); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the VoidBool call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); { - // Expect a call to Proc1. - auto EC = expect<VoidBool>(C2, [&](bool &B) { - EXPECT_EQ(B, true) << "Bool serialization broken"; - return Error::success(); - }); - EXPECT_FALSE(EC) << "Simple expect over queue failed"; + // Make an async call. + auto Err = Client.callAsync<DummyRPCAPI::VoidBool>( + [](Error Err) { + EXPECT_FALSE(!!Err) << "Async void(bool) response handler failed"; + return Error::success(); + }, true); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(bool)"; } { - // Wait for the result. - auto EC = waitForResult(C1, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; + // Poke the client to process the result of the void(bool) call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; } - // Verify that the function returned ok. - auto Err = ResOrErr->first.get(); - EXPECT_FALSE(!!Err) << "Remote void function failed to execute."; + ServerThread.join(); } -TEST_F(DummyRPC, TestAsyncIntInt) { +TEST(DummyRPC, TestAsyncIntInt) { Queue Q1, Q2; - QueueChannel C1(Q1, Q2); - QueueChannel C2(Q2, Q1); + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); - // Make an async call. - auto ResOrErr = callNBWithSeq<IntInt>(C1, 21); - EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::IntInt>( + [](int X) -> int { + EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result"; + return 2 * X; + }); - { - // Expect a call to Proc1. - auto EC = expect<IntInt>(C2, [&](int32_t I) -> Expected<int32_t> { - EXPECT_EQ(I, 21) << "Bool serialization broken"; - return 2 * I; + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the int(int) call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to int(int)"; + } }); - EXPECT_FALSE(EC) << "Simple expect over queue failed"; + + { + auto Err = Client.callAsync<DummyRPCAPI::IntInt>( + [](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + EXPECT_EQ(*Result, 42) + << "Async int(int) response handler received incorrect result"; + return Error::success(); + }, 21); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)"; } { - // Wait for the result. - auto EC = waitForResult(C1, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; + // Poke the client to process the result. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; } - // Verify that the function returned ok. - auto Val = ResOrErr->first.get(); - EXPECT_TRUE(!!Val) << "Remote int function failed to execute."; - EXPECT_EQ(*Val, 42) << "Remote int function return wrong value."; + ServerThread.join(); } -TEST_F(DummyRPC, TestSerialization) { +TEST(DummyRPC, TestSerialization) { Queue Q1, Q2; - QueueChannel C1(Q1, Q2); - QueueChannel C2(Q2, Q1); + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); - // Make a call to Proc1. - std::vector<int> v({42, 7}); - auto ResOrErr = callNBWithSeq<AllTheTypes>( - C1, -101, 250, -10000, 10000, -1000000000, 1000000000, -10000000000, - 10000000000, true, "foo", v); - EXPECT_TRUE(!!ResOrErr) << "Big (serialization test) call over queue failed"; + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::AllTheTypes>( + [&](int8_t S8, uint8_t U8, int16_t S16, uint16_t U16, + int32_t S32, uint32_t U32, int64_t S64, uint64_t U64, + bool B, std::string S, std::vector<int> V) { - { - // Expect a call to Proc1. - auto EC = expect<AllTheTypes>( - C2, [&](int8_t &s8, uint8_t &u8, int16_t &s16, uint16_t &u16, - int32_t &s32, uint32_t &u32, int64_t &s64, uint64_t &u64, - bool &b, std::string &s, std::vector<int> &v) { - - EXPECT_EQ(s8, -101) << "int8_t serialization broken"; - EXPECT_EQ(u8, 250) << "uint8_t serialization broken"; - EXPECT_EQ(s16, -10000) << "int16_t serialization broken"; - EXPECT_EQ(u16, 10000) << "uint16_t serialization broken"; - EXPECT_EQ(s32, -1000000000) << "int32_t serialization broken"; - EXPECT_EQ(u32, 1000000000ULL) << "uint32_t serialization broken"; - EXPECT_EQ(s64, -10000000000) << "int64_t serialization broken"; - EXPECT_EQ(u64, 10000000000ULL) << "uint64_t serialization broken"; - EXPECT_EQ(b, true) << "bool serialization broken"; - EXPECT_EQ(s, "foo") << "std::string serialization broken"; - EXPECT_EQ(v, std::vector<int>({42, 7})) + EXPECT_EQ(S8, -101) << "int8_t serialization broken"; + EXPECT_EQ(U8, 250) << "uint8_t serialization broken"; + EXPECT_EQ(S16, -10000) << "int16_t serialization broken"; + EXPECT_EQ(U16, 10000) << "uint16_t serialization broken"; + EXPECT_EQ(S32, -1000000000) << "int32_t serialization broken"; + EXPECT_EQ(U32, 1000000000ULL) << "uint32_t serialization broken"; + EXPECT_EQ(S64, -10000000000) << "int64_t serialization broken"; + EXPECT_EQ(U64, 10000000000ULL) << "uint64_t serialization broken"; + EXPECT_EQ(B, true) << "bool serialization broken"; + EXPECT_EQ(S, "foo") << "std::string serialization broken"; + EXPECT_EQ(V, std::vector<int>({42, 7})) << "std::vector serialization broken"; + return Error::success(); + }); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the AllTheTypes call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); + + + { + // Make an async call. + std::vector<int> v({42, 7}); + auto Err = Client.callAsync<DummyRPCAPI::AllTheTypes>( + [](Error Err) { + EXPECT_FALSE(!!Err) << "Async AllTheTypes response handler failed"; return Error::success(); - }); - EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; + }, + static_cast<int8_t>(-101), static_cast<uint8_t>(250), + static_cast<int16_t>(-10000), static_cast<uint16_t>(10000), + static_cast<int32_t>(-1000000000), static_cast<uint32_t>(1000000000), + static_cast<int64_t>(-10000000000), static_cast<uint64_t>(10000000000), + true, std::string("foo"), v); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for AllTheTypes"; } { - // Wait for the result. - auto EC = waitForResult(C1, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; + // Poke the client to process the result of the AllTheTypes call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from AllTheTypes"; } - // Verify that the function returned ok. - auto Err = ResOrErr->first.get(); - EXPECT_FALSE(!!Err) << "Remote void function failed to execute."; + ServerThread.join(); } // Test the synchronous call API. |