diff options
Diffstat (limited to 'llvm')
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 228 | ||||
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 50 |
2 files changed, 198 insertions, 80 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index be620d4f0d7..8119494b571 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -82,16 +82,6 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex; template <typename DerivedFunc, typename RetT, typename... ArgTs> std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; -/// Provides a typedef for a tuple containing the decayed argument types. -template <typename T> class FunctionArgsTuple; - -template <typename RetT, typename... ArgTs> -class FunctionArgsTuple<RetT(ArgTs...)> { -public: - using Type = std::tuple<typename std::decay< - typename std::remove_reference<ArgTs>::type>::type...>; -}; - /// Allocates RPC function ids during autonegotiation. /// Specializations of this class must provide four members: /// @@ -196,6 +186,16 @@ public: #endif // _MSC_VER +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> class FunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class FunctionArgsTuple<RetT(ArgTs...)> { +public: + using Type = std::tuple<typename std::decay< + typename std::remove_reference<ArgTs>::type>::type...>; +}; + // ResultTraits provides typedefs and utilities specific to the return type // of functions. template <typename RetT> class ResultTraits { @@ -339,6 +339,22 @@ public: using Type = Error; }; +template <typename FnT> class AsyncHandlerTraits; + +template <typename ResultT, typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Expected<ResultT>; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + // This template class provides utilities related to RPC function handlers. // The base case applies to non-function types (the template class is // specialized for function types) and inherits from the appropriate @@ -358,15 +374,20 @@ public: // Return type of the handler. using ReturnType = RetT; - // A std::tuple wrapping the handler arguments. - using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type; - // Call the given handler with the given arguments. - template <typename HandlerT> + template <typename HandlerT, typename... TArgTs> static typename WrappedHandlerReturn<RetT>::Type - unpackAndRun(HandlerT &Handler, ArgStorage &Args) { + unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { return unpackAndRunHelper(Handler, Args, - llvm::index_sequence_for<ArgTs...>()); + llvm::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT, typename ResponderT, typename... TArgTs> + static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, + std::tuple<TArgTs...> &Args) { + return unpackAndRunAsyncHelper(Handler, Responder, Args, + llvm::index_sequence_for<TArgTs...>()); } // Call the given handler with the given arguments. @@ -379,11 +400,11 @@ public: return Error::success(); } - template <typename HandlerT> + template <typename HandlerT, typename... TArgTs> static typename std::enable_if< !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, typename HandlerTraits<HandlerT>::ReturnType>::type - run(HandlerT &Handler, ArgTs... Args) { + run(HandlerT &Handler, TArgTs... Args) { return Handler(std::move(Args)...); } @@ -408,13 +429,24 @@ private: C, std::get<Indexes>(Args)...); } - template <typename HandlerT, size_t... Indexes> + template <typename HandlerT, typename ArgTuple, size_t... Indexes> static typename WrappedHandlerReturn< typename HandlerTraits<HandlerT>::ReturnType>::Type - unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args, + unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, llvm::index_sequence<Indexes...>) { return run(Handler, std::move(std::get<Indexes>(Args))...); } + + + template <typename HandlerT, typename ResponderT, typename ArgTuple, + size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, + ArgTuple &Args, + llvm::index_sequence<Indexes...>) { + return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); + } }; // Handler traits for free functions. @@ -763,8 +795,7 @@ public: auto NegotiateId = FnIdAllocator.getNegotiateId(); RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( - [this](const std::string &Name) { return handleNegotiate(Name); }, - LaunchPolicy()); + [this](const std::string &Name) { return handleNegotiate(Name); }); } @@ -919,9 +950,6 @@ public: } protected: - // The LaunchPolicy type allows a launch policy to be specified when adding - // a function handler. See addHandlerImpl. - using LaunchPolicy = std::function<Error(std::function<Error()>)>; FunctionIdT getInvalidFunctionId() const { return FnIdAllocator.getInvalidId(); @@ -930,7 +958,7 @@ protected: /// Add the given handler to the handler map and make it available for /// autonegotiation and execution. template <typename Func, typename HandlerT> - void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { + void addHandlerImpl(HandlerT Handler) { static_assert(detail::RPCArgTypeCheck< CanDeserializeCheck, typename Func::Type, @@ -939,8 +967,22 @@ protected: FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); LocalFunctionIds[Func::getPrototype()] = NewFnId; - Handlers[NewFnId] = - wrapHandler<Func>(std::move(Handler), std::move(Launch)); + Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandlerImpl(HandlerT Handler) { + + static_assert(detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type + >::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); } Error handleResponse(SequenceNumberT SeqNo) { @@ -1022,12 +1064,49 @@ protected: // Wrap the given user handler in the necessary argument-deserialization code, // result-serialization code, and call to the launch policy (if present). template <typename Func, typename HandlerT> - WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) { - return [this, Handler, Launch](ChannelT &Channel, - SequenceNumberT SeqNo) mutable -> Error { + WrappedHandlerFn wrapHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using ArgsTuple = + typename detail::FunctionArgsTuple< + typename detail::HandlerTraits<HandlerT>::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, + HTraits::unpackAndRun(Handler, *Args)); + }; + } + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { // Start by deserializing the arguments. - auto Args = std::make_shared< - typename detail::HandlerTraits<HandlerT>::ArgStorage>(); + using AHTraits = detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>; + using ArgsTuple = + typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + if (auto Err = detail::HandlerTraits<typename Func::Type>::deserializeArgs( Channel, *Args)) @@ -1042,22 +1121,15 @@ protected: if (auto Err = Channel.endReceiveMessage()) return Err; - // Build the handler/responder. - auto Responder = [this, Handler, Args, &Channel, - SeqNo]() mutable -> Error { - using HTraits = detail::HandlerTraits<HandlerT>; - using FuncReturn = typename Func::ReturnType; - return detail::respond<FuncReturn>( - Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args)); - }; - - // If there is an explicit launch policy then use it to launch the - // handler. - if (Launch) - return Launch(std::move(Responder)); - - // Otherwise run the handler on the listener thread. - return Responder(); + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + auto Responder = + [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { + return detail::respond<FuncReturn>(C, ResponseId, SeqNo, + std::move(RetVal)); + }; + + return HTraits::unpackAndRunAsync(Handler, Responder, *Args); }; } @@ -1097,40 +1169,31 @@ public: MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} - /// The LaunchPolicy type allows a launch policy to be specified when adding - /// a function handler. See addHandler. - using LaunchPolicy = typename BaseClass::LaunchPolicy; - /// Add a handler for the given RPC function. /// This installs the given handler functor for the given RPC Function, and /// makes the RPC function available for negotiation/calling from the remote. - /// - /// The optional LaunchPolicy argument can be used to control how the handler - /// is run when called: - /// - /// * If no LaunchPolicy is given, the handler code will be run on the RPC - /// handler thread that is reading from the channel. This handler cannot - /// make blocking RPC calls (since it would be blocking the thread used to - /// get the result), but can make non-blocking calls. - /// - /// * If a LaunchPolicy is given, the user's handler will be wrapped in a - /// call to serialize and send the result, and the resulting functor (with - /// type 'Error()' will be passed to the LaunchPolicy. The user can then - /// choose to add the wrapped handler to a work queue, spawn a new thread, - /// or anything else. template <typename Func, typename HandlerT> - void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) { - return this->template addHandlerImpl<Func>(std::move(Handler), - std::move(Launch)); + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); } /// Add a class-method as a handler. template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...), - LaunchPolicy Launch = LaunchPolicy()) { + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { addHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method), - Launch); + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } /// Return type for non-blocking call primitives. @@ -1220,16 +1283,13 @@ private: SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, ChannelT, FunctionIdT, SequenceNumberT>; - using LaunchPolicy = typename BaseClass::LaunchPolicy; - public: SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} template <typename Func, typename HandlerT> void addHandler(HandlerT Handler) { - return this->template addHandlerImpl<Func>(std::move(Handler), - LaunchPolicy()); + return this->template addHandlerImpl<Func>(std::move(Handler)); } template <typename Func, typename ClassT, typename RetT, typename... ArgTs> @@ -1238,6 +1298,18 @@ public: detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } + + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); + } + template <typename Func, typename... ArgTs, typename AltRetT = typename Func::ReturnType> typename detail::ResultTraits<AltRetT>::ErrorReturnType diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index ae2d18e3d75..91cec1c1ede 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -154,7 +154,7 @@ TEST(DummyRPC, TestFreeFunctionHandler) { Server.addHandler<DummyRPCAPI::VoidBool>(freeVoidBool); } -TEST(DummyRPC, TestAsyncVoidBool) { +TEST(DummyRPC, TestCallAsyncVoidBool) { Queue Q1, Q2; DummyRPCEndpoint Client(Q1, Q2); DummyRPCEndpoint Server(Q2, Q1); @@ -198,7 +198,7 @@ TEST(DummyRPC, TestAsyncVoidBool) { ServerThread.join(); } -TEST(DummyRPC, TestAsyncIntInt) { +TEST(DummyRPC, TestCallAsyncIntInt) { Queue Q1, Q2; DummyRPCEndpoint Client(Q1, Q2); DummyRPCEndpoint Server(Q2, Q1); @@ -243,6 +243,52 @@ TEST(DummyRPC, TestAsyncIntInt) { ServerThread.join(); } +TEST(DummyRPC, TestAsyncIntIntHandler) { + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread([&]() { + Server.addAsyncHandler<DummyRPCAPI::IntInt>( + [](std::function<Error(Expected<int32_t>)> SendResult, + int32_t X) { + EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result"; + return SendResult(2 * X); + }); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the VoidBool call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); + + { + auto Err = Client.callAsync<DummyRPCAPI::IntInt>( + [](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + EXPECT_EQ(*Result, 42) + << "Async int(int) response handler received incorrect result"; + return Error::success(); + }, 21); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)"; + } + + { + // Poke the client to process the result. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; + } + + ServerThread.join(); +} + TEST(DummyRPC, TestSerialization) { Queue Q1, Q2; DummyRPCEndpoint Client(Q1, Q2); |