From a34680a33eb1caa5e224a9432e9f3e643824dc2d Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Wed, 9 Oct 2019 14:27:52 -0700 Subject: Break out OrcError and RPC Summary: When createing an ORC remote JIT target the current library split forces the target process to link large portions of LLVM (Core, Execution Engine, JITLink, Object, MC, Passes, RuntimeDyld, Support, Target, and TransformUtils). This occurs because the ORC RPC interfaces rely on the static globals the ORC Error types require, which starts a cycle of pulling in more and more. This patch breaks the ORC RPC Error implementations out into an "OrcError" library which only depends on LLVM Support. It also pulls the ORC RPC headers into their own subdirectory. With this patch code can include the Orc/RPC/*.h headers and will only incur link dependencies on LLVMOrcError and LLVMSupport. Reviewers: lhames Reviewed By: lhames Subscribers: mgorny, hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D68732 --- .../BuildingAJIT/Chapter5/RemoteJITUtils.h | 2 +- .../ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h | 4 +- .../ExecutionEngine/Orc/RPC/RPCSerialization.h | 703 ++++++++ .../llvm/ExecutionEngine/Orc/RPC/RPCUtils.h | 1690 ++++++++++++++++++++ .../llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h | 184 +++ .../llvm/ExecutionEngine/Orc/RPCSerialization.h | 703 -------- llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 1690 -------------------- .../llvm/ExecutionEngine/Orc/RawByteChannel.h | 184 --- llvm/lib/ExecutionEngine/CMakeLists.txt | 1 + llvm/lib/ExecutionEngine/LLVMBuild.txt | 2 +- llvm/lib/ExecutionEngine/Orc/CMakeLists.txt | 2 - llvm/lib/ExecutionEngine/Orc/LLVMBuild.txt | 4 +- llvm/lib/ExecutionEngine/Orc/OrcError.cpp | 115 -- llvm/lib/ExecutionEngine/Orc/RPCUtils.cpp | 54 - llvm/lib/ExecutionEngine/OrcError/CMakeLists.txt | 6 + llvm/lib/ExecutionEngine/OrcError/LLVMBuild.txt | 21 + llvm/lib/ExecutionEngine/OrcError/OrcError.cpp | 115 ++ llvm/lib/ExecutionEngine/OrcError/RPCError.cpp | 54 + llvm/tools/lli/RemoteJITUtils.h | 2 +- llvm/unittests/ExecutionEngine/Orc/QueueChannel.h | 2 +- .../unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 2 +- 21 files changed, 2783 insertions(+), 2757 deletions(-) create mode 100644 llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h create mode 100644 llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCUtils.h create mode 100644 llvm/include/llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h delete mode 100644 llvm/include/llvm/ExecutionEngine/Orc/RPCSerialization.h delete mode 100644 llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h delete mode 100644 llvm/include/llvm/ExecutionEngine/Orc/RawByteChannel.h delete mode 100644 llvm/lib/ExecutionEngine/Orc/OrcError.cpp delete mode 100644 llvm/lib/ExecutionEngine/Orc/RPCUtils.cpp create mode 100644 llvm/lib/ExecutionEngine/OrcError/CMakeLists.txt create mode 100644 llvm/lib/ExecutionEngine/OrcError/LLVMBuild.txt create mode 100644 llvm/lib/ExecutionEngine/OrcError/OrcError.cpp create mode 100644 llvm/lib/ExecutionEngine/OrcError/RPCError.cpp diff --git a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h index 14a2815bc02..c7d15bb8dd9 100644 --- a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h +++ b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h @@ -13,7 +13,7 @@ #ifndef LLVM_TOOLS_LLI_REMOTEJITUTILS_H #define LLVM_TOOLS_LLI_REMOTEJITUTILS_H -#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h" #include "llvm/Support/Error.h" #include #include diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index e7b598d8f81..3ff5a5f6e90 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -16,8 +16,8 @@ #define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H #include "llvm/ExecutionEngine/JITSymbol.h" -#include "llvm/ExecutionEngine/Orc/RPCUtils.h" -#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RPC/RPCUtils.h" +#include "llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h" namespace llvm { namespace orc { diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h new file mode 100644 index 00000000000..9c69a84f4c6 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h @@ -0,0 +1,703 @@ +//===- llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H +#define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H + +#include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/Support/thread.h" +#include +#include +#include +#include +#include +#include + +namespace llvm { +namespace orc { +namespace rpc { + +template +class RPCTypeName; + +/// TypeNameSequence is a utility for rendering sequences of types to a string +/// by rendering each type, separated by ", ". +template class RPCTypeNameSequence {}; + +/// Render an empty TypeNameSequence to an ostream. +template +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) { + return OS; +} + +/// Render a TypeNameSequence of a single type to an ostream. +template +OStream &operator<<(OStream &OS, const RPCTypeNameSequence &V) { + OS << RPCTypeName::getName(); + return OS; +} + +/// Render a TypeNameSequence of more than one type to an ostream. +template +OStream& +operator<<(OStream &OS, const RPCTypeNameSequence &V) { + OS << RPCTypeName::getName() << ", " + << RPCTypeNameSequence(); + return OS; +} + +template <> +class RPCTypeName { +public: + static const char* getName() { return "void"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int8_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint8_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int16_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint16_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int32_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint32_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int64_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint64_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "bool"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "std::string"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "Error"; } +}; + +template +class RPCTypeName> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "Expected<" + << RPCTypeNameSequence() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template +class RPCTypeName> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template +class RPCTypeName> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::tuple<" + << RPCTypeNameSequence() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template +class RPCTypeName> { +public: + static const char*getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::vector<" << RPCTypeName::getName() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template class RPCTypeName> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::set<" << RPCTypeName::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template class RPCTypeName> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::map<" << RPCTypeNameSequence() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +/// The SerializationTraits class describes how to serialize and +/// deserialize an instance of type T to/from an abstract channel of type +/// ChannelT. It also provides a representation of the type's name via the +/// getName method. +/// +/// Specializations of this class should provide the following functions: +/// +/// @code{.cpp} +/// +/// static const char* getName(); +/// static Error serialize(ChannelT&, const T&); +/// static Error deserialize(ChannelT&, T&); +/// +/// @endcode +/// +/// The third argument of SerializationTraits is intended to support SFINAE. +/// E.g.: +/// +/// @code{.cpp} +/// +/// class MyVirtualChannel { ... }; +/// +/// template +/// class SerializationTraits::value +/// >::type> { +/// public: +/// static const char* getName() { ... }; +/// } +/// +/// @endcode +template +class SerializationTraits; + +template +class SequenceTraits { +public: + static Error emitSeparator(ChannelT &C) { return Error::success(); } + static Error consumeSeparator(ChannelT &C) { return Error::success(); } +}; + +/// Utility class for serializing sequences of values of varying types. +/// Specializations of this class contain 'serialize' and 'deserialize' methods +/// for the given channel. The ArgTs... list will determine the "over-the-wire" +/// types to be serialized. The serialize and deserialize methods take a list +/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., +/// but may be different types from ArgTs, provided that for each CArgT there +/// is a SerializationTraits specialization +/// SerializeTraits with methods that can serialize the +/// caller argument to over-the-wire value. +template +class SequenceSerialization; + +template +class SequenceSerialization { +public: + static Error serialize(ChannelT &C) { return Error::success(); } + static Error deserialize(ChannelT &C) { return Error::success(); } +}; + +template +class SequenceSerialization { +public: + + template + static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits::type>:: + serialize(C, std::forward(CArg)); + } + + template + static Error deserialize(ChannelT &C, CArgT &CArg) { + return SerializationTraits::deserialize(C, CArg); + } +}; + +template +class SequenceSerialization { +public: + + template + static Error serialize(ChannelT &C, CArgT &&CArg, + CArgTs &&... CArgs) { + if (auto Err = + SerializationTraits::type>:: + serialize(C, std::forward(CArg))) + return Err; + if (auto Err = SequenceTraits::emitSeparator(C)) + return Err; + return SequenceSerialization:: + serialize(C, std::forward(CArgs)...); + } + + template + static Error deserialize(ChannelT &C, CArgT &CArg, + CArgTs &... CArgs) { + if (auto Err = + SerializationTraits::deserialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits::consumeSeparator(C)) + return Err; + return SequenceSerialization::deserialize(C, CArgs...); + } +}; + +template +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { + return SequenceSerialization::type...>:: + serialize(C, std::forward(Args)...); +} + +template +Error deserializeSeq(ChannelT &C, ArgTs &... Args) { + return SequenceSerialization::deserialize(C, Args...); +} + +template +class SerializationTraits { +public: + + using WrappedErrorSerializer = + std::function; + + using WrappedErrorDeserializer = + std::function; + + template + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to the + // key of the deserializers map to save us from duplicating the string in + // the serializer. This should be changed to use a stringpool if we switch + // to a map type that may move keys in memory. + std::lock_guard Lock(DeserializersMutex); + auto I = + Deserializers.insert(Deserializers.begin(), + std::make_pair(std::move(Name), + std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard Lock(SerializersMutex); + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize = std::move(Serialize)]( + ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast(EIB)); + }; + } + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard Lock(SerializersMutex); + + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), + [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard Lock(DeserializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::recursive_mutex SerializersMutex; + static std::recursive_mutex DeserializersMutex; + static std::map Serializers; + static std::map Deserializers; +}; + +template +std::recursive_mutex SerializationTraits::SerializersMutex; + +template +std::recursive_mutex SerializationTraits::DeserializersMutex; + +template +std::map::WrappedErrorSerializer> +SerializationTraits::Serializers; + +template +std::map::WrappedErrorDeserializer> +SerializationTraits::Deserializers; + +/// Registers a serializer and deserializer for the given error type on the +/// given channel type. +template +void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, + DeserializeFtor &&Deserialize) { + SerializationTraits::template registerErrorType( + std::move(Name), + std::forward(Serialize), + std::forward(Deserialize)); +} + +/// Registers serialization/deserialization for StringError. +template +void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + registerErrorSerialization( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = + make_error(std::move(Msg), + orcError( + OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + +/// SerializationTraits for Expected from an Expected. +template +class SerializationTraits, Expected> { +public: + + static Error serialize(ChannelT &C, Expected &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected &ValOrErr) { + ExpectedAsOutParameter EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected from a T2. +template +class SerializationTraits, T2> { +public: + + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected(std::forward(Val))); + } +}; + +/// SerializationTraits for Expected from an Error. +template +class SerializationTraits, Error> { +public: + + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected(std::move(Err))); + } +}; + +/// SerializationTraits default specialization for std::pair. +template +class SerializationTraits, std::pair> { +public: + static Error serialize(ChannelT &C, const std::pair &V) { + if (auto Err = SerializationTraits::serialize(C, V.first)) + return Err; + return SerializationTraits::serialize(C, V.second); + } + + static Error deserialize(ChannelT &C, std::pair &V) { + if (auto Err = + SerializationTraits::deserialize(C, V.first)) + return Err; + return SerializationTraits::deserialize(C, V.second); + } +}; + +/// SerializationTraits default specialization for std::tuple. +template +class SerializationTraits> { +public: + + /// RPC channel serialization for std::tuple. + static Error serialize(ChannelT &C, const std::tuple &V) { + return serializeTupleHelper(C, V, std::index_sequence_for()); + } + + /// RPC channel deserialization for std::tuple. + static Error deserialize(ChannelT &C, std::tuple &V) { + return deserializeTupleHelper(C, V, std::index_sequence_for()); + } + +private: + // Serialization helper for std::tuple. + template + static Error serializeTupleHelper(ChannelT &C, const std::tuple &V, + std::index_sequence _) { + return serializeSeq(C, std::get(V)...); + } + + // Serialization helper for std::tuple. + template + static Error deserializeTupleHelper(ChannelT &C, std::tuple &V, + std::index_sequence _) { + return deserializeSeq(C, std::get(V)...); + } +}; + +/// SerializationTraits default specialization for std::vector. +template +class SerializationTraits> { +public: + + /// Serialize a std::vector from std::vector. + static Error serialize(ChannelT &C, const std::vector &V) { + if (auto Err = serializeSeq(C, static_cast(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = serializeSeq(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::vector to a std::vector. + static Error deserialize(ChannelT &C, std::vector &V) { + assert(V.empty() && + "Expected default-constructed vector to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + V.resize(Count); + for (auto &E : V) + if (auto Err = deserializeSeq(C, E)) + return Err; + + return Error::success(); + } +}; + +template +class SerializationTraits, std::set> { +public: + /// Serialize a std::set from std::set. + static Error serialize(ChannelT &C, const std::set &S) { + if (auto Err = serializeSeq(C, static_cast(S.size()))) + return Err; + + for (const auto &E : S) + if (auto Err = SerializationTraits::serialize(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::set to a std::set. + static Error deserialize(ChannelT &C, std::set &S) { + assert(S.empty() && "Expected default-constructed set to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + T2 Val; + if (auto Err = SerializationTraits::deserialize(C, Val)) + return Err; + + auto Added = S.insert(Val).second; + if (!Added) + return make_error("Duplicate element in deserialized set", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template +class SerializationTraits, std::map> { +public: + /// Serialize a std::map from std::map. + static Error serialize(ChannelT &C, const std::map &M) { + if (auto Err = serializeSeq(C, static_cast(M.size()))) + return Err; + + for (const auto &E : M) { + if (auto Err = + SerializationTraits::serialize(C, E.first)) + return Err; + if (auto Err = + SerializationTraits::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Deserialize a std::map to a std::map. + static Error deserialize(ChannelT &C, std::map &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair Val; + if (auto Err = + SerializationTraits::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCUtils.h new file mode 100644 index 00000000000..ed09363dcec --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCUtils.h @@ -0,0 +1,1690 @@ +//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities to support construction of simple RPC APIs. +// +// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ +// programmers, high performance, low memory overhead, and efficient use of the +// communications channel. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" + +#include + +namespace llvm { +namespace orc { +namespace rpc { + +/// Base class of all fatal RPC errors (those that necessarily result in the +/// termination of the RPC session). +class RPCFatalError : public ErrorInfo { +public: + static char ID; +}; + +/// RPCConnectionClosed is returned from RPC operations if the RPC connection +/// has already been closed due to either an error or graceful disconnection. +class ConnectionClosed : public ErrorInfo { +public: + static char ID; + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// BadFunctionCall is returned from handleOne when the remote makes a call with +/// an unrecognized function id. +/// +/// This error is fatal because Orc RPC needs to know how to parse a function +/// call to know where the next call starts, and if it doesn't recognize the +/// function id it cannot parse the call. +template +class BadFunctionCall + : public ErrorInfo, RPCFatalError> { +public: + static char ID; + + BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) + : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + } + + void log(raw_ostream &OS) const override { + OS << "Call to invalid RPC function id '" << FnId << "' with " + "sequence number " << SeqNo; + } + +private: + FnIdT FnId; + SeqNoT SeqNo; +}; + +template +char BadFunctionCall::ID = 0; + +/// InvalidSequenceNumberForResponse is returned from handleOne when a response +/// call arrives with a sequence number that doesn't correspond to any in-flight +/// function call. +/// +/// This error is fatal because Orc RPC needs to know how to parse the rest of +/// the response call to know where the next call starts, and if it doesn't have +/// a result parser for this sequence number it can't do that. +template +class InvalidSequenceNumberForResponse + : public ErrorInfo, RPCFatalError> { +public: + static char ID; + + InvalidSequenceNumberForResponse(SeqNoT SeqNo) + : SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + + void log(raw_ostream &OS) const override { + OS << "Response has unknown sequence number " << SeqNo; + } +private: + SeqNoT SeqNo; +}; + +template +char InvalidSequenceNumberForResponse::ID = 0; + +/// This non-fatal error will be passed to asynchronous result handlers in place +/// of a result if the connection goes down before a result returns, or if the +/// function to be called cannot be negotiated with the remote. +class ResponseAbandoned : public ErrorInfo { +public: + static char ID; + + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// This error is returned if the remote does not have a handler installed for +/// the given RPC function. +class CouldNotNegotiate : public ErrorInfo { +public: + static char ID; + + CouldNotNegotiate(std::string Signature); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSignature() const { return Signature; } +private: + std::string Signature; +}; + +template class Function; + +// RPC Function class. +// DerivedFunc should be a user defined class with a static 'getName()' method +// returning a const char* representing the function's name. +template +class Function { +public: + /// User defined function type. + using Type = RetT(ArgTs...); + + /// Return type. + using ReturnType = RetT; + + /// Returns the full function prototype as a string. + static const char *getPrototype() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << RPCTypeName::getName() << " " << DerivedFunc::getName() + << "(" << llvm::orc::rpc::RPCTypeNameSequence() << ")"; + return Name; + }(); + return Name.data(); + } +}; + +/// Allocates RPC function ids during autonegotiation. +/// Specializations of this class must provide four members: +/// +/// static T getInvalidId(): +/// Should return a reserved id that will be used to represent missing +/// functions during autonegotiation. +/// +/// static T getResponseId(): +/// Should return a reserved id that will be used to send function responses +/// (return values). +/// +/// static T getNegotiateId(): +/// Should return a reserved id for the negotiate function, which will be used +/// to negotiate ids for user defined functions. +/// +/// template T allocate(): +/// Allocate a unique id for function Func. +template class RPCFunctionIdAllocator; + +/// This specialization of RPCFunctionIdAllocator provides a default +/// implementation for integral types. +template +class RPCFunctionIdAllocator< + T, typename std::enable_if::value>::type> { +public: + static T getInvalidId() { return T(0); } + static T getResponseId() { return T(1); } + static T getNegotiateId() { return T(2); } + + template T allocate() { return NextId++; } + +private: + T NextId = 3; +}; + +namespace detail { + +/// Provides a typedef for a tuple containing the decayed argument types. +template class FunctionArgsTuple; + +template +class FunctionArgsTuple { +public: + using Type = std::tuple::type>::type...>; +}; + +// ResultTraits provides typedefs and utilities specific to the return type +// of functions. +template class ResultTraits { +public: + // The return type wrapped in llvm::Expected. + using ErrorReturnType = Expected; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType(RetT()); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType RetOrErr) { + consumeError(RetOrErr.takeError()); + } +}; + +// ResultTraits specialization for void functions. +template <> class ResultTraits { +public: + // For void functions, ErrorReturnType is llvm::Error. + using ErrorReturnType = Error; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future; +#endif + + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType::success(); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType Err) { + consumeError(std::move(Err)); + } +}; + +// ResultTraits is equivalent to ResultTraits. This allows +// handlers for void RPC functions to return either void (in which case they +// implicitly succeed) or Error (in which case their error return is +// propagated). See usage in HandlerTraits::runHandlerHelper. +template <> class ResultTraits : public ResultTraits {}; + +// ResultTraits> is equivalent to ResultTraits. This allows +// handlers for RPC functions returning a T to return either a T (in which +// case they implicitly succeed) or Expected (in which case their error +// return is propagated). See usage in HandlerTraits::runHandlerHelper. +template +class ResultTraits> : public ResultTraits {}; + +// Determines whether an RPC function's defined error return type supports +// error return value. +template +class SupportsErrorReturn { +public: + static const bool value = false; +}; + +template <> +class SupportsErrorReturn { +public: + static const bool value = true; +}; + +template +class SupportsErrorReturn> { +public: + static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template +class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> +class RespondHelper { +public: + + // Send Expected. + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected ResultOrErr) { + if (!ResultOrErr && ResultOrErr.template errorIsA()) + return ResultOrErr.takeError(); + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits>::serialize( + C, std::move(ResultOrErr))) + return Err; + + // Close the response message. + if (auto Err = C.endSendMessage()) + return Err; + return C.send(); + } + + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err && Err.isA()) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = serializeSeq(C, std::move(Err))) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } + +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> +class RespondHelper { +public: + + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected ResultOrErr) { + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits::serialize( + C, *ResultOrErr)) + return Err; + + // End the response message. + if (auto Err = C.endSendMessage()) + return Err; + + return C.send(); + } + + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); + } + +}; + + +// Send a response of the given wire return type (WireRetT) over the +// channel, with the given sequence number. +template +Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Expected ResultOrErr) { + return RespondHelper::value>:: + template sendResult(C, ResponseId, SeqNo, std::move(ResultOrErr)); +} + +// Send an empty response message on the given channel to indicate that +// the handler ran. +template +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Error Err) { + return RespondHelper::value>:: + sendResult(C, ResponseId, SeqNo, std::move(Err)); +} + +// Converts a given type to the equivalent error return type. +template class WrappedHandlerReturn { +public: + using Type = Expected; +}; + +template class WrappedHandlerReturn> { +public: + using Type = Expected; +}; + +template <> class WrappedHandlerReturn { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn { +public: + using Type = Error; +}; + +template <> class WrappedHandlerReturn { +public: + using Type = Error; +}; + +// Traits class that strips the response function from the list of handler +// arguments. +template class AsyncHandlerTraits; + +template +class AsyncHandlerTraits)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Expected; +}; + +template +class AsyncHandlerTraits, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template +class AsyncHandlerTraits, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template +class AsyncHandlerTraits, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template +class AsyncHandlerTraits : + public AsyncHandlerTraits::type, + ArgTs...)> {}; + +// This template class provides utilities related to RPC function handlers. +// The base case applies to non-function types (the template class is +// specialized for function types) and inherits from the appropriate +// speciilization for the given non-function type's call operator. +template +class HandlerTraits : public HandlerTraits::type::operator())> { +}; + +// Traits for handlers with a given function type. +template +class HandlerTraits { +public: + // Function type of the handler. + using Type = RetT(ArgTs...); + + // Return type of the handler. + using ReturnType = RetT; + + // Call the given handler with the given arguments. + template + static typename WrappedHandlerReturn::Type + unpackAndRun(HandlerT &Handler, std::tuple &Args) { + return unpackAndRunHelper(Handler, Args, + std::index_sequence_for()); + } + + // Call the given handler with the given arguments. + template + static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, + std::tuple &Args) { + return unpackAndRunAsyncHelper(Handler, Responder, Args, + std::index_sequence_for()); + } + + // Call the given handler with the given arguments. + template + static typename std::enable_if< + std::is_void::ReturnType>::value, + Error>::type + run(HandlerT &Handler, ArgTs &&... Args) { + Handler(std::move(Args)...); + return Error::success(); + } + + template + static typename std::enable_if< + !std::is_void::ReturnType>::value, + typename HandlerTraits::ReturnType>::type + run(HandlerT &Handler, TArgTs... Args) { + return Handler(std::move(Args)...); + } + + // Serialize arguments to the channel. + template + static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { + return SequenceSerialization::serialize(C, CArgs...); + } + + // Deserialize arguments from the channel. + template + static Error deserializeArgs(ChannelT &C, std::tuple &Args) { + return deserializeArgsHelper(C, Args, std::index_sequence_for()); + } + +private: + template + static Error deserializeArgsHelper(ChannelT &C, std::tuple &Args, + std::index_sequence _) { + return SequenceSerialization::deserialize( + C, std::get(Args)...); + } + + template + static typename WrappedHandlerReturn< + typename HandlerTraits::ReturnType>::Type + unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, + std::index_sequence) { + return run(Handler, std::move(std::get(Args))...); + } + + template + static typename WrappedHandlerReturn< + typename HandlerTraits::ReturnType>::Type + unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, + ArgTuple &Args, std::index_sequence) { + return run(Handler, Responder, std::move(std::get(Args))...); + } +}; + +// Handler traits for free functions. +template +class HandlerTraits + : public HandlerTraits {}; + +// Handler traits for class methods (especially call operators for lambdas). +template +class HandlerTraits + : public HandlerTraits {}; + +// Handler traits for const class methods (especially call operators for +// lambdas). +template +class HandlerTraits + : public HandlerTraits {}; + +// Utility to peel the Expected wrapper off a response handler error type. +template class ResponseHandlerArg; + +template class ResponseHandlerArg)> { +public: + using ArgType = Expected; + using UnwrappedArgType = ArgT; +}; + +template +class ResponseHandlerArg)> { +public: + using ArgType = Expected; + using UnwrappedArgType = ArgT; +}; + +template <> class ResponseHandlerArg { +public: + using ArgType = Error; +}; + +template <> class ResponseHandlerArg { +public: + using ArgType = Error; +}; + +// ResponseHandler represents a handler for a not-yet-received function call +// result. +template class ResponseHandler { +public: + virtual ~ResponseHandler() {} + + // Reads the function result off the wire and acts on it. The meaning of + // "act" will depend on how this method is implemented in any given + // ResponseHandler subclass but could, for example, mean running a + // user-specified handler or setting a promise value. + virtual Error handleResponse(ChannelT &C) = 0; + + // Abandons this outstanding result. + virtual void abandon() = 0; + + // Create an error instance representing an abandoned response. + static Error createAbandonedResponseError() { + return make_error(); + } +}; + +// ResponseHandler subclass for RPC functions with non-void returns. +template +class ResponseHandlerImpl : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using UnwrappedArgType = typename ResponseHandlerArg< + typename HandlerTraits::Type>::UnwrappedArgType; + UnwrappedArgType Result; + if (auto Err = + SerializationTraits::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// ResponseHandler subclass for RPC functions with void returns. +template +class ResponseHandlerImpl + : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result (no actual value, just a notification that the function + // has completed on the remote end) by calling the user-defined handler with + // Error::success(). + Error handleResponse(ChannelT &C) override { + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Error::success()); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template +class ResponseHandlerImpl, HandlerT> + : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using HandlerArgType = typename ResponseHandlerArg< + typename HandlerTraits::Type>::ArgType; + HandlerArgType Result((typename HandlerArgType::value_type())); + + if (auto Err = + SerializationTraits, + HandlerArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template +class ResponseHandlerImpl + : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + Error Result = Error::success(); + if (auto Err = SerializationTraits::deserialize( + C, Result)) { + consumeError(std::move(Result)); + return Err; + } + if (auto Err = C.endReceiveMessage()) { + consumeError(std::move(Result)); + return Err; + } + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +// Create a ResponseHandler from a given user handler. +template +std::unique_ptr> createResponseHandler(HandlerT H) { + return std::make_unique>( + std::move(H)); +} + +// Helper for wrapping member functions up as functors. This is useful for +// installing methods as result handlers. +template +class MemberFnWrapper { +public: + using MethodT = RetT (ClassT::*)(ArgTs...); + MemberFnWrapper(ClassT &Instance, MethodT Method) + : Instance(Instance), Method(Method) {} + RetT operator()(ArgTs &&... Args) { + return (Instance.*Method)(std::move(Args)...); + } + +private: + ClassT &Instance; + MethodT Method; +}; + +// Helper that provides a Functor for deserializing arguments. +template class ReadArgs { +public: + Error operator()() { return Error::success(); } +}; + +template +class ReadArgs : public ReadArgs { +public: + ReadArgs(ArgT &Arg, ArgTs &... Args) + : ReadArgs(Args...), Arg(Arg) {} + + Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { + this->Arg = std::move(ArgVal); + return ReadArgs::operator()(ArgVals...); + } + +private: + ArgT &Arg; +}; + +// Manage sequence numbers. +template class SequenceNumberManager { +public: + // Reset, making all sequence numbers available. + void reset() { + std::lock_guard Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } + + // Get the next available sequence number. Will re-use numbers that have + // been released. + SequenceNumberT getSequenceNumber() { + std::lock_guard Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } + + // Release a sequence number, making it available for re-use. + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } + +private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector FreeSequenceNumbers; +}; + +// Checks that predicate P holds for each corresponding pair of type arguments +// from T1 and T2 tuple. +template