diff options
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 80 | ||||
-rw-r--r-- | llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 137 |
2 files changed, 152 insertions, 65 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 7efe6046dc6..2bd198e7819 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -358,54 +358,83 @@ public: /// Return type for asynchronous call primitives. template <typename Func> - using AsyncCallResult = + using AsyncCallResult = std::future<typename Func::OptionalReturn>; + + /// Return type for asynchronous call-with-seq primitives. + template <typename Func> + using AsyncCallWithSeqResult = std::pair<std::future<typename Func::OptionalReturn>, SequenceNumberT>; /// Serialize Args... to channel C, but do not call C.send(). /// - /// For void functions returns a std::future<Error>. For functions that - /// return an R, returns a std::future<Optional<R>>. + /// Returns an error (on serialization failure) or a pair of: + /// (1) A future Optional<T> (or future<bool> for void functions), and + /// (2) A sequence number. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallAsync method, which does not + /// return the sequence numeber, should be preferred. template <typename Func, typename... ArgTs> - ErrorOr<AsyncCallResult<Func>> appendCallAsync(ChannelT &C, - const ArgTs &... Args) { + ErrorOr<AsyncCallWithSeqResult<Func>> + appendCallAsyncWithSeq(ChannelT &C, const ArgTs &... Args) { auto SeqNo = SequenceNumberMgr.getSequenceNumber(); std::promise<typename Func::OptionalReturn> Promise; auto Result = Promise.get_future(); - OutstandingResults[SeqNo] = std::move(Promise); + OutstandingResults[SeqNo] = + createOutstandingResult<Func>(std::move(Promise)); if (auto EC = CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, Args...)) { abandonOutstandingResults(); return EC; } else - return AsyncCallResult<Func>(std::move(Result), SeqNo); + return AsyncCallWithSeqResult<Func>(std::move(Result), SeqNo); } - /// Serialize Args... to channel C and call C.send(). + /// The same as appendCallAsyncWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. template <typename Func, typename... ArgTs> - ErrorOr<AsyncCallResult<Func>> callAsync(ChannelT &C, const ArgTs &... Args) { - auto SeqNo = SequenceNumberMgr.getSequenceNumber(); - std::promise<typename Func::OptionalReturn> Promise; - auto Result = Promise.get_future(); - OutstandingResults[SeqNo] = - createOutstandingResult<Func>(std::move(Promise)); - if (auto EC = CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, - Args...)) { - abandonOutstandingResults(); - return EC; - } + ErrorOr<AsyncCallWithSeqResult<Func>> + callAsyncWithSeq(ChannelT &C, const ArgTs &... Args) { + auto Result = appendCallAsyncWithSeq<Func>(C, Args...); + if (!Result) + return Result; if (auto EC = C.send()) { abandonOutstandingResults(); return EC; } - return AsyncCallResult<Func>(std::move(Result), SeqNo); + return Result; + } + + /// Serialize Args... to channel C, but do not call send. + /// Returns an error if serialization fails, otherwise returns a + /// std::future<Optional<T>> (or a future<bool> for void functions). + template <typename Func, typename... ArgTs> + ErrorOr<AsyncCallResult<Func>> + appendCallAsync(ChannelT &C, const ArgTs &... Args) { + auto ResAndSeqOrErr = appendCallAsyncWithSeq<Func>(C, Args...); + if (ResAndSeqOrErr) + return std::move(ResAndSeqOrErr->first); + return ResAndSeqOrErr.getError(); + } + + /// The same as appendCallAsync, except that it calls C.send to flush the + /// channel after serializing the call. + template <typename Func, typename... ArgTs> + ErrorOr<AsyncCallResult<Func>> + callAsync(ChannelT &C, const ArgTs &... Args) { + auto ResAndSeqOrErr = callAsyncWithSeq<Func>(C, Args...); + if (ResAndSeqOrErr) + return std::move(ResAndSeqOrErr->first); + return ResAndSeqOrErr.getError(); } /// This can be used in single-threaded mode. template <typename Func, typename HandleFtor, typename... ArgTs> typename Func::ErrorReturn callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) { - if (auto ResultAndSeqNoOrErr = callAsync<Func>(C, Args...)) { + if (auto ResultAndSeqNoOrErr = callAsyncWithSeq<Func>(C, Args...)) { auto &ResultAndSeqNo = *ResultAndSeqNoOrErr; if (auto EC = waitForResult(C, ResultAndSeqNo.second, HandleOther)) return EC; @@ -491,12 +520,17 @@ public: /// Read a response from Channel. /// This should be called from the receive loop to retrieve results. - std::error_code handleResponse(ChannelT &C, SequenceNumberT &SeqNo) { + std::error_code handleResponse(ChannelT &C, + SequenceNumberT *SeqNoRet = nullptr) { + SequenceNumberT SeqNo; if (auto EC = deserialize(C, SeqNo)) { abandonOutstandingResults(); return EC; } + if (SeqNoRet) + *SeqNoRet = SeqNo; + auto I = OutstandingResults.find(SeqNo); if (I == OutstandingResults.end()) { abandonOutstandingResults(); @@ -528,7 +562,7 @@ public: return EC; if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) { SequenceNumberT SeqNo; - if (auto EC = handleResponse(C, SeqNo)) + if (auto EC = handleResponse(C, &SeqNo)) return EC; GotTgtResult = (SeqNo == TgtSeqNo); } else if (auto EC = HandleOther(C, Id)) diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 77632e35eb1..87c98ad1dfb 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -17,52 +17,81 @@ using namespace llvm; using namespace llvm::orc; using namespace llvm::orc::remote; +class Queue : public std::queue<char> { +public: + std::mutex& getLock() { return Lock; } +private: + std::mutex Lock; +}; + class QueueChannel : public RPCChannel { public: - QueueChannel(std::queue<char> &Queue) : Queue(Queue) {} + QueueChannel(Queue &InQueue, Queue &OutQueue) + : InQueue(InQueue), OutQueue(OutQueue) {} std::error_code readBytes(char *Dst, unsigned Size) override { - while (Size--) { - *Dst++ = Queue.front(); - Queue.pop(); + while (Size != 0) { + // If there's nothing to read then yield. + while (InQueue.empty()) + std::this_thread::yield(); + + // Lock the channel and read what we can. + std::lock_guard<std::mutex> Lock(InQueue.getLock()); + while (!InQueue.empty() && Size) { + *Dst++ = InQueue.front(); + --Size; + InQueue.pop(); + } } return std::error_code(); } std::error_code appendBytes(const char *Src, unsigned Size) override { + std::lock_guard<std::mutex> Lock(OutQueue.getLock()); while (Size--) - Queue.push(*Src++); + OutQueue.push(*Src++); return std::error_code(); } std::error_code send() override { return std::error_code(); } private: - std::queue<char> &Queue; + Queue &InQueue; + Queue &OutQueue; }; class DummyRPC : public testing::Test, public RPC<QueueChannel> { public: - typedef Function<2, void(bool)> BasicVoid; - typedef Function<3, int32_t(bool)> BasicInt; - typedef Function<4, void(int8_t, uint8_t, int16_t, uint16_t, - int32_t, uint32_t, int64_t, uint64_t, - bool, std::string, std::vector<int>)> AllTheTypes; + + enum FuncId : uint32_t { + VoidBoolId = RPCFunctionIdTraits<FuncId>::FirstValidId, + IntIntId, + AllTheTypesId + }; + + typedef Function<VoidBoolId, void(bool)> VoidBool; + typedef Function<IntIntId, int32_t(int32_t)> IntInt; + typedef Function<AllTheTypesId, void(int8_t, uint8_t, int16_t, uint16_t, + int32_t, uint32_t, int64_t, uint64_t, + bool, std::string, std::vector<int>)> + AllTheTypes; + }; -TEST_F(DummyRPC, TestAsyncBasicVoid) { - std::queue<char> Queue; - QueueChannel C(Queue); +TEST_F(DummyRPC, TestAsyncVoidBool) { + Queue Q1, Q2; + QueueChannel C1(Q1, Q2); + QueueChannel C2(Q2, Q1); // Make an async call. - auto ResOrErr = callAsync<BasicVoid>(C, true); + auto ResOrErr = callAsyncWithSeq<VoidBool>(C1, true); EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; { // Expect a call to Proc1. - auto EC = expect<BasicVoid>(C, + auto EC = expect<VoidBool>(C2, [&](bool &B) { EXPECT_EQ(B, true) << "Bool serialization broken"; @@ -73,7 +102,7 @@ TEST_F(DummyRPC, TestAsyncBasicVoid) { { // Wait for the result. - auto EC = waitForResult(C, ResOrErr->second, handleNone); + auto EC = waitForResult(C1, ResOrErr->second, handleNone); EXPECT_FALSE(EC) << "Could not read result."; } @@ -82,28 +111,29 @@ TEST_F(DummyRPC, TestAsyncBasicVoid) { EXPECT_TRUE(Val) << "Remote void function failed to execute."; } -TEST_F(DummyRPC, TestAsyncBasicInt) { - std::queue<char> Queue; - QueueChannel C(Queue); +TEST_F(DummyRPC, TestAsyncIntInt) { + Queue Q1, Q2; + QueueChannel C1(Q1, Q2); + QueueChannel C2(Q2, Q1); // Make an async call. - auto ResOrErr = callAsync<BasicInt>(C, false); + auto ResOrErr = callAsyncWithSeq<IntInt>(C1, 21); EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; { // Expect a call to Proc1. - auto EC = expect<BasicInt>(C, - [&](bool &B) { - EXPECT_EQ(B, false) + auto EC = expect<IntInt>(C2, + [&](int32_t I) { + EXPECT_EQ(I, 21) << "Bool serialization broken"; - return 42; + return 2 * I; }); EXPECT_FALSE(EC) << "Simple expect over queue failed"; } { // Wait for the result. - auto EC = waitForResult(C, ResOrErr->second, handleNone); + auto EC = waitForResult(C1, ResOrErr->second, handleNone); EXPECT_FALSE(EC) << "Could not read result."; } @@ -114,29 +144,30 @@ TEST_F(DummyRPC, TestAsyncBasicInt) { } TEST_F(DummyRPC, TestSerialization) { - std::queue<char> Queue; - QueueChannel C(Queue); + Queue Q1, Q2; + QueueChannel C1(Q1, Q2); + QueueChannel C2(Q2, Q1); // Make a call to Proc1. std::vector<int> v({42, 7}); - auto ResOrErr = callAsync<AllTheTypes>(C, - -101, - 250, - -10000, - 10000, - -1000000000, - 1000000000, - -10000000000, - 10000000000, - true, - "foo", - v); + auto ResOrErr = callAsyncWithSeq<AllTheTypes>(C1, + -101, + 250, + -10000, + 10000, + -1000000000, + 1000000000, + -10000000000, + 10000000000, + true, + "foo", + v); EXPECT_TRUE(!!ResOrErr) << "Big (serialization test) call over queue failed"; { // Expect a call to Proc1. - auto EC = expect<AllTheTypes>(C, + auto EC = expect<AllTheTypes>(C2, [&](int8_t &s8, uint8_t &u8, int16_t &s16, @@ -178,7 +209,7 @@ TEST_F(DummyRPC, TestSerialization) { { // Wait for the result. - auto EC = waitForResult(C, ResOrErr->second, handleNone); + auto EC = waitForResult(C1, ResOrErr->second, handleNone); EXPECT_FALSE(EC) << "Could not read result."; } @@ -186,3 +217,25 @@ TEST_F(DummyRPC, TestSerialization) { auto Val = ResOrErr->first.get(); EXPECT_TRUE(Val) << "Remote void function failed to execute."; } + +// Test the synchronous call API. +TEST_F(DummyRPC, TestSynchronousCall) { + Queue Q1, Q2; + QueueChannel C1(Q1, Q2); + QueueChannel C2(Q2, Q1); + + auto ServerResult = + std::async(std::launch::async, + [&]() { + return expect<IntInt>(C2, [&](int32_t V) { return V; }); + }); + + auto ValOrErr = callST<IntInt>(C1, 42); + + EXPECT_FALSE(!!ServerResult.get()) + << "Server returned an error."; + EXPECT_TRUE(!!ValOrErr) + << "callST returned an error."; + EXPECT_EQ(*ValOrErr, 42) + << "Incorrect callST<IntInt> result"; +} |