summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/ExecutionEngine/Orc/OrcError.h1
-rw-r--r--llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h166
-rw-r--r--llvm/lib/ExecutionEngine/Orc/OrcError.cpp2
-rw-r--r--llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp71
4 files changed, 204 insertions, 36 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h
index 8841aa77f62..b74988cce2f 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcError.h
@@ -27,6 +27,7 @@ enum class OrcErrorCode : int {
RemoteMProtectAddrUnrecognized,
RemoteIndirectStubsOwnerDoesNotExist,
RemoteIndirectStubsOwnerIdAlreadyInUse,
+ RPCResponseAbandoned,
UnexpectedRPCCall,
UnexpectedRPCResponse,
UnknownRPCFunction
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h
index 34ae392516c..f51fbe153a4 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h
@@ -364,9 +364,27 @@ public:
// Call the given handler with the given arguments.
template <typename HandlerT>
static typename WrappedHandlerReturn<RetT>::Type
- runHandler(HandlerT &Handler, ArgStorage &Args) {
- return runHandlerHelper<RetT>(Handler, Args,
- llvm::index_sequence_for<ArgTs...>());
+ unpackAndRun(HandlerT &Handler, ArgStorage &Args) {
+ return unpackAndRunHelper(Handler, Args,
+ llvm::index_sequence_for<ArgTs...>());
+ }
+
+ // Call the given handler with the given arguments.
+ template <typename HandlerT>
+ static typename std::enable_if<
+ std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
+ Error>::type
+ run(HandlerT &Handler, ArgTs &&... Args) {
+ Handler(std::move(Args)...);
+ return Error::success();
+ }
+
+ template <typename HandlerT>
+ static typename std::enable_if<
+ !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
+ typename HandlerTraits<HandlerT>::ReturnType>::type
+ run(HandlerT &Handler, ArgTs... Args) {
+ return Handler(std::move(Args)...);
}
// Serialize arguments to the channel.
@@ -383,31 +401,20 @@ public:
}
private:
- // For non-void user handlers: unwrap the args tuple and call the handler,
- // returning the result.
- template <typename RetTAlt, typename HandlerT, size_t... Indexes>
- static typename std::enable_if<!std::is_void<RetTAlt>::value, RetT>::type
- runHandlerHelper(HandlerT &Handler, ArgStorage &Args,
- llvm::index_sequence<Indexes...>) {
- return Handler(std::move(std::get<Indexes>(Args))...);
- }
-
- // For void user handlers: unwrap the args tuple and call the handler, then
- // return Error::success().
- template <typename RetTAlt, typename HandlerT, size_t... Indexes>
- static typename std::enable_if<std::is_void<RetTAlt>::value, Error>::type
- runHandlerHelper(HandlerT &Handler, ArgStorage &Args,
- llvm::index_sequence<Indexes...>) {
- Handler(std::move(std::get<Indexes>(Args))...);
- return Error::success();
- }
-
template <typename ChannelT, typename... CArgTs, size_t... Indexes>
static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
llvm::index_sequence<Indexes...> _) {
return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
C, std::get<Indexes>(Args)...);
}
+
+ template <typename HandlerT, size_t... Indexes>
+ static typename WrappedHandlerReturn<
+ typename HandlerTraits<HandlerT>::ReturnType>::Type
+ unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args,
+ llvm::index_sequence<Indexes...>) {
+ return run(Handler, std::move(std::get<Indexes>(Args))...);
+ }
};
// Handler traits for class methods (especially call operators for lambdas).
@@ -422,17 +429,29 @@ class HandlerTraits<RetT (Class::*)(ArgTs...) const>
: public HandlerTraits<RetT(ArgTs...)> {};
// Utility to peel the Expected wrapper off a response handler error type.
-template <typename HandlerT> class UnwrapResponseHandlerArg;
+template <typename HandlerT> class ResponseHandlerArg;
-template <typename ArgT> class UnwrapResponseHandlerArg<Error(Expected<ArgT>)> {
+template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
public:
- using ArgType = ArgT;
+ using ArgType = Expected<ArgT>;
+ using UnwrappedArgType = ArgT;
};
template <typename ArgT>
-class UnwrapResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
+class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
+public:
+ using ArgType = Expected<ArgT>;
+ using UnwrappedArgType = ArgT;
+};
+
+template <> class ResponseHandlerArg<Error(Error)> {
public:
- using ArgType = ArgT;
+ using ArgType = Error;
+};
+
+template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
+public:
+ using ArgType = Error;
};
// ResponseHandler represents a handler for a not-yet-received function call
@@ -452,8 +471,7 @@ public:
// Create an error instance representing an abandoned response.
static Error createAbandonedResponseError() {
- return make_error<StringError>("RPC function call failed to return",
- inconvertibleErrorCode());
+ return orcError(OrcErrorCode::RPCResponseAbandoned);
}
};
@@ -466,12 +484,12 @@ public:
// Handle the result by deserializing it from the channel then passing it
// to the user defined handler.
Error handleResponse(ChannelT &C) override {
- using ArgType = typename UnwrapResponseHandlerArg<
- typename HandlerTraits<HandlerT>::Type>::ArgType;
- ArgType Result;
+ using UnwrappedArgType = typename ResponseHandlerArg<
+ typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
+ UnwrappedArgType Result;
if (auto Err =
- SerializationTraits<ChannelT, FuncRetT, ArgType>::deserialize(
- C, Result))
+ SerializationTraits<ChannelT, FuncRetT,
+ UnwrappedArgType>::deserialize(C, Result))
return Err;
if (auto Err = C.endReceiveMessage())
return Err;
@@ -802,6 +820,8 @@ public:
return Error::success();
}
+ Error sendAppendedCalls() { return C.send(); };
+
template <typename Func, typename HandlerT, typename... ArgTs>
Error callAsync(HandlerT Handler, const ArgTs &... Args) {
if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
@@ -966,8 +986,8 @@ protected:
SeqNo]() mutable -> Error {
using HTraits = detail::HandlerTraits<HandlerT>;
using FuncReturn = typename Func::ReturnType;
- return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
- HTraits::runHandler(Handler, *Args));
+ return detail::respond<FuncReturn>(
+ Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args));
};
// If there is an explicit launch policy then use it to launch the
@@ -1238,6 +1258,80 @@ public:
}
};
+/// \brief Allows a set of asynchrounous calls to be dispatched, and then
+/// waited on as a group.
+template <typename RPCClass> class ParallelCallGroup {
+public:
+
+ /// \brief Construct a parallel call group for the given RPC.
+ ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {}
+
+ ParallelCallGroup(const ParallelCallGroup &) = delete;
+ ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
+
+ /// \brief Make as asynchronous call.
+ ///
+ /// Does not issue a send call to the RPC's channel. The channel may use this
+ /// to batch up subsequent calls. A send will automatically be sent when wait
+ /// is called.
+ template <typename Func, typename HandlerT, typename... ArgTs>
+ Error appendCall(HandlerT Handler, const ArgTs &... Args) {
+ // Increment the count of outstanding calls. This has to happen before
+ // we invoke the call, as the handler may (depending on scheduling)
+ // be run immediately on another thread, and we don't want the decrement
+ // in the wrapped handler below to run before the increment.
+ {
+ std::unique_lock<std::mutex> Lock(M);
+ ++NumOutstandingCalls;
+ }
+
+ // Wrap the user handler in a lambda that will decrement the
+ // outstanding calls count, then poke the condition variable.
+ using ArgType = typename detail::ResponseHandlerArg<
+ typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
+ // FIXME: Move handler into wrapped handler once we have C++14.
+ auto WrappedHandler = [this, Handler](ArgType Arg) {
+ auto Err = Handler(std::move(Arg));
+ std::unique_lock<std::mutex> Lock(M);
+ --NumOutstandingCalls;
+ CV.notify_all();
+ return Err;
+ };
+
+ return RPC.template appendCallAsync<Func>(std::move(WrappedHandler),
+ Args...);
+ }
+
+ /// \brief Make an asynchronous call.
+ ///
+ /// The same as appendCall, but also calls send on the channel immediately.
+ /// Prefer appendCall if you are about to issue a "wait" call shortly, as
+ /// this may allow the channel to better batch the calls.
+ template <typename Func, typename HandlerT, typename... ArgTs>
+ Error call(HandlerT Handler, const ArgTs &... Args) {
+ if (auto Err = appendCall(std::move(Handler), Args...))
+ return Err;
+ return RPC.sendAppendedCalls();
+ }
+
+ /// \brief Blocks until all calls have been completed and their return value
+ /// handlers run.
+ Error wait() {
+ if (auto Err = RPC.sendAppendedCalls())
+ return Err;
+ std::unique_lock<std::mutex> Lock(M);
+ while (NumOutstandingCalls > 0)
+ CV.wait(Lock);
+ return Error::success();
+ }
+
+private:
+ RPCClass &RPC;
+ std::mutex M;
+ std::condition_variable CV;
+ uint32_t NumOutstandingCalls;
+};
+
} // end namespace rpc
} // end namespace orc
} // end namespace llvm
diff --git a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp
index 48dcd442266..c531fe36992 100644
--- a/llvm/lib/ExecutionEngine/Orc/OrcError.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/OrcError.cpp
@@ -39,6 +39,8 @@ public:
return "Remote indirect stubs owner does not exist";
case OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse:
return "Remote indirect stubs owner Id already in use";
+ case OrcErrorCode::RPCResponseAbandoned:
+ return "RPC response abandoned";
case OrcErrorCode::UnexpectedRPCCall:
return "Unexpected RPC call";
case OrcErrorCode::UnexpectedRPCResponse:
diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
index f9b65a99505..381fd103042 100644
--- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
@@ -386,3 +386,74 @@ TEST(DummyRPC, TestWithAltCustomType) {
ServerThread.join();
}
+
+TEST(DummyRPC, TestParallelCallGroup) {
+ Queue Q1, Q2;
+ DummyRPCEndpoint Client(Q1, Q2);
+ DummyRPCEndpoint Server(Q2, Q1);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::IntInt>(
+ [](int X) -> int {
+ return 2 * X;
+ });
+
+ // Handle the negotiate, plus three calls.
+ for (unsigned I = 0; I != 4; ++I) {
+ auto Err = Server.handleOne();
+ EXPECT_FALSE(!!Err) << "Server failed to handle call to int(int)";
+ }
+ });
+
+ {
+ int A, B, C;
+ ParallelCallGroup<DummyRPCEndpoint> PCG(Client);
+
+ {
+ auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+ [&A](Expected<int> Result) {
+ EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
+ A = *Result;
+ return Error::success();
+ }, 1);
+ EXPECT_FALSE(!!Err) << "First parallel call failed for int(int)";
+ }
+
+ {
+ auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+ [&B](Expected<int> Result) {
+ EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
+ B = *Result;
+ return Error::success();
+ }, 2);
+ EXPECT_FALSE(!!Err) << "Second parallel call failed for int(int)";
+ }
+
+ {
+ auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+ [&C](Expected<int> Result) {
+ EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
+ C = *Result;
+ return Error::success();
+ }, 3);
+ EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)";
+ }
+
+ // Handle the three int(int) results.
+ for (unsigned I = 0; I != 3; ++I) {
+ auto Err = Client.handleOne();
+ EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)";
+ }
+
+ {
+ auto Err = PCG.wait();
+ EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)";
+ }
+
+ EXPECT_EQ(A, 2) << "First parallel call returned bogus result";
+ EXPECT_EQ(B, 4) << "Second parallel call returned bogus result";
+ EXPECT_EQ(C, 6) << "Third parallel call returned bogus result";
+ }
+
+ ServerThread.join();
+}
OpenPOWER on IntegriCloud