summaryrefslogtreecommitdiffstats
path: root/llvm/unittests/ExecutionEngine/Orc
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/ExecutionEngine/Orc')
-rw-r--r--llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp192
1 files changed, 192 insertions, 0 deletions
diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
index a84610f5eb4..095bf25291b 100644
--- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
@@ -47,6 +47,54 @@ namespace rpc {
class RPCBar {};
+class DummyError : public ErrorInfo<DummyError> {
+public:
+
+ static char ID;
+
+ DummyError(uint32_t Val) : Val(Val) {}
+
+ std::error_code convertToErrorCode() const override {
+ // Use a nonsense error code - we want to verify that errors
+ // transmitted over the network are replaced with
+ // OrcErrorCode::UnknownErrorCodeFromRemote.
+ return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+ }
+
+ void log(raw_ostream &OS) const override {
+ OS << "Dummy error " << Val;
+ }
+
+ uint32_t getValue() const { return Val; }
+
+public:
+ uint32_t Val;
+};
+
+char DummyError::ID = 0;
+
+template <typename ChannelT>
+void registerDummyErrorSerialization() {
+ static bool AlreadyRegistered = false;
+ if (!AlreadyRegistered) {
+ SerializationTraits<ChannelT, Error>::
+ template registerErrorType<DummyError>(
+ "DummyError",
+ [](ChannelT &C, const DummyError &DE) {
+ return serializeSeq(C, DE.getValue());
+ },
+ [](ChannelT &C, Error &Err) -> Error {
+ ErrorAsOutParameter EAO(&Err);
+ uint32_t Val;
+ if (auto Err = deserializeSeq(C, Val))
+ return Err;
+ Err = make_error<DummyError>(Val);
+ return Error::success();
+ });
+ AlreadyRegistered = true;
+ }
+}
+
namespace llvm {
namespace orc {
namespace rpc {
@@ -98,6 +146,16 @@ namespace DummyRPCAPI {
static const char* getName() { return "CustomType"; }
};
+ class ErrorFunc : public Function<ErrorFunc, Error()> {
+ public:
+ static const char* getName() { return "ErrorFunc"; }
+ };
+
+ class ExpectedFunc : public Function<ExpectedFunc, Expected<uint32_t>()> {
+ public:
+ static const char* getName() { return "ExpectedFunc"; }
+ };
+
}
class DummyRPCEndpoint : public SingleThreadedRPCEndpoint<QueueChannel> {
@@ -493,6 +551,140 @@ TEST(DummyRPC, TestWithAltCustomType) {
ServerThread.join();
}
+TEST(DummyRPC, ReturnErrorSuccess) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ErrorFunc>(
+ []() {
+ return Error::success();
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>(
+ [&](Error Err) {
+ EXPECT_FALSE(!!Err) << "Expected success value";
+ return Error::success();
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+}
+
+TEST(DummyRPC, ReturnErrorFailure) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ErrorFunc>(
+ []() {
+ return make_error<DummyError>(42);
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>(
+ [&](Error Err) {
+ EXPECT_TRUE(Err.isA<DummyError>())
+ << "Incorrect error type";
+ return handleErrors(
+ std::move(Err),
+ [](const DummyError &DE) {
+ EXPECT_EQ(DE.getValue(), 42ULL)
+ << "Incorrect DummyError serialization";
+ });
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+}
+
+TEST(DummyRPC, RPCExpectedSuccess) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ExpectedFunc>(
+ []() -> uint32_t {
+ return 42;
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>(
+ [&](Expected<uint32_t> ValOrErr) {
+ EXPECT_TRUE(!!ValOrErr)
+ << "Expected success value";
+ EXPECT_EQ(*ValOrErr, 42ULL)
+ << "Incorrect Expected<uint32_t> deserialization";
+ return Error::success();
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+};
+
+TEST(DummyRPC, RPCExpectedFailure) {
+ registerDummyErrorSerialization<QueueChannel>();
+
+ auto Channels = createPairedQueueChannels();
+ DummyRPCEndpoint Client(*Channels.first);
+ DummyRPCEndpoint Server(*Channels.second);
+
+ std::thread ServerThread([&]() {
+ Server.addHandler<DummyRPCAPI::ExpectedFunc>(
+ []() -> Expected<uint32_t> {
+ return make_error<DummyError>(7);
+ });
+
+ // Handle the negotiate plus one call.
+ for (unsigned I = 0; I != 2; ++I)
+ cantFail(Server.handleOne());
+ });
+
+ cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>(
+ [&](Expected<uint32_t> ValOrErr) {
+ EXPECT_FALSE(!!ValOrErr)
+ << "Expected failure value";
+ auto Err = ValOrErr.takeError();
+ EXPECT_TRUE(Err.isA<DummyError>())
+ << "Incorrect error type";
+ return handleErrors(
+ std::move(Err),
+ [](const DummyError &DE) {
+ EXPECT_EQ(DE.getValue(), 7ULL)
+ << "Incorrect DummyError serialization";
+ });
+ }));
+
+ cantFail(Client.handleOne());
+
+ ServerThread.join();
+};
+
TEST(DummyRPC, TestParallelCallGroup) {
auto Channels = createPairedQueueChannels();
DummyRPCEndpoint Client(*Channels.first);
OpenPOWER on IntegriCloud