diff options
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/OrcError.h | 1 | ||||
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 166 | ||||
-rw-r--r-- | llvm/lib/ExecutionEngine/Orc/OrcError.cpp | 2 | ||||
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 71 |
4 files changed, 204 insertions, 36 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h index 8841aa77f62..b74988cce2f 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -27,6 +27,7 @@ enum class OrcErrorCode : int { RemoteMProtectAddrUnrecognized, RemoteIndirectStubsOwnerDoesNotExist, RemoteIndirectStubsOwnerIdAlreadyInUse, + RPCResponseAbandoned, UnexpectedRPCCall, UnexpectedRPCResponse, UnknownRPCFunction diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 34ae392516c..f51fbe153a4 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -364,9 +364,27 @@ public: // Call the given handler with the given arguments. template <typename HandlerT> static typename WrappedHandlerReturn<RetT>::Type - runHandler(HandlerT &Handler, ArgStorage &Args) { - return runHandlerHelper<RetT>(Handler, Args, - llvm::index_sequence_for<ArgTs...>()); + unpackAndRun(HandlerT &Handler, ArgStorage &Args) { + return unpackAndRunHelper(Handler, Args, + llvm::index_sequence_for<ArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT> + static typename std::enable_if< + std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, + Error>::type + run(HandlerT &Handler, ArgTs &&... Args) { + Handler(std::move(Args)...); + return Error::success(); + } + + template <typename HandlerT> + static typename std::enable_if< + !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, + typename HandlerTraits<HandlerT>::ReturnType>::type + run(HandlerT &Handler, ArgTs... Args) { + return Handler(std::move(Args)...); } // Serialize arguments to the channel. @@ -383,31 +401,20 @@ public: } private: - // For non-void user handlers: unwrap the args tuple and call the handler, - // returning the result. - template <typename RetTAlt, typename HandlerT, size_t... Indexes> - static typename std::enable_if<!std::is_void<RetTAlt>::value, RetT>::type - runHandlerHelper(HandlerT &Handler, ArgStorage &Args, - llvm::index_sequence<Indexes...>) { - return Handler(std::move(std::get<Indexes>(Args))...); - } - - // For void user handlers: unwrap the args tuple and call the handler, then - // return Error::success(). - template <typename RetTAlt, typename HandlerT, size_t... Indexes> - static typename std::enable_if<std::is_void<RetTAlt>::value, Error>::type - runHandlerHelper(HandlerT &Handler, ArgStorage &Args, - llvm::index_sequence<Indexes...>) { - Handler(std::move(std::get<Indexes>(Args))...); - return Error::success(); - } - template <typename ChannelT, typename... CArgTs, size_t... Indexes> static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, llvm::index_sequence<Indexes...> _) { return SequenceSerialization<ChannelT, ArgTs...>::deserialize( C, std::get<Indexes>(Args)...); } + + template <typename HandlerT, size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args, + llvm::index_sequence<Indexes...>) { + return run(Handler, std::move(std::get<Indexes>(Args))...); + } }; // Handler traits for class methods (especially call operators for lambdas). @@ -422,17 +429,29 @@ class HandlerTraits<RetT (Class::*)(ArgTs...) const> : public HandlerTraits<RetT(ArgTs...)> {}; // Utility to peel the Expected wrapper off a response handler error type. -template <typename HandlerT> class UnwrapResponseHandlerArg; +template <typename HandlerT> class ResponseHandlerArg; -template <typename ArgT> class UnwrapResponseHandlerArg<Error(Expected<ArgT>)> { +template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { public: - using ArgType = ArgT; + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; }; template <typename ArgT> -class UnwrapResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { +class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { +public: + using ArgType = Expected<ArgT>; + using UnwrappedArgType = ArgT; +}; + +template <> class ResponseHandlerArg<Error(Error)> { public: - using ArgType = ArgT; + using ArgType = Error; +}; + +template <> class ResponseHandlerArg<ErrorSuccess(Error)> { +public: + using ArgType = Error; }; // ResponseHandler represents a handler for a not-yet-received function call @@ -452,8 +471,7 @@ public: // Create an error instance representing an abandoned response. static Error createAbandonedResponseError() { - return make_error<StringError>("RPC function call failed to return", - inconvertibleErrorCode()); + return orcError(OrcErrorCode::RPCResponseAbandoned); } }; @@ -466,12 +484,12 @@ public: // Handle the result by deserializing it from the channel then passing it // to the user defined handler. Error handleResponse(ChannelT &C) override { - using ArgType = typename UnwrapResponseHandlerArg< - typename HandlerTraits<HandlerT>::Type>::ArgType; - ArgType Result; + using UnwrappedArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; + UnwrappedArgType Result; if (auto Err = - SerializationTraits<ChannelT, FuncRetT, ArgType>::deserialize( - C, Result)) + SerializationTraits<ChannelT, FuncRetT, + UnwrappedArgType>::deserialize(C, Result)) return Err; if (auto Err = C.endReceiveMessage()) return Err; @@ -802,6 +820,8 @@ public: return Error::success(); } + Error sendAppendedCalls() { return C.send(); }; + template <typename Func, typename HandlerT, typename... ArgTs> Error callAsync(HandlerT Handler, const ArgTs &... Args) { if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) @@ -966,8 +986,8 @@ protected: SeqNo]() mutable -> Error { using HTraits = detail::HandlerTraits<HandlerT>; using FuncReturn = typename Func::ReturnType; - return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, - HTraits::runHandler(Handler, *Args)); + return detail::respond<FuncReturn>( + Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args)); }; // If there is an explicit launch policy then use it to launch the @@ -1238,6 +1258,80 @@ public: } }; +/// \brief Allows a set of asynchrounous calls to be dispatched, and then +/// waited on as a group. +template <typename RPCClass> class ParallelCallGroup { +public: + + /// \brief Construct a parallel call group for the given RPC. + ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {} + + ParallelCallGroup(const ParallelCallGroup &) = delete; + ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; + + /// \brief Make as asynchronous call. + /// + /// Does not issue a send call to the RPC's channel. The channel may use this + /// to batch up subsequent calls. A send will automatically be sent when wait + /// is called. + template <typename Func, typename HandlerT, typename... ArgTs> + Error appendCall(HandlerT Handler, const ArgTs &... Args) { + // Increment the count of outstanding calls. This has to happen before + // we invoke the call, as the handler may (depending on scheduling) + // be run immediately on another thread, and we don't want the decrement + // in the wrapped handler below to run before the increment. + { + std::unique_lock<std::mutex> Lock(M); + ++NumOutstandingCalls; + } + + // Wrap the user handler in a lambda that will decrement the + // outstanding calls count, then poke the condition variable. + using ArgType = typename detail::ResponseHandlerArg< + typename detail::HandlerTraits<HandlerT>::Type>::ArgType; + // FIXME: Move handler into wrapped handler once we have C++14. + auto WrappedHandler = [this, Handler](ArgType Arg) { + auto Err = Handler(std::move(Arg)); + std::unique_lock<std::mutex> Lock(M); + --NumOutstandingCalls; + CV.notify_all(); + return Err; + }; + + return RPC.template appendCallAsync<Func>(std::move(WrappedHandler), + Args...); + } + + /// \brief Make an asynchronous call. + /// + /// The same as appendCall, but also calls send on the channel immediately. + /// Prefer appendCall if you are about to issue a "wait" call shortly, as + /// this may allow the channel to better batch the calls. + template <typename Func, typename HandlerT, typename... ArgTs> + Error call(HandlerT Handler, const ArgTs &... Args) { + if (auto Err = appendCall(std::move(Handler), Args...)) + return Err; + return RPC.sendAppendedCalls(); + } + + /// \brief Blocks until all calls have been completed and their return value + /// handlers run. + Error wait() { + if (auto Err = RPC.sendAppendedCalls()) + return Err; + std::unique_lock<std::mutex> Lock(M); + while (NumOutstandingCalls > 0) + CV.wait(Lock); + return Error::success(); + } + +private: + RPCClass &RPC; + std::mutex M; + std::condition_variable CV; + uint32_t NumOutstandingCalls; +}; + } // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp index 48dcd442266..c531fe36992 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp @@ -39,6 +39,8 @@ public: return "Remote indirect stubs owner does not exist"; case OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse: return "Remote indirect stubs owner Id already in use"; + case OrcErrorCode::RPCResponseAbandoned: + return "RPC response abandoned"; case OrcErrorCode::UnexpectedRPCCall: return "Unexpected RPC call"; case OrcErrorCode::UnexpectedRPCResponse: diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index f9b65a99505..381fd103042 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -386,3 +386,74 @@ TEST(DummyRPC, TestWithAltCustomType) { ServerThread.join(); } + +TEST(DummyRPC, TestParallelCallGroup) { + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::IntInt>( + [](int X) -> int { + return 2 * X; + }); + + // Handle the negotiate, plus three calls. + for (unsigned I = 0; I != 4; ++I) { + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to int(int)"; + } + }); + + { + int A, B, C; + ParallelCallGroup<DummyRPCEndpoint> PCG(Client); + + { + auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + [&A](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + A = *Result; + return Error::success(); + }, 1); + EXPECT_FALSE(!!Err) << "First parallel call failed for int(int)"; + } + + { + auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + [&B](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + B = *Result; + return Error::success(); + }, 2); + EXPECT_FALSE(!!Err) << "Second parallel call failed for int(int)"; + } + + { + auto Err = PCG.appendCall<DummyRPCAPI::IntInt>( + [&C](Expected<int> Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + C = *Result; + return Error::success(); + }, 3); + EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)"; + } + + // Handle the three int(int) results. + for (unsigned I = 0; I != 3; ++I) { + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; + } + + { + auto Err = PCG.wait(); + EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)"; + } + + EXPECT_EQ(A, 2) << "First parallel call returned bogus result"; + EXPECT_EQ(B, 4) << "Second parallel call returned bogus result"; + EXPECT_EQ(C, 6) << "Third parallel call returned bogus result"; + } + + ServerThread.join(); +} |