diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 515b6be39..26873a00e 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -218,6 +218,26 @@ class SyncRecvMessageEndpoint final : public RecvMessageEndpoint void sendResponse(const uint8_t* data, int size); }; +class AsyncDirectRecvEndpoint final : public RecvMessageEndpoint +{ + public: + AsyncDirectRecvEndpoint(const std::string& inprocLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + std::optional recv(int size = 0) override; +}; + +class AsyncDirectSendEndpoint final : public MessageEndpoint +{ + public: + AsyncDirectSendEndpoint(const std::string& inProcLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + void send(const uint8_t* data, size_t dataSize, bool more = false); + + zmq::socket_t socket; +}; + class MessageTimeoutException final : public faabric::util::FaabricException { public: diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 119324885..054ea2730 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -79,6 +79,13 @@ zmq::socket_t socketFactory(zmq::socket_type socketType, CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } + case zmq::socket_type::pair: { + SPDLOG_TRACE("Bind socket: pair {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } case zmq::socket_type::pub: { SPDLOG_TRACE( "Bind socket: pub {} (timeout {}ms)", address, timeoutMs); @@ -123,6 +130,13 @@ zmq::socket_t socketFactory(zmq::socket_type socketType, } case (MessageEndpointConnectType::CONNECT): { switch (socketType) { + case zmq::socket_type::pair: { + SPDLOG_TRACE("Connect socket: pair {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + break; + } case zmq::socket_type::pull: { SPDLOG_TRACE("Connect socket: pull {} (timeout {}ms)", address, @@ -559,4 +573,38 @@ void SyncRecvMessageEndpoint::sendResponse(const uint8_t* data, int size) SPDLOG_TRACE("REP {} ({} bytes)", address, size); doSend(socket, data, size, false); } + +// ---------------------------------------------- +// INTERNAL DIRECT MESSAGE ENDPOINTS +// ---------------------------------------------- + +AsyncDirectRecvEndpoint::AsyncDirectRecvEndpoint(const std::string& inprocLabel, + int timeoutMs) + : RecvMessageEndpoint(inprocLabel, + timeoutMs, + zmq::socket_type::pair, + MessageEndpointConnectType::BIND) +{} + +std::optional AsyncDirectRecvEndpoint::recv(int size) +{ + SPDLOG_TRACE("PAIR recv {} ({} bytes)", address, size); + return RecvMessageEndpoint::recv(size); +} + +AsyncDirectSendEndpoint::AsyncDirectSendEndpoint(const std::string& inprocLabel, + int timeoutMs) + : MessageEndpoint("inproc://" + inprocLabel, timeoutMs) +{ + socket = + setUpSocket(zmq::socket_type::pair, MessageEndpointConnectType::CONNECT); +} + +void AsyncDirectSendEndpoint::send(const uint8_t* data, + size_t dataSize, + bool more) +{ + SPDLOG_TRACE("PAIR send {} ({} bytes, more {})", address, dataSize, more); + doSend(socket, data, dataSize, more); +} } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 80a11470f..41c1480e6 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -1,10 +1,12 @@ #include "faabric_utils.h" #include +#include #include #include #include +#include #include using namespace faabric::transport; @@ -224,6 +226,112 @@ TEST_CASE_METHOD(SchedulerTestFixture, } } -#endif +TEST_CASE_METHOD(SchedulerTestFixture, "Test direct messaging", "[transport]") +{ + std::string expected = "Direct hello"; + const uint8_t* msg = BYTES_CONST(expected.c_str()); + + std::string inprocLabel = "direct-test"; + + AsyncDirectSendEndpoint sender(inprocLabel); + sender.send(msg, expected.size()); + + AsyncDirectRecvEndpoint receiver(inprocLabel); + + std::string actual; + SECTION("Recv with size") + { + faabric::transport::Message recvMsg = + receiver.recv(expected.size()).value(); + actual = std::string(recvMsg.data(), recvMsg.size()); + } + + SECTION("Recv no size") + { + faabric::transport::Message recvMsg = receiver.recv().value(); + actual = std::string(recvMsg.data(), recvMsg.size()); + } + + REQUIRE(actual == expected); +} + +TEST_CASE_METHOD(SchedulerTestFixture, + "Stress test direct messaging", + "[transport]") +{ + int nMessages = 1000; + int nPairs = 3; + std::string inprocLabel = "direct-test-"; + + std::shared_ptr startLatch = + faabric::util::Latch::create(nPairs + 1); + + std::vector senders; + std::vector receivers; + + for (int i = 0; i < nPairs; i++) { + senders.emplace_back([i, nMessages, inprocLabel, &startLatch] { + std::string thisLabel = inprocLabel + std::to_string(i); + AsyncDirectSendEndpoint sender(thisLabel); + + for (int m = 0; m < nMessages; m++) { + std::string expected = + "Direct hello " + std::to_string(i) + "_" + std::to_string(m); + const uint8_t* msg = BYTES_CONST(expected.c_str()); + sender.send(msg, expected.size()); + + if (m % 100 == 0) { + SLEEP_MS(10); + } + + // Make main thread wait until messages are queued (to check no + // issue with connecting before binding) + if (m == 10) { + startLatch->wait(); + } + } + }); + } + + // Wait for queued messages + startLatch->wait(); + + std::atomic success = true; + for (int i = 0; i < nPairs; i++) { + receivers.emplace_back([i, nMessages, inprocLabel, &success] { + std::string thisLabel = inprocLabel + std::to_string(i); + AsyncDirectRecvEndpoint receiver(thisLabel); + + // Receive messages + for (int m = 0; m < nMessages; m++) { + faabric::transport::Message recvMsg = receiver.recv().value(); + std::string actual(recvMsg.data(), recvMsg.size()); + + std::string expected = + "Direct hello " + std::to_string(i) + "_" + std::to_string(m); + + if (actual != expected) { + success.store(false); + } + } + }); + } + + REQUIRE(success.load(std::memory_order_acquire)); + + for (auto& t : senders) { + if (t.joinable()) { + t.join(); + } + } + + for (auto& t : receivers) { + if (t.joinable()) { + t.join(); + } + } +} + +#endif // End ThreadSanitizer exclusion }