diff options
Diffstat (limited to 'llvm')
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 133 | ||||
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 126 |
2 files changed, 257 insertions, 2 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 68ce2c05404..13556110714 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -82,6 +82,17 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex; template <typename DerivedFunc, typename RetT, typename... ArgTs> std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> +class FunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class FunctionArgsTuple<RetT(ArgTs...)> { +public: + using Type = std::tuple<typename std::decay< + typename std::remove_reference<ArgTs>::type>::type...>; +}; + /// Allocates RPC function ids during autonegotiation. /// Specializations of this class must provide four members: /// @@ -349,8 +360,7 @@ public: using ReturnType = RetT; // A std::tuple wrapping the handler arguments. - using ArgStorage = std::tuple<typename std::decay< - typename std::remove_reference<ArgTs>::type>::type...>; + using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type; // Call the given handler with the given arguments. template <typename HandlerT> @@ -589,6 +599,84 @@ private: std::vector<SequenceNumberT> FreeSequenceNumbers; }; +// Checks that predicate P holds for each corresponding pair of type arguments +// from T1 and T2 tuple. +template <template<class, class> class P, typename T1Tuple, + typename T2Tuple> +class RPCArgTypeCheckHelper; + +template <template<class, class> class P> +class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { +public: + static const bool value = true; +}; + +template <template<class, class> class P, typename T, typename... Ts, + typename U, typename... Us> +class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { +public: + static const bool value = + P<T, U>::value && + RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; +}; + +template <template<class, class> class P, typename T1Sig, typename T2Sig> +class RPCArgTypeCheck { +public: + + using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type; + using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type; + + static_assert(std::tuple_size<T1Tuple>::value >= std::tuple_size<T2Tuple>::value, + "Too many arguments to RPC call"); + static_assert(std::tuple_size<T1Tuple>::value <= std::tuple_size<T2Tuple>::value, + "Too few arguments to RPC call"); + + static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanSerialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type + check(typename std::enable_if< + std::is_same< + decltype(T::serialize(std::declval<ChannelT&>(), + std::declval<const ConcreteT&>())), + Error>::value, + void*>::type); + + template <typename> + static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + +template <typename ChannelT, typename WireT, typename ConcreteT> +class CanDeserialize { +private: + using S = SerializationTraits<ChannelT, WireT, ConcreteT>; + + template <typename T> + static std::true_type + check(typename std::enable_if< + std::is_same< + decltype(T::deserialize(std::declval<ChannelT&>(), + std::declval<ConcreteT&>())), + Error>::value, + void*>::type); + + template <typename> + static std::false_type check(...); + +public: + static const bool value = decltype(check<S>(0))::value; +}; + /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure @@ -603,6 +691,7 @@ template <typename ImplT, typename ChannelT, typename FunctionIdT, typename SequenceNumberT> class RPCBase { protected: + class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { public: static const char *getName() { return "__orc_rpc$invalid"; } @@ -619,6 +708,31 @@ protected: static const char *getName() { return "__orc_rpc$negotiate"; } }; + // Helper predicate for testing for the presence of SerializeTraits + // serializers. + template <typename WireT, typename ConcreteT> + class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing serializer for argument (Can't serialize the " + "first template type argument of CanSerializeCheck " + "from the second)"); + }; + + // Helper predicate for testing for the presence of SerializeTraits + // deserializers. + template <typename WireT, typename ConcreteT> + class CanDeserializeCheck + : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { + public: + using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; + + static_assert(value, "Missing deserializer for argument (Can't deserialize " + "the second template type argument of " + "CanDeserializeCheck from the first)"); + }; + public: /// Construct an RPC instance on a channel. RPCBase(ChannelT &C, bool LazyAutoNegotiation) @@ -643,6 +757,13 @@ public: /// with an error if the return value is abandoned due to a channel error. template <typename Func, typename HandlerT, typename... ArgTs> Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { + + static_assert( + detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, + void(ArgTs...)> + ::value, + ""); + // Look up the function ID. FunctionIdT FnId; if (auto FnIdOrErr = getRemoteFunctionId<Func>()) @@ -738,6 +859,14 @@ protected: /// autonegotiation and execution. template <typename Func, typename HandlerT> void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { + + static_assert( + detail::RPCArgTypeCheck<CanDeserializeCheck, + typename Func::Type, + typename detail::HandlerTraits<HandlerT>::Type> + ::value, + ""); + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); LocalFunctionIds[Func::getPrototype()] = NewFnId; Handlers[NewFnId] = diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 9ace46dffd8..19146968bbd 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -58,6 +58,40 @@ private: Queue &OutQueue; }; +class RPCFoo {}; + +template <> +class RPCTypeName<RPCFoo> { +public: + static const char* getName() { return "RPCFoo"; } +}; + +template <> +class SerializationTraits<QueueChannel, RPCFoo, RPCFoo> { +public: + static Error serialize(QueueChannel&, const RPCFoo&) { + return Error::success(); + } + + static Error deserialize(QueueChannel&, RPCFoo&) { + return Error::success(); + } +}; + +class RPCBar {}; + +template <> +class SerializationTraits<QueueChannel, RPCFoo, RPCBar> { +public: + static Error serialize(QueueChannel&, const RPCBar&) { + return Error::success(); + } + + static Error deserialize(QueueChannel&, RPCBar&) { + return Error::success(); + } +}; + class DummyRPCAPI { public: @@ -79,6 +113,12 @@ public: public: static const char* getName() { return "AllTheTypes"; } }; + + class CustomType : public Function<CustomType, RPCFoo(RPCFoo)> { + public: + static const char* getName() { return "CustomType"; } + }; + }; class DummyRPCEndpoint : public DummyRPCAPI, @@ -244,3 +284,89 @@ TEST(DummyRPC, TestSerialization) { ServerThread.join(); } + +TEST(DummyRPC, TestCustomType) { + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::CustomType>( + [](RPCFoo F) {}); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the CustomType call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)"; + } + }); + + { + // Make an async call. + auto Err = Client.callAsync<DummyRPCAPI::CustomType>( + [](Expected<RPCFoo> FOrErr) { + EXPECT_TRUE(!!FOrErr) + << "Async RPCFoo(RPCFoo) response handler failed"; + return Error::success(); + }, RPCFoo()); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)"; + } + + { + // Poke the client to process the result of the RPCFoo() call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) + << "Client failed to handle response from RPCFoo(RPCFoo)"; + } + + ServerThread.join(); +} + +TEST(DummyRPC, TestWithAltCustomType) { + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::CustomType>( + [](RPCBar F) {}); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the CustomType call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)"; + } + }); + + { + // Make an async call. + auto Err = Client.callAsync<DummyRPCAPI::CustomType>( + [](Expected<RPCBar> FOrErr) { + EXPECT_TRUE(!!FOrErr) + << "Async RPCFoo(RPCFoo) response handler failed"; + return Error::success(); + }, RPCBar()); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)"; + } + + { + // Poke the client to process the result of the RPCFoo() call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) + << "Client failed to handle response from RPCFoo(RPCFoo)"; + } + + ServerThread.join(); +} |