diff options
Diffstat (limited to 'llvm/unittests/ExecutionEngine/Orc')
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index a84610f5eb4..095bf25291b 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -47,6 +47,54 @@ namespace rpc { class RPCBar {}; +class DummyError : public ErrorInfo<DummyError> { +public: + + static char ID; + + DummyError(uint32_t Val) : Val(Val) {} + + std::error_code convertToErrorCode() const override { + // Use a nonsense error code - we want to verify that errors + // transmitted over the network are replaced with + // OrcErrorCode::UnknownErrorCodeFromRemote. + return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + } + + void log(raw_ostream &OS) const override { + OS << "Dummy error " << Val; + } + + uint32_t getValue() const { return Val; } + +public: + uint32_t Val; +}; + +char DummyError::ID = 0; + +template <typename ChannelT> +void registerDummyErrorSerialization() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + SerializationTraits<ChannelT, Error>:: + template registerErrorType<DummyError>( + "DummyError", + [](ChannelT &C, const DummyError &DE) { + return serializeSeq(C, DE.getValue()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + uint32_t Val; + if (auto Err = deserializeSeq(C, Val)) + return Err; + Err = make_error<DummyError>(Val); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + namespace llvm { namespace orc { namespace rpc { @@ -98,6 +146,16 @@ namespace DummyRPCAPI { static const char* getName() { return "CustomType"; } }; + class ErrorFunc : public Function<ErrorFunc, Error()> { + public: + static const char* getName() { return "ErrorFunc"; } + }; + + class ExpectedFunc : public Function<ExpectedFunc, Expected<uint32_t>()> { + public: + static const char* getName() { return "ExpectedFunc"; } + }; + } class DummyRPCEndpoint : public SingleThreadedRPCEndpoint<QueueChannel> { @@ -493,6 +551,140 @@ TEST(DummyRPC, TestWithAltCustomType) { ServerThread.join(); } +TEST(DummyRPC, ReturnErrorSuccess) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ErrorFunc>( + []() { + return Error::success(); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>( + [&](Error Err) { + EXPECT_FALSE(!!Err) << "Expected success value"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnErrorFailure) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ErrorFunc>( + []() { + return make_error<DummyError>(42); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>( + [&](Error Err) { + EXPECT_TRUE(Err.isA<DummyError>()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 42ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, RPCExpectedSuccess) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ExpectedFunc>( + []() -> uint32_t { + return 42; + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>( + [&](Expected<uint32_t> ValOrErr) { + EXPECT_TRUE(!!ValOrErr) + << "Expected success value"; + EXPECT_EQ(*ValOrErr, 42ULL) + << "Incorrect Expected<uint32_t> deserialization"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +}; + +TEST(DummyRPC, RPCExpectedFailure) { + registerDummyErrorSerialization<QueueChannel>(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler<DummyRPCAPI::ExpectedFunc>( + []() -> Expected<uint32_t> { + return make_error<DummyError>(7); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>( + [&](Expected<uint32_t> ValOrErr) { + EXPECT_FALSE(!!ValOrErr) + << "Expected failure value"; + auto Err = ValOrErr.takeError(); + EXPECT_TRUE(Err.isA<DummyError>()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 7ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +}; + TEST(DummyRPC, TestParallelCallGroup) { auto Channels = createPairedQueueChannels(); DummyRPCEndpoint Client(*Channels.first); |