From 6860a6d14dae492da2bd56795d9d65efadd59c9b Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 10:52:50 +0000 Subject: [PATCH 01/66] Increase wait time on flush test --- tests/test/scheduler/test_function_client_server.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index b6f5c7b4f..c7c68b8bb 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -17,7 +17,7 @@ #include #include -#define TEST_TIMEOUT_MS 500 +#define TEST_WAIT_MS 1000 using namespace faabric::scheduler; @@ -37,7 +37,7 @@ class ClientServerFixture : cli(LOCALHOST) { server.start(); - usleep(1000 * TEST_TIMEOUT_MS); + usleep(TEST_WAIT_MS * 1000); // Set up executor executorFactory = std::make_shared(); @@ -81,7 +81,7 @@ TEST_CASE_METHOD(ClientServerFixture, // Send flush message cli.sendFlush(); - usleep(1000 * TEST_TIMEOUT_MS); + usleep(TEST_WAIT_MS * 1000); // Check the scheduler has been flushed REQUIRE(sch.getFunctionRegisteredHostCount(msgA) == 0); @@ -138,7 +138,7 @@ TEST_CASE_METHOD(ClientServerFixture, // Make the request cli.executeFunctions(req); - usleep(1000 * TEST_TIMEOUT_MS); + usleep(TEST_WAIT_MS * 1000); // Check no other hosts have been registered faabric::Message m = req->messages().at(0); @@ -227,7 +227,7 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") reqB.set_host(otherHost); *reqB.mutable_function() = msg; cli.unregister(reqB); - usleep(1000 * TEST_TIMEOUT_MS); + usleep(TEST_WAIT_MS * 1000); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 0); sch.setThisHostResources(originalResources); From bac7026421f7f6d15c423e950907eb52d33fe169 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 11:03:33 +0000 Subject: [PATCH 02/66] Small tidy-up in endpoint code --- include/faabric/transport/MessageEndpoint.h | 1 + src/transport/MessageEndpoint.cpp | 24 +++++++++------------ 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index c1852f89d..3afe41211 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -72,6 +72,7 @@ class MessageEndpoint protected: const std::string host; const int port; + const std::string address; std::thread::id tid; int id; diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 07059f87d..ff7f646d3 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -8,6 +8,7 @@ namespace faabric::transport { MessageEndpoint::MessageEndpoint(const std::string& hostIn, int portIn) : host(hostIn) , port(portIn) + , address("tcp://" + host + ":" + std::to_string(port)) , tid(std::this_thread::get_id()) , id(faabric::util::generateGid()) {} @@ -28,9 +29,6 @@ void MessageEndpoint::open(faabric::transport::MessageContext& context, // costly checks when running a Release build. assert(tid == std::this_thread::get_id()); - std::string address = - "tcp://" + this->host + ":" + std::to_string(this->port); - // Note - only one socket may bind, but several can connect. This // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between // bind and connect does not matter. @@ -196,9 +194,6 @@ void MessageEndpoint::close(bool bind) SPDLOG_WARN("Closing socket from a different thread"); } - std::string address = - "tcp://" + this->host + ":" + std::to_string(this->port); - // We duplicate the call to close() because when unbinding, we want to // block until we _actually_ have unbinded, i.e. 0MQ has closed the // socket (which happens asynchronously). For connect()-ed sockets we @@ -215,24 +210,25 @@ void MessageEndpoint::close(bool bind) throw; } } + // NOTE - unbinding a socket has a considerable overhead compared to // disconnecting it. // TODO - could we reuse the monitor? try { - { - zmq::monitor_t mon; - const std::string monAddr = - "inproc://monitor_" + std::to_string(id); - mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); - this->socket->close(); - mon.check_event(-1); - } + zmq::monitor_t mon; + const std::string monAddr = + "inproc://monitor_" + std::to_string(id); + mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); + + this->socket->close(); + mon.check_event(-1); } catch (zmq::error_t& e) { if (e.num() != ZMQ_ETERM) { SPDLOG_ERROR("Error closing bind socket: {}", e.what()); throw; } } + } else { try { this->socket->disconnect(address); From 7a183a1d9e1b60eee8ed4165b37d21aa630a2eec Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 11:35:32 +0000 Subject: [PATCH 03/66] Add timeout on close and catch-all error handling --- include/faabric/transport/MessageEndpoint.h | 4 + src/transport/MessageEndpoint.cpp | 242 +++++++++----------- 2 files changed, 112 insertions(+), 134 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 3afe41211..78fbb1fe0 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -80,6 +80,10 @@ class MessageEndpoint int sendTimeoutMs = DEFAULT_SEND_TIMEOUT_MS; void validateTimeout(int value); + + Message recvBuffer(int size); + + Message recvNoBuffer(); }; /* Send and Recv Message Endpoints */ diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index ff7f646d3..a3e51bf33 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -4,7 +4,24 @@ #include +#define ZMQ_CATCH(op, label) \ + try { \ + op; \ + } catch (zmq::error_t & e) { \ + SPDLOG_ERROR("Caught ZeroMQ error for {} on address {}: {} ({})", \ + label, \ + address, \ + e.num(), \ + e.what()); \ + throw; \ + } catch (...) { \ + SPDLOG_ERROR( \ + "Caught non-ZeroMQ error for {} on address {}", label, address); \ + throw; \ + } + namespace faabric::transport { + MessageEndpoint::MessageEndpoint(const std::string& hostIn, int portIn) : host(hostIn) , port(portIn) @@ -34,48 +51,34 @@ void MessageEndpoint::open(faabric::transport::MessageContext& context, // bind and connect does not matter. switch (sockType) { case faabric::transport::SocketType::PUSH: - try { - this->socket = std::make_unique( - context.get(), zmq::socket_type::push); - } catch (zmq::error_t& e) { - SPDLOG_ERROR( - "Error opening SEND socket to {}: {}", address, e.what()); - throw; - } + ZMQ_CATCH( + { + this->socket = std::make_unique( + context.get(), zmq::socket_type::push); + }, + "push_socket") break; case faabric::transport::SocketType::PULL: - try { - this->socket = std::make_unique( - context.get(), zmq::socket_type::pull); - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error opening RECV socket bound to {}: {}", - address, - e.what()); - throw; - } + ZMQ_CATCH( + { + this->socket = std::make_unique( + context.get(), zmq::socket_type::pull); + }, + "pull_socket") break; default: throw std::runtime_error("Unrecognized socket type"); } + + // Check opening the socket has worked assert(this->socket != nullptr); // Bind or connect the socket if (bind) { - try { - this->socket->bind(address); - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error binding socket to {}: {}", address, e.what()); - throw; - } + ZMQ_CATCH({ this->socket->bind(address); }, "bind") } else { - try { - this->socket->connect(address); - } catch (zmq::error_t& e) { - SPDLOG_ERROR( - "Error connecting socket to {}: {}", address, e.what()); - throw; - } + ZMQ_CATCH({ this->socket->connect(address); }, "connect") } // Set socket options @@ -88,37 +91,22 @@ void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) assert(tid == std::this_thread::get_id()); assert(this->socket != nullptr); - if (more) { - try { - auto res = this->socket->send(zmq::buffer(serialisedMsg, msgSize), - zmq::send_flags::sndmore); - if (res != msgSize) { - SPDLOG_ERROR("Sent different bytes than expected (sent " - "{}, expected {})", - res.value_or(0), - msgSize); - throw std::runtime_error("Error sending message"); - } - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error sending message: {}", e.what()); - throw; - } - } else { - try { - auto res = this->socket->send(zmq::buffer(serialisedMsg, msgSize), - zmq::send_flags::none); - if (res != msgSize) { - SPDLOG_ERROR("Sent different bytes than expected (sent " - "{}, expected {})", - res.value_or(0), - msgSize); - throw std::runtime_error("Error sending message"); - } - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error sending message: {}", e.what()); - throw; - } - } + zmq::send_flags sendFlags = + more ? zmq::send_flags::sndmore : zmq::send_flags::none; + + ZMQ_CATCH( + { + auto res = + this->socket->send(zmq::buffer(serialisedMsg, msgSize), sendFlags); + if (res != msgSize) { + SPDLOG_ERROR("Sent different bytes than expected (sent " + "{}, expected {})", + res.value_or(0), + msgSize); + throw std::runtime_error("Error sending message"); + } + }, + "send") } // By passing the expected recv buffer size, we instrument zeromq to receive on @@ -129,42 +117,53 @@ Message MessageEndpoint::recv(int size) assert(this->socket != nullptr); assert(size >= 0); - // Pre-allocate buffer to avoid copying data - if (size > 0) { - Message msg(size); - - try { - auto res = this->socket->recv(zmq::buffer(msg.udata(), msg.size())); - - if (!res.has_value()) { - SPDLOG_ERROR("Timed out receiving message of size {}", size); - throw MessageTimeoutException("Timed out receiving message"); - } - - if (res.has_value() && (res->size != res->untruncated_size)) { - SPDLOG_ERROR("Received more bytes than buffer can hold. " - "Received: {}, capacity {}", - res->untruncated_size, - res->size); - throw std::runtime_error("Error receiving message"); - } - } catch (zmq::error_t& e) { - if (e.num() == ZMQ_ETERM) { - // Return empty message to signify termination - SPDLOG_TRACE("Shutting endpoint down after receiving ETERM"); - return Message(); - } - - // Print default message and rethrow - SPDLOG_ERROR("Error receiving message: {} ({})", e.num(), e.what()); - throw; - } - - return msg; + if (size == 0) { + return recvNoBuffer(); } + return recvBuffer(size); +} + +Message MessageEndpoint::recvBuffer(int size) +{ + // Pre-allocate buffer to avoid copying data + Message msg(size); + + ZMQ_CATCH( + try { + auto res = this->socket->recv(zmq::buffer(msg.udata(), msg.size())); + + if (!res.has_value()) { + SPDLOG_ERROR("Timed out receiving message of size {}", size); + throw MessageTimeoutException("Timed out receiving message"); + } + + if (res.has_value() && (res->size != res->untruncated_size)) { + SPDLOG_ERROR("Received more bytes than buffer can hold. " + "Received: {}, capacity {}", + res->untruncated_size, + res->size); + throw std::runtime_error("Error receiving message"); + } + } catch (zmq::error_t& e) { + if (e.num() == ZMQ_ETERM) { + // Return empty message to signify termination + SPDLOG_TRACE("Shutting endpoint down after receiving ETERM"); + return Message(); + } + + throw; + }, + "recv_buffer") + + return msg; +} + +Message MessageEndpoint::recvNoBuffer() +{ // Allocate a message to receive data zmq::message_t msg; + ZMQ_CATCH( try { auto res = this->socket->recv(msg); if (!res.has_value()) { @@ -176,11 +175,9 @@ Message MessageEndpoint::recv(int size) // Return empty message to signify termination SPDLOG_TRACE("Shutting endpoint down after receiving ETERM"); return Message(); - } else { - SPDLOG_ERROR("Error receiving message: {} ({})", e.num(), e.what()); - throw; } - } + throw; + }, "recv_no_buffer") // Copy the received message to a buffer whose scope we control return Message(msg); @@ -202,48 +199,25 @@ void MessageEndpoint::close(bool bind) // system is slow at closing sockets, and the application relies a lot // on synchronous message-passing. if (bind) { - try { - this->socket->unbind(address); - } catch (zmq::error_t& e) { - if (e.num() != ZMQ_ETERM) { - SPDLOG_ERROR("Error unbinding socket: {}", e.what()); - throw; - } - } - // NOTE - unbinding a socket has a considerable overhead compared to // disconnecting it. + ZMQ_CATCH(this->socket->unbind(address), "unbind") + // TODO - could we reuse the monitor? - try { - zmq::monitor_t mon; - const std::string monAddr = - "inproc://monitor_" + std::to_string(id); - mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); - - this->socket->close(); - mon.check_event(-1); - } catch (zmq::error_t& e) { - if (e.num() != ZMQ_ETERM) { - SPDLOG_ERROR("Error closing bind socket: {}", e.what()); - throw; - } - } + zmq::monitor_t mon; + const std::string monAddr = + "inproc://monitor_" + std::to_string(id); + mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); + + ZMQ_CATCH(this->socket->close(), "close") + + // Wait for this to complete + mon.check_event(recvTimeoutMs); } else { - try { - this->socket->disconnect(address); - } catch (zmq::error_t& e) { - if (e.num() != ZMQ_ETERM) { - SPDLOG_ERROR("Error disconnecting socket: {}", e.what()); - throw; - } - } - try { - this->socket->close(); - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error closing connect socket: {}", e.what()); - throw; - } + ZMQ_CATCH(this->socket->disconnect(address), "disconnect") + + ZMQ_CATCH(this->socket->close(), "disconnect") } // Finally, null the socket From 4dfc42dd1edc5ae2e99d025f4f912f80334a9c37 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 12:07:42 +0000 Subject: [PATCH 04/66] Formatting --- src/transport/MessageEndpoint.cpp | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index a3e51bf33..9014285a2 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -147,8 +147,7 @@ Message MessageEndpoint::recvBuffer(int size) } } catch (zmq::error_t& e) { if (e.num() == ZMQ_ETERM) { - // Return empty message to signify termination - SPDLOG_TRACE("Shutting endpoint down after receiving ETERM"); + SPDLOG_TRACE("Endpoint received ETERM"); return Message(); } @@ -164,20 +163,20 @@ Message MessageEndpoint::recvNoBuffer() // Allocate a message to receive data zmq::message_t msg; ZMQ_CATCH( - try { - auto res = this->socket->recv(msg); - if (!res.has_value()) { - SPDLOG_ERROR("Timed out receiving message with no size"); - throw MessageTimeoutException("Timed out receiving message"); - } - } catch (zmq::error_t& e) { - if (e.num() == ZMQ_ETERM) { - // Return empty message to signify termination - SPDLOG_TRACE("Shutting endpoint down after receiving ETERM"); - return Message(); - } - throw; - }, "recv_no_buffer") + try { + auto res = this->socket->recv(msg); + if (!res.has_value()) { + SPDLOG_ERROR("Timed out receiving message with no size"); + throw MessageTimeoutException("Timed out receiving message"); + } + } catch (zmq::error_t& e) { + if (e.num() == ZMQ_ETERM) { + SPDLOG_TRACE("Endpoint received ETERM"); + return Message(); + } + throw; + }, + "recv_no_buffer") // Copy the received message to a buffer whose scope we control return Message(msg); From 1f5f00057120550f5d820ade57b84f5a42de699f Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 12:55:03 +0000 Subject: [PATCH 05/66] Lazy-init global message context and don't shut down --- include/faabric/transport/MessageContext.h | 4 +- src/transport/MessageContext.cpp | 34 +++++++++----- src/transport/MessageEndpoint.cpp | 44 ++++++++++--------- src/transport/MessageEndpointServer.cpp | 16 +++---- tests/test/transport/test_message_context.cpp | 30 ------------- 5 files changed, 56 insertions(+), 72 deletions(-) delete mode 100644 tests/test/transport/test_message_context.cpp diff --git a/include/faabric/transport/MessageContext.h b/include/faabric/transport/MessageContext.h index 78c17afb2..bda3a54b2 100644 --- a/include/faabric/transport/MessageContext.h +++ b/include/faabric/transport/MessageContext.h @@ -33,8 +33,8 @@ class MessageContext */ void close(); - bool isContextShutDown; + bool isClosed = false; }; -faabric::transport::MessageContext& getGlobalMessageContext(); +MessageContext& getGlobalMessageContext(); } diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index 8c7f74649..dc0c7e9a4 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -1,9 +1,14 @@ #include +#include +#include namespace faabric::transport { + +static std::unique_ptr instance = nullptr; +static std::shared_mutex messageContextMx; + MessageContext::MessageContext() : ctx(1) - , isContextShutDown(false) {} MessageContext::MessageContext(int overrideIoThreads) @@ -17,8 +22,8 @@ MessageContext::~MessageContext() void MessageContext::close() { + isClosed = true; this->ctx.close(); - this->isContextShutDown = true; } zmq::context_t& MessageContext::get() @@ -28,15 +33,22 @@ zmq::context_t& MessageContext::get() faabric::transport::MessageContext& getGlobalMessageContext() { - static auto msgContext = - std::make_unique(); - // The message context needs to be opened and closed every server instance. - // Sometimes (e.g. tests) the scheduler is re-used, but the message context - // needs to be reset. In this situations we override the shut-down message - // context. - if (msgContext->isContextShutDown) { - msgContext = std::make_unique(); + if (instance == nullptr) { + faabric::util::FullLock lock(messageContextMx); + if (instance == nullptr) { + instance = std::make_unique(); + } + } + + { + faabric::util::SharedLock lock(messageContextMx); + + if (instance->isClosed) { + throw std::runtime_error( + "Global ZeroMQ message context already closed"); + } + + return *instance; } - return *msgContext; } } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 9014285a2..4a3222534 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -4,16 +4,21 @@ #include -#define ZMQ_CATCH(op, label) \ +#define CATCH_ZMQ_ERR(op, label) \ try { \ op; \ } catch (zmq::error_t & e) { \ - SPDLOG_ERROR("Caught ZeroMQ error for {} on address {}: {} ({})", \ - label, \ - address, \ - e.num(), \ - e.what()); \ - throw; \ + if (e.num() == ZMQ_ETERM) { \ + SPDLOG_TRACE( \ + "Got ZeroMQ ETERM for {} on address {}", label, address); \ + } else { \ + SPDLOG_ERROR("Caught ZeroMQ error for {} on address {}: {} ({})", \ + label, \ + address, \ + e.num(), \ + e.what()); \ + throw; \ + } \ } catch (...) { \ SPDLOG_ERROR( \ "Caught non-ZeroMQ error for {} on address {}", label, address); \ @@ -51,7 +56,7 @@ void MessageEndpoint::open(faabric::transport::MessageContext& context, // bind and connect does not matter. switch (sockType) { case faabric::transport::SocketType::PUSH: - ZMQ_CATCH( + CATCH_ZMQ_ERR( { this->socket = std::make_unique( context.get(), zmq::socket_type::push); @@ -59,7 +64,7 @@ void MessageEndpoint::open(faabric::transport::MessageContext& context, "push_socket") break; case faabric::transport::SocketType::PULL: - ZMQ_CATCH( + CATCH_ZMQ_ERR( { this->socket = std::make_unique( context.get(), zmq::socket_type::pull); @@ -76,9 +81,9 @@ void MessageEndpoint::open(faabric::transport::MessageContext& context, // Bind or connect the socket if (bind) { - ZMQ_CATCH({ this->socket->bind(address); }, "bind") + CATCH_ZMQ_ERR(this->socket->bind(address), "bind") } else { - ZMQ_CATCH({ this->socket->connect(address); }, "connect") + CATCH_ZMQ_ERR(this->socket->connect(address), "connect") } // Set socket options @@ -94,7 +99,7 @@ void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) zmq::send_flags sendFlags = more ? zmq::send_flags::sndmore : zmq::send_flags::none; - ZMQ_CATCH( + CATCH_ZMQ_ERR( { auto res = this->socket->send(zmq::buffer(serialisedMsg, msgSize), sendFlags); @@ -129,7 +134,7 @@ Message MessageEndpoint::recvBuffer(int size) // Pre-allocate buffer to avoid copying data Message msg(size); - ZMQ_CATCH( + CATCH_ZMQ_ERR( try { auto res = this->socket->recv(zmq::buffer(msg.udata(), msg.size())); @@ -162,7 +167,7 @@ Message MessageEndpoint::recvNoBuffer() { // Allocate a message to receive data zmq::message_t msg; - ZMQ_CATCH( + CATCH_ZMQ_ERR( try { auto res = this->socket->recv(msg); if (!res.has_value()) { @@ -185,7 +190,6 @@ Message MessageEndpoint::recvNoBuffer() void MessageEndpoint::close(bool bind) { if (this->socket != nullptr) { - if (tid != std::this_thread::get_id()) { SPDLOG_WARN("Closing socket from a different thread"); } @@ -194,13 +198,13 @@ void MessageEndpoint::close(bool bind) // block until we _actually_ have unbinded, i.e. 0MQ has closed the // socket (which happens asynchronously). For connect()-ed sockets we // don't care. - // Not blobking on un-bind can cause race-conditions when the underlying + // Not blocking on un-bind can cause race-conditions when the underlying // system is slow at closing sockets, and the application relies a lot // on synchronous message-passing. if (bind) { // NOTE - unbinding a socket has a considerable overhead compared to // disconnecting it. - ZMQ_CATCH(this->socket->unbind(address), "unbind") + CATCH_ZMQ_ERR(this->socket->unbind(address), "unbind") // TODO - could we reuse the monitor? zmq::monitor_t mon; @@ -208,15 +212,15 @@ void MessageEndpoint::close(bool bind) "inproc://monitor_" + std::to_string(id); mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); - ZMQ_CATCH(this->socket->close(), "close") + CATCH_ZMQ_ERR(this->socket->close(), "close_unbind") // Wait for this to complete mon.check_event(recvTimeoutMs); } else { - ZMQ_CATCH(this->socket->disconnect(address), "disconnect") + CATCH_ZMQ_ERR(this->socket->disconnect(address), "disconnect") - ZMQ_CATCH(this->socket->close(), "disconnect") + CATCH_ZMQ_ERR(this->socket->close(), "close_disconnect") } // Finally, null the socket diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 13600759c..200a72fa4 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -16,7 +16,7 @@ void MessageEndpointServer::start() void MessageEndpointServer::start(faabric::transport::MessageContext& context) { // Start serving thread in background - this->servingThread = std::thread([this, &context] { + servingThread = std::thread([this, &context] { RecvMessageEndpoint serverEndpoint(this->port); // Open message endpoint, and bind @@ -41,13 +41,9 @@ void MessageEndpointServer::stop() void MessageEndpointServer::stop(faabric::transport::MessageContext& context) { - // Note - different servers will concurrently close the server context, but - // this structure is thread-safe, and the close operation idempotent. - context.close(); - - // Finally join the serving thread - if (this->servingThread.joinable()) { - this->servingThread.join(); + // Join the serving thread + if (servingThread.joinable()) { + servingThread.join(); } } @@ -57,17 +53,19 @@ int MessageEndpointServer::recv(RecvMessageEndpoint& endpoint) // Receive header and body Message header = endpoint.recv(); + // Detect shutdown condition if (header.udata() == nullptr) { return ENDPOINT_SERVER_SHUTDOWN; } + // Check the header was sent with ZMQ_SNDMORE flag if (!header.more()) { throw std::runtime_error("Header sent without SNDMORE flag"); } - Message body = endpoint.recv(); // Check that there are no more messages to receive + Message body = endpoint.recv(); if (body.more()) { throw std::runtime_error("Body sent with SNDMORE flag"); } diff --git a/tests/test/transport/test_message_context.cpp b/tests/test/transport/test_message_context.cpp deleted file mode 100644 index 35173263e..000000000 --- a/tests/test/transport/test_message_context.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include - -#include - -using namespace faabric::transport; - -namespace tests { -TEST_CASE("Test global message context", "[transport]") -{ - // Get message context - MessageContext& context = getGlobalMessageContext(); - - // Context not shut down - REQUIRE(!context.isContextShutDown); - - // Close message context - REQUIRE_NOTHROW(context.close()); - - // Context is shut down - REQUIRE(context.isContextShutDown); - - // Get message context again, lazy-initialise it - MessageContext& newContext = getGlobalMessageContext(); - - // Context not shut down - REQUIRE(!newContext.isContextShutDown); - - newContext.close(); -} -} From d38dd03de01bc9b7a84625138f95d49e7a2dbcff Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 13:44:06 +0000 Subject: [PATCH 06/66] Remove need to close global message context --- include/faabric/transport/MessageContext.h | 9 ---- .../faabric/transport/MessageEndpointServer.h | 4 +- src/transport/MessageContext.cpp | 11 ----- src/transport/MessageEndpointServer.cpp | 47 +++++++++++++------ tests/test/transport/test_message_server.cpp | 13 +++-- tests/utils/fixtures.h | 2 +- 6 files changed, 44 insertions(+), 42 deletions(-) diff --git a/include/faabric/transport/MessageContext.h b/include/faabric/transport/MessageContext.h index bda3a54b2..e5cef9158 100644 --- a/include/faabric/transport/MessageContext.h +++ b/include/faabric/transport/MessageContext.h @@ -25,15 +25,6 @@ class MessageContext zmq::context_t ctx; zmq::context_t& get(); - - /* Close the message context - * - * In 0MQ terms, this method calls close() on the context, which in turn - * first shuts down (i.e. stop blocking operations) and then closes. - */ - void close(); - - bool isClosed = false; }; MessageContext& getGlobalMessageContext(); diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 1f84d75ff..46ed9b452 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -42,7 +42,7 @@ class MessageEndpointServer virtual void stop(); protected: - int recv(faabric::transport::RecvMessageEndpoint& endpoint); + bool recv(); /* Template function to handle message reception * @@ -67,6 +67,8 @@ class MessageEndpointServer private: const int port; + std::unique_ptr endpoint = nullptr; + std::thread servingThread; }; } diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index dc0c7e9a4..76f84ec10 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -17,12 +17,6 @@ MessageContext::MessageContext(int overrideIoThreads) MessageContext::~MessageContext() { - this->close(); -} - -void MessageContext::close() -{ - isClosed = true; this->ctx.close(); } @@ -43,11 +37,6 @@ faabric::transport::MessageContext& getGlobalMessageContext() { faabric::util::SharedLock lock(messageContextMx); - if (instance->isClosed) { - throw std::runtime_error( - "Global ZeroMQ message context already closed"); - } - return *instance; } } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 200a72fa4..c3bee6bec 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -17,20 +18,25 @@ void MessageEndpointServer::start(faabric::transport::MessageContext& context) { // Start serving thread in background servingThread = std::thread([this, &context] { - RecvMessageEndpoint serverEndpoint(this->port); + endpoint = std::make_unique(this->port); // Open message endpoint, and bind - serverEndpoint.open(context); - assert(serverEndpoint.socket != nullptr); + endpoint->open(context); - // Loop until context is terminated + // Loop until we receive a shutdown message while (true) { - int rc = this->recv(serverEndpoint); - if (rc == ENDPOINT_SERVER_SHUTDOWN) { - serverEndpoint.close(); - break; + try { + bool messageReceived = this->recv(); + if (!messageReceived) { + SPDLOG_TRACE("Server received shutdown message"); + break; + } + } catch (MessageTimeoutException& ex) { + continue; } } + + endpoint->close(); }); } @@ -41,22 +47,33 @@ void MessageEndpointServer::stop() void MessageEndpointServer::stop(faabric::transport::MessageContext& context) { + // Send a shutdown message via a temporary endpoint + SPDLOG_TRACE("Sending shutdown message locally to {}:{}", + endpoint->getHost(), + endpoint->getPort()); + SendMessageEndpoint e(endpoint->getHost(), endpoint->getPort()); + e.open(getGlobalMessageContext()); + e.send(nullptr, 0); + // Join the serving thread if (servingThread.joinable()) { servingThread.join(); } + + e.close(); } -int MessageEndpointServer::recv(RecvMessageEndpoint& endpoint) +bool MessageEndpointServer::recv() { - assert(endpoint.socket != nullptr); + // Check endpoint has been initialised + assert(endpoint->socket != nullptr); // Receive header and body - Message header = endpoint.recv(); + Message header = endpoint->recv(); // Detect shutdown condition - if (header.udata() == nullptr) { - return ENDPOINT_SERVER_SHUTDOWN; + if (header.size() == 0) { + return false; } // Check the header was sent with ZMQ_SNDMORE flag @@ -65,7 +82,7 @@ int MessageEndpointServer::recv(RecvMessageEndpoint& endpoint) } // Check that there are no more messages to receive - Message body = endpoint.recv(); + Message body = endpoint->recv(); if (body.more()) { throw std::runtime_error("Body sent with SNDMORE flag"); } @@ -74,7 +91,7 @@ int MessageEndpointServer::recv(RecvMessageEndpoint& endpoint) // Server-specific message handling doRecv(header, body); - return 0; + return true; } // We create a new endpoint every time. Re-using them would be a possible diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 79bbaf65b..a5116538e 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -38,7 +38,7 @@ class DummyServer final : public MessageEndpointServer faabric::transport::Message& body) override { // Dummy server, do nothing but increment the message count - this->messageCount++; + messageCount++; } }; @@ -83,16 +83,19 @@ TEST_CASE("Test send one message to server", "[transport]") server.start(); // Open the source endpoint client, don't bind - auto& context = getGlobalMessageContext(); MessageEndpointClient src(thisHost, testPort); + + auto& context = getGlobalMessageContext(); src.open(context); // Send message: server expects header + body std::string header = "header"; uint8_t headerMsg[header.size()]; memcpy(headerMsg, header.c_str(), header.size()); + // Mark we are sending the header - REQUIRE_NOTHROW(src.send(headerMsg, header.size(), true)); + src.send(headerMsg, header.size(), true); + // Send the body std::string body = "body"; uint8_t bodyMsg[body.size()]; @@ -132,8 +135,7 @@ TEST_CASE("Test send one-off response to client", "[transport]") uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - REQUIRE_NOTHROW( - server.sendResponse(msg, expectedMsg.size(), thisHost, testPort)); + server.sendResponse(msg, expectedMsg.size(), thisHost, testPort); if (clientThread.joinable()) { clientThread.join(); @@ -209,6 +211,7 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") int clientTimeout; bool expectFailure; + SECTION("Long timeout no failure") { clientTimeout = 20000; diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index b102db298..c4f152d0d 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -128,7 +128,7 @@ class MessageContextFixture : public SchedulerTestFixture : context(faabric::transport::getGlobalMessageContext()) {} - ~MessageContextFixture() { context.close(); } + ~MessageContextFixture() { } }; class MpiBaseTestFixture : public SchedulerTestFixture From 9cb6b89bb02a74dd439e64c3b7330eb0574e8460 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 14:00:28 +0000 Subject: [PATCH 07/66] Fix up timeout test --- tests/test/transport/test_message_server.cpp | 21 +++++++++++++++----- tests/utils/fixtures.h | 2 +- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index a5116538e..7e11a208a 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -57,12 +57,20 @@ class SlowServer final : public MessageEndpointServer int returnPort) { usleep(delayMs * 1000); + + SPDLOG_DEBUG("Slow message server test sending response"); + MessageEndpointServer::sendResponse( + serialisedMsg, size, returnHost, returnPort); } private: void doRecv(faabric::transport::Message& header, faabric::transport::Message& body) override - {} + { + SPDLOG_DEBUG("Slow message server test recv"); + std::vector data = { 0, 1, 2, 3 }; + sendResponse(data.data(), data.size(), thisHost, testPort); + } }; namespace tests { @@ -205,10 +213,6 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") // Wait for the server to start up usleep(500 * 1000); - // Set up the client - auto& context = getGlobalMessageContext(); - MessageEndpointClient cli(thisHost, testPort); - int clientTimeout; bool expectFailure; @@ -224,9 +228,16 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") expectFailure = true; } + // Set up the client + auto& context = getGlobalMessageContext(); + MessageEndpointClient cli(thisHost, testPort); cli.setRecvTimeoutMs(clientTimeout); cli.open(context); + std::vector data = { 1, 1, 1 }; + cli.send(data.data(), data.size(), true); + cli.send(data.data(), data.size()); + // Check for failure accordingly if (expectFailure) { REQUIRE_THROWS_AS(cli.awaitResponse(testPort + REPLY_PORT_OFFSET), diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index c4f152d0d..2e985a741 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -128,7 +128,7 @@ class MessageContextFixture : public SchedulerTestFixture : context(faabric::transport::getGlobalMessageContext()) {} - ~MessageContextFixture() { } + ~MessageContextFixture() {} }; class MpiBaseTestFixture : public SchedulerTestFixture From db3ad143b78459bd429b44dd1ddfa96555c3d811 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 18 Jun 2021 15:03:50 +0000 Subject: [PATCH 08/66] Removed catch-all error handling --- src/transport/MessageContext.cpp | 1 + src/transport/MessageEndpoint.cpp | 4 ---- src/transport/MessageEndpointClient.cpp | 3 +++ 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index 76f84ec10..73a5f4391 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -17,6 +17,7 @@ MessageContext::MessageContext(int overrideIoThreads) MessageContext::~MessageContext() { + SPDLOG_TRACE("Closing global ZeroMQ message context"); this->ctx.close(); } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 4a3222534..dcd607d7d 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -19,10 +19,6 @@ e.what()); \ throw; \ } \ - } catch (...) { \ - SPDLOG_ERROR( \ - "Caught non-ZeroMQ error for {} on address {}", label, address); \ - throw; \ } namespace faabric::transport { diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index cdd2461f9..45798a067 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -1,4 +1,5 @@ #include +#include namespace faabric::transport { MessageEndpointClient::MessageEndpointClient(const std::string& host, int port) @@ -17,7 +18,9 @@ Message MessageEndpointClient::awaitResponse(int port) endpoint.setSendTimeoutMs(sendTimeoutMs); endpoint.open(faabric::transport::getGlobalMessageContext()); + Message receivedMessage = endpoint.recv(); + endpoint.close(); return receivedMessage; From c74687696ffec623b0e0a1bd625b79002b695471 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 21 Jun 2021 14:14:11 +0000 Subject: [PATCH 09/66] Remove message context argument from open and close --- .../faabric/scheduler/FunctionCallClient.h | 1 - include/faabric/transport/MessageContext.h | 11 ++++++-- include/faabric/transport/MessageEndpoint.h | 9 +++---- .../faabric/transport/MessageEndpointServer.h | 16 ------------ src/scheduler/FunctionCallClient.cpp | 2 +- src/scheduler/FunctionCallServer.cpp | 2 +- src/scheduler/SnapshotClient.cpp | 2 +- src/scheduler/SnapshotServer.cpp | 2 +- src/state/StateClient.cpp | 2 +- src/transport/MessageContext.cpp | 19 +++++++++----- src/transport/MessageEndpoint.cpp | 18 ++++++------- src/transport/MessageEndpointClient.cpp | 13 ++++++---- src/transport/MessageEndpointServer.cpp | 18 +++---------- src/transport/MpiMessageEndpoint.cpp | 12 ++++----- .../test_message_endpoint_client.cpp | 26 +++++++++---------- tests/test/transport/test_message_server.cpp | 12 +++------ tests/utils/fixtures.h | 8 +----- 17 files changed, 74 insertions(+), 99 deletions(-) diff --git a/include/faabric/scheduler/FunctionCallClient.h b/include/faabric/scheduler/FunctionCallClient.h index 55e105231..c6f20519f 100644 --- a/include/faabric/scheduler/FunctionCallClient.h +++ b/include/faabric/scheduler/FunctionCallClient.h @@ -2,7 +2,6 @@ #include #include -#include #include #include diff --git a/include/faabric/transport/MessageContext.h b/include/faabric/transport/MessageContext.h index e5cef9158..42917c3a6 100644 --- a/include/faabric/transport/MessageContext.h +++ b/include/faabric/transport/MessageContext.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace faabric::transport { @@ -22,10 +23,16 @@ class MessageContext ~MessageContext(); + zmq::context_t& getZMQContext(); + + static std::shared_ptr getInstance(); + + private: zmq::context_t ctx; - zmq::context_t& get(); + static std::shared_ptr instance; + static std::shared_mutex mx; }; -MessageContext& getGlobalMessageContext(); +std::shared_ptr getGlobalMessageContext(); } diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 78fbb1fe0..f5e949aad 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -3,7 +3,6 @@ #include #include -#include #include #include @@ -47,9 +46,7 @@ class MessageEndpoint ~MessageEndpoint(); - void open(faabric::transport::MessageContext& context, - faabric::transport::SocketType sockTypeIn, - bool bind); + void open(SocketType sockTypeIn, bool bind); void close(bool bind); @@ -93,7 +90,7 @@ class SendMessageEndpoint : public MessageEndpoint public: SendMessageEndpoint(const std::string& hostIn, int portIn); - void open(MessageContext& context); + void open(); void close(); }; @@ -103,7 +100,7 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn); - void open(MessageContext& context); + void open(); void close(); }; diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 46ed9b452..268d2d14f 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -22,21 +21,6 @@ class MessageEndpointServer public: MessageEndpointServer(int portIn); - /* Start and stop the server - * - * Generic methods to start and stop a message endpoint server. They take - * a, thread-safe, 0MQ context as an argument. The stop method will block - * until _all_ sockets within the context have been closed. Sockets blocking - * on a `recv` will be interrupted with ETERM upon context closure. - */ - void start(faabric::transport::MessageContext& context); - - void stop(faabric::transport::MessageContext& context); - - /* Common start and stop entrypoint - * - * Call the generic methods with the default global message context. - */ void start(); virtual void stop(); diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 27e974151..006fa631c 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -82,7 +82,7 @@ void clearMockRequests() FunctionCallClient::FunctionCallClient(const std::string& hostIn) : faabric::transport::MessageEndpointClient(hostIn, FUNCTION_CALL_PORT) { - this->open(faabric::transport::getGlobalMessageContext()); + this->open(); } void FunctionCallClient::sendHeader(faabric::scheduler::FunctionCalls call) diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index e17a13b06..5237ee5db 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -18,7 +18,7 @@ void FunctionCallServer::stop() faabric::scheduler::getScheduler().closeFunctionCallClients(); // Call the parent stop - MessageEndpointServer::stop(faabric::transport::getGlobalMessageContext()); + MessageEndpointServer::stop(); } void FunctionCallServer::doRecv(faabric::transport::Message& header, diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 7e99bc0c4..ba4a9b6d8 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -78,7 +78,7 @@ void clearMockSnapshotRequests() SnapshotClient::SnapshotClient(const std::string& hostIn) : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_PORT) { - this->open(faabric::transport::getGlobalMessageContext()); + this->open(); } void SnapshotClient::sendHeader(faabric::scheduler::SnapshotCalls call) diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 4312c0568..7765711ff 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -18,7 +18,7 @@ void SnapshotServer::stop() faabric::scheduler::getScheduler().closeSnapshotClients(); // Call the parent stop - MessageEndpointServer::stop(faabric::transport::getGlobalMessageContext()); + MessageEndpointServer::stop(); } void SnapshotServer::doRecv(faabric::transport::Message& header, diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index bc1a8242e..8506e0dac 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -13,7 +13,7 @@ StateClient::StateClient(const std::string& userIn, , host(hostIn) , reg(state::getInMemoryStateRegistry()) { - this->open(faabric::transport::getGlobalMessageContext()); + this->open(); } void StateClient::sendHeader(faabric::state::StateCalls call) diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index 73a5f4391..e4e57a00c 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -4,8 +4,8 @@ namespace faabric::transport { -static std::unique_ptr instance = nullptr; -static std::shared_mutex messageContextMx; +std::shared_ptr MessageContext::instance = nullptr; +std::shared_mutex MessageContext::mx; MessageContext::MessageContext() : ctx(1) @@ -21,24 +21,29 @@ MessageContext::~MessageContext() this->ctx.close(); } -zmq::context_t& MessageContext::get() +zmq::context_t& MessageContext::getZMQContext() { return this->ctx; } -faabric::transport::MessageContext& getGlobalMessageContext() +std::shared_ptr MessageContext::getInstance() { if (instance == nullptr) { - faabric::util::FullLock lock(messageContextMx); + faabric::util::FullLock lock(mx); if (instance == nullptr) { instance = std::make_unique(); } } { - faabric::util::SharedLock lock(messageContextMx); + faabric::util::SharedLock lock(mx); - return *instance; + return instance; } } + +std::shared_ptr getGlobalMessageContext() +{ + return MessageContext::getInstance(); +} } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index dcd607d7d..d831e48fe 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -39,9 +40,7 @@ MessageEndpoint::~MessageEndpoint() } } -void MessageEndpoint::open(faabric::transport::MessageContext& context, - faabric::transport::SocketType sockType, - bool bind) +void MessageEndpoint::open(faabric::transport::SocketType sockType, bool bind) { // Check we are opening from the same thread. We assert not to incur in // costly checks when running a Release build. @@ -50,12 +49,13 @@ void MessageEndpoint::open(faabric::transport::MessageContext& context, // Note - only one socket may bind, but several can connect. This // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between // bind and connect does not matter. + std::shared_ptr context = getGlobalMessageContext(); switch (sockType) { case faabric::transport::SocketType::PUSH: CATCH_ZMQ_ERR( { this->socket = std::make_unique( - context.get(), zmq::socket_type::push); + context->getZMQContext(), zmq::socket_type::push); }, "push_socket") break; @@ -63,7 +63,7 @@ void MessageEndpoint::open(faabric::transport::MessageContext& context, CATCH_ZMQ_ERR( { this->socket = std::make_unique( - context.get(), zmq::socket_type::pull); + context->getZMQContext(), zmq::socket_type::pull); }, "pull_socket") @@ -265,12 +265,12 @@ SendMessageEndpoint::SendMessageEndpoint(const std::string& hostIn, int portIn) : MessageEndpoint(hostIn, portIn) {} -void SendMessageEndpoint::open(MessageContext& context) +void SendMessageEndpoint::open() { SPDLOG_TRACE( fmt::format("Opening socket: {} (SEND {}:{})", id, host, port)); - MessageEndpoint::open(context, SocketType::PUSH, false); + MessageEndpoint::open(SocketType::PUSH, false); } void SendMessageEndpoint::close() @@ -285,12 +285,12 @@ RecvMessageEndpoint::RecvMessageEndpoint(int portIn) : MessageEndpoint(ANY_HOST, portIn) {} -void RecvMessageEndpoint::open(MessageContext& context) +void RecvMessageEndpoint::open() { SPDLOG_TRACE( fmt::format("Opening socket: {} (RECV {}:{})", id, host, port)); - MessageEndpoint::open(context, SocketType::PULL, true); + MessageEndpoint::open(SocketType::PULL, true); } void RecvMessageEndpoint::close() diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 45798a067..916ea3c52 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -10,18 +10,21 @@ MessageEndpointClient::MessageEndpointClient(const std::string& host, int port) Message MessageEndpointClient::awaitResponse(int port) { // Wait for the response, open a temporary endpoint for it - // Note - we use a different host/port not to clash with existing server RecvMessageEndpoint endpoint(port); // Inherit timeouts on temporary endpoint endpoint.setRecvTimeoutMs(recvTimeoutMs); endpoint.setSendTimeoutMs(sendTimeoutMs); - endpoint.open(faabric::transport::getGlobalMessageContext()); + endpoint.open(); - Message receivedMessage = endpoint.recv(); - - endpoint.close(); + Message receivedMessage; + try { + receivedMessage = endpoint.recv(); + } catch (MessageTimeoutException& ex) { + endpoint.close(); + throw; + } return receivedMessage; } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index c3bee6bec..bf825dccd 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -10,18 +10,13 @@ MessageEndpointServer::MessageEndpointServer(int portIn) {} void MessageEndpointServer::start() -{ - start(faabric::transport::getGlobalMessageContext()); -} - -void MessageEndpointServer::start(faabric::transport::MessageContext& context) { // Start serving thread in background - servingThread = std::thread([this, &context] { + servingThread = std::thread([this] { endpoint = std::make_unique(this->port); // Open message endpoint, and bind - endpoint->open(context); + endpoint->open(); // Loop until we receive a shutdown message while (true) { @@ -41,18 +36,13 @@ void MessageEndpointServer::start(faabric::transport::MessageContext& context) } void MessageEndpointServer::stop() -{ - stop(faabric::transport::getGlobalMessageContext()); -} - -void MessageEndpointServer::stop(faabric::transport::MessageContext& context) { // Send a shutdown message via a temporary endpoint SPDLOG_TRACE("Sending shutdown message locally to {}:{}", endpoint->getHost(), endpoint->getPort()); SendMessageEndpoint e(endpoint->getHost(), endpoint->getPort()); - e.open(getGlobalMessageContext()); + e.open(); e.send(nullptr, 0); // Join the serving thread @@ -103,7 +93,7 @@ void MessageEndpointServer::sendResponse(uint8_t* serialisedMsg, { // Open the endpoint socket, server connects (not bind) to remote address SendMessageEndpoint endpoint(returnHost, returnPort + REPLY_PORT_OFFSET); - endpoint.open(faabric::transport::getGlobalMessageContext()); + endpoint.open(); endpoint.send(serialisedMsg, size); endpoint.close(); } diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 67be90801..39c7ea688 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -4,7 +4,7 @@ namespace faabric::transport { faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() { faabric::transport::RecvMessageEndpoint endpoint(MPI_PORT); - endpoint.open(faabric::transport::getGlobalMessageContext()); + endpoint.open(); faabric::transport::Message m = endpoint.recv(); PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); endpoint.close(); @@ -22,7 +22,7 @@ void sendMpiHostRankMsg(const std::string& hostIn, throw std::runtime_error("Error serialising message"); } faabric::transport::SendMessageEndpoint endpoint(hostIn, MPI_PORT); - endpoint.open(faabric::transport::getGlobalMessageContext()); + endpoint.open(); endpoint.send(sMsg, msgSize, false); endpoint.close(); } @@ -32,8 +32,8 @@ MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) : sendMessageEndpoint(hostIn, portIn) , recvMessageEndpoint(portIn) { - sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + sendMessageEndpoint.open(); + recvMessageEndpoint.open(); } MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, @@ -42,8 +42,8 @@ MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, : sendMessageEndpoint(hostIn, sendPort) , recvMessageEndpoint(recvPort) { - sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); + sendMessageEndpoint.open(); + recvMessageEndpoint.open(); } void MpiMessageEndpoint::sendMpiMessage( diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index ea41b3585..5b770ceda 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -19,11 +19,11 @@ TEST_CASE_METHOD(MessageContextFixture, { // Open an endpoint client, don't bind MessageEndpoint cli(thisHost, testPort); - REQUIRE_NOTHROW(cli.open(context, SocketType::PULL, false)); + REQUIRE_NOTHROW(cli.open(SocketType::PULL, false)); // Open another endpoint client, bind MessageEndpoint secondCli(thisHost, testPort); - REQUIRE_NOTHROW(secondCli.open(context, SocketType::PUSH, true)); + REQUIRE_NOTHROW(secondCli.open(SocketType::PUSH, true)); // Close all endpoint clients REQUIRE_NOTHROW(cli.close(false)); @@ -36,11 +36,11 @@ TEST_CASE_METHOD(MessageContextFixture, { // Open the source endpoint client, don't bind SendMessageEndpoint src(thisHost, testPort); - src.open(context); + src.open(); // Open the destination endpoint client, bind RecvMessageEndpoint dst(testPort); - dst.open(context); + dst.open(); // Send message std::string expectedMsg = "Hello world!"; @@ -68,7 +68,7 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") std::thread senderThread([this, expectedMsg, expectedResponse] { // Open the source endpoint client, don't bind MessageEndpointClient src(thisHost, testPort); - src.open(context); + src.open(); // Send message and wait for response uint8_t msg[expectedMsg.size()]; @@ -86,7 +86,7 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") // Receive message RecvMessageEndpoint dst(testPort); - dst.open(context); + dst.open(); faabric::transport::Message recvMsg = dst.recv(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); @@ -94,7 +94,7 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") // Send response, open a new endpoint for it SendMessageEndpoint dstResponse(thisHost, testReplyPort); - dstResponse.open(context); + dstResponse.open(); uint8_t msg[expectedResponse.size()]; memcpy(msg, expectedResponse.c_str(), expectedResponse.size()); dstResponse.send(msg, expectedResponse.size()); @@ -119,7 +119,7 @@ TEST_CASE_METHOD(MessageContextFixture, std::thread senderThread([this, numMessages, baseMsg] { // Open the source endpoint client, don't bind SendMessageEndpoint src(thisHost, testPort); - src.open(context); + src.open(); for (int i = 0; i < numMessages; i++) { std::string expectedMsg = baseMsg + std::to_string(i); uint8_t msg[expectedMsg.size()]; @@ -132,7 +132,7 @@ TEST_CASE_METHOD(MessageContextFixture, // Receive messages RecvMessageEndpoint dst(testPort); - dst.open(context); + dst.open(); for (int i = 0; i < numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -168,7 +168,7 @@ TEST_CASE_METHOD(MessageContextFixture, std::thread([this, numMessages, expectedMsg] { // Open the source endpoint client, don't bind SendMessageEndpoint src(thisHost, testPort); - src.open(context); + src.open(); for (int i = 0; i < numMessages; i++) { uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); @@ -181,7 +181,7 @@ TEST_CASE_METHOD(MessageContextFixture, // Receive messages RecvMessageEndpoint dst(testPort); - dst.open(context); + dst.open(); for (int i = 0; i < numSenders * numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -231,13 +231,13 @@ TEST_CASE_METHOD(MessageContextFixture, SECTION("Recv, socket already initialised") { - cli.open(context, SocketType::PULL, false); + cli.open(SocketType::PULL, false); REQUIRE_THROWS(cli.setRecvTimeoutMs(100)); } SECTION("Send, socket already initialised") { - cli.open(context, SocketType::PULL, false); + cli.open(SocketType::PULL, false); REQUIRE_THROWS(cli.setSendTimeoutMs(100)); } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 7e11a208a..2fb5326ad 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -93,8 +93,7 @@ TEST_CASE("Test send one message to server", "[transport]") // Open the source endpoint client, don't bind MessageEndpointClient src(thisHost, testPort); - auto& context = getGlobalMessageContext(); - src.open(context); + src.open(); // Send message: server expects header + body std::string header = "header"; @@ -129,9 +128,8 @@ TEST_CASE("Test send one-off response to client", "[transport]") std::thread clientThread([expectedMsg] { // Open the source endpoint client, don't bind - auto& context = getGlobalMessageContext(); MessageEndpointClient cli(thisHost, testPort); - cli.open(context); + cli.open(); Message msg = cli.awaitResponse(testPort + REPLY_PORT_OFFSET); assert(msg.size() == expectedMsg.size()); @@ -164,9 +162,8 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") for (int i = 0; i < numClients; i++) { clientThreads.emplace_back(std::thread([numMessages] { // Prepare client - auto& context = getGlobalMessageContext(); MessageEndpointClient cli(thisHost, testPort); - cli.open(context); + cli.open(); std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { @@ -229,10 +226,9 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") } // Set up the client - auto& context = getGlobalMessageContext(); MessageEndpointClient cli(thisHost, testPort); cli.setRecvTimeoutMs(clientTimeout); - cli.open(context); + cli.open(); std::vector data = { 1, 1, 1 }; cli.send(data.data(), data.size(), true); diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 2e985a741..2c5ffc984 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -120,13 +119,8 @@ class ConfTestFixture class MessageContextFixture : public SchedulerTestFixture { - protected: - faabric::transport::MessageContext& context; - public: - MessageContextFixture() - : context(faabric::transport::getGlobalMessageContext()) - {} + MessageContextFixture() {} ~MessageContextFixture() {} }; From b985bc7b4e95e1c6cbaf403316f9e75e3153b6a2 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 21 Jun 2021 14:32:07 +0000 Subject: [PATCH 10/66] Use stlib in Message object --- include/faabric/transport/Message.h | 6 ++---- src/transport/Message.cpp | 28 ++++++++-------------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/include/faabric/transport/Message.h b/include/faabric/transport/Message.h index f5c5eadbe..6341e3884 100644 --- a/include/faabric/transport/Message.h +++ b/include/faabric/transport/Message.h @@ -18,8 +18,6 @@ class Message Message(); - ~Message(); - char* data(); uint8_t* udata(); @@ -31,8 +29,8 @@ class Message void persist(); private: - uint8_t* msg; - int _size; + std::vector bytes; + bool _more; bool _persist; }; diff --git a/src/transport/Message.cpp b/src/transport/Message.cpp index 489eab098..d17f81e39 100644 --- a/src/transport/Message.cpp +++ b/src/transport/Message.cpp @@ -2,47 +2,35 @@ namespace faabric::transport { Message::Message(const zmq::message_t& msgIn) - : _size(msgIn.size()) + : bytes(msgIn.size()) , _more(msgIn.more()) , _persist(false) { - msg = reinterpret_cast(malloc(_size * sizeof(uint8_t))); - memcpy(msg, msgIn.data(), _size); + memcpy(bytes.data(), msgIn.data(), msgIn.size()); } Message::Message(int sizeIn) - : _size(sizeIn) + : bytes(sizeIn) , _more(false) , _persist(false) -{ - msg = reinterpret_cast(malloc(_size * sizeof(uint8_t))); -} - -// Empty message signals shutdown -Message::Message() - : msg(nullptr) {} -Message::~Message() -{ - if (!_persist) { - free(reinterpret_cast(msg)); - } -} +// Empty message signals shutdown +Message::Message() {} char* Message::data() { - return reinterpret_cast(msg); + return reinterpret_cast(bytes.data()); } uint8_t* Message::udata() { - return msg; + return bytes.data(); } int Message::size() { - return _size; + return bytes.size(); } bool Message::more() From f40e12ea96de5509ebac9d6fa9c54b3a6629e115 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 22 Jun 2021 12:37:23 +0000 Subject: [PATCH 11/66] Remove MessageContext object --- include/faabric/transport/MessageContext.h | 31 +++---------------- src/transport/MessageContext.cpp | 32 +++----------------- src/transport/MessageEndpoint.cpp | 6 ++-- src/transport/MessageEndpointClient.cpp | 2 ++ tests/test/transport/test_message_server.cpp | 28 ++++++++--------- 5 files changed, 27 insertions(+), 72 deletions(-) diff --git a/include/faabric/transport/MessageContext.h b/include/faabric/transport/MessageContext.h index 42917c3a6..4b59bac5e 100644 --- a/include/faabric/transport/MessageContext.h +++ b/include/faabric/transport/MessageContext.h @@ -3,36 +3,13 @@ #include #include -namespace faabric::transport { -/* Wrapper around zmq::context_t - * +/* * The context object is thread safe, and the constructor parameter indicates * the number of hardware IO threads to be used. As a rule of thumb, use one * IO thread per Gbps of data. */ -class MessageContext -{ - public: - MessageContext(); - - // Message context should not be copied as there must only be one ZMQ - // context - MessageContext(const MessageContext& ctx) = delete; - - MessageContext(int overrideIoThreads); - - ~MessageContext(); - - zmq::context_t& getZMQContext(); +#define ZMQ_CONTEXT_IO_THREADS 1 - static std::shared_ptr getInstance(); - - private: - zmq::context_t ctx; - - static std::shared_ptr instance; - static std::shared_mutex mx; -}; - -std::shared_ptr getGlobalMessageContext(); +namespace faabric::transport { +std::shared_ptr getGlobalMessageContext(); } diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index e4e57a00c..5d50a7697 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -4,34 +4,15 @@ namespace faabric::transport { -std::shared_ptr MessageContext::instance = nullptr; -std::shared_mutex MessageContext::mx; +static std::shared_ptr instance = nullptr; +static std::shared_mutex mx; -MessageContext::MessageContext() - : ctx(1) -{} - -MessageContext::MessageContext(int overrideIoThreads) - : ctx(overrideIoThreads) -{} - -MessageContext::~MessageContext() -{ - SPDLOG_TRACE("Closing global ZeroMQ message context"); - this->ctx.close(); -} - -zmq::context_t& MessageContext::getZMQContext() -{ - return this->ctx; -} - -std::shared_ptr MessageContext::getInstance() +std::shared_ptr getGlobalMessageContext() { if (instance == nullptr) { faabric::util::FullLock lock(mx); if (instance == nullptr) { - instance = std::make_unique(); + instance = std::make_shared(ZMQ_CONTEXT_IO_THREADS); } } @@ -41,9 +22,4 @@ std::shared_ptr MessageContext::getInstance() return instance; } } - -std::shared_ptr getGlobalMessageContext() -{ - return MessageContext::getInstance(); -} } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index d831e48fe..f58a5a65a 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -49,13 +49,13 @@ void MessageEndpoint::open(faabric::transport::SocketType sockType, bool bind) // Note - only one socket may bind, but several can connect. This // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between // bind and connect does not matter. - std::shared_ptr context = getGlobalMessageContext(); + std::shared_ptr context = getGlobalMessageContext(); switch (sockType) { case faabric::transport::SocketType::PUSH: CATCH_ZMQ_ERR( { this->socket = std::make_unique( - context->getZMQContext(), zmq::socket_type::push); + *context, zmq::socket_type::push); }, "push_socket") break; @@ -63,7 +63,7 @@ void MessageEndpoint::open(faabric::transport::SocketType sockType, bool bind) CATCH_ZMQ_ERR( { this->socket = std::make_unique( - context->getZMQContext(), zmq::socket_type::pull); + *context, zmq::socket_type::pull); }, "pull_socket") diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 916ea3c52..dee3228b8 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -26,6 +26,8 @@ Message MessageEndpointClient::awaitResponse(int port) throw; } + endpoint.close(); + return receivedMessage; } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 2fb5326ad..819a3415d 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -196,20 +196,6 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") TEST_CASE("Test client timeout on requests to valid server", "[transport]") { - // Start the server in the background - std::thread t([] { - SlowServer server; - server.start(); - - int threadSleep = server.delayMs + 500; - usleep(threadSleep * 1000); - - server.stop(); - }); - - // Wait for the server to start up - usleep(500 * 1000); - int clientTimeout; bool expectFailure; @@ -225,6 +211,20 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") expectFailure = true; } + // Start the server in the background + std::thread t([] { + SlowServer server; + server.start(); + + int threadSleep = server.delayMs + 500; + usleep(threadSleep * 1000); + + server.stop(); + }); + + // Wait for the server to start up + usleep(500 * 1000); + // Set up the client MessageEndpointClient cli(thisHost, testPort); cli.setRecvTimeoutMs(clientTimeout); From 9ee0e112d32afb6e848d6c7ee09de5df57798cff Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 22 Jun 2021 14:56:53 +0000 Subject: [PATCH 12/66] Make SocketType a property on the message endpoint --- include/faabric/transport/Message.h | 2 + include/faabric/transport/MessageEndpoint.h | 9 +- src/transport/Message.cpp | 5 + src/transport/MessageContext.cpp | 1 - src/transport/MessageEndpoint.cpp | 102 ++++++++++-------- src/transport/MessageEndpointServer.cpp | 3 +- .../test_message_endpoint_client.cpp | 47 ++++---- tests/test/transport/test_message_server.cpp | 31 +++--- 8 files changed, 106 insertions(+), 94 deletions(-) diff --git a/include/faabric/transport/Message.h b/include/faabric/transport/Message.h index 6341e3884..f84b211fb 100644 --- a/include/faabric/transport/Message.h +++ b/include/faabric/transport/Message.h @@ -22,6 +22,8 @@ class Message uint8_t* udata(); + std::vector dataCopy(); + int size(); bool more(); diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index f5e949aad..92891616e 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -35,7 +35,9 @@ enum class SocketType class MessageEndpoint { public: - MessageEndpoint(const std::string& hostIn, int portIn); + MessageEndpoint(SocketType socketTypeIn, + const std::string& hostIn, + int portIn); // Message endpoints shouldn't be assigned as ZeroMQ sockets are not thread // safe @@ -46,9 +48,9 @@ class MessageEndpoint ~MessageEndpoint(); - void open(SocketType sockTypeIn, bool bind); + void open(); - void close(bool bind); + void close(); void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); @@ -67,6 +69,7 @@ class MessageEndpoint void setSendTimeoutMs(int value); protected: + const SocketType socketType; const std::string host; const int port; const std::string address; diff --git a/src/transport/Message.cpp b/src/transport/Message.cpp index d17f81e39..74d0ec1eb 100644 --- a/src/transport/Message.cpp +++ b/src/transport/Message.cpp @@ -28,6 +28,11 @@ uint8_t* Message::udata() return bytes.data(); } +std::vector Message::dataCopy() +{ + return bytes; +} + int Message::size() { return bytes.size(); diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index 5d50a7697..19a8bb78c 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -18,7 +18,6 @@ std::shared_ptr getGlobalMessageContext() { faabric::util::SharedLock lock(mx); - return instance; } } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index f58a5a65a..44548e3ad 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -24,8 +24,11 @@ namespace faabric::transport { -MessageEndpoint::MessageEndpoint(const std::string& hostIn, int portIn) - : host(hostIn) +MessageEndpoint::MessageEndpoint(SocketType socketTypeIn, + const std::string& hostIn, + int portIn) + : socketType(socketTypeIn) + , host(hostIn) , port(portIn) , address("tcp://" + host + ":" + std::to_string(port)) , tid(std::this_thread::get_id()) @@ -36,11 +39,11 @@ MessageEndpoint::~MessageEndpoint() { if (this->socket != nullptr) { SPDLOG_WARN("Destroying an open message endpoint!"); - this->close(false); + this->close(); } } -void MessageEndpoint::open(faabric::transport::SocketType sockType, bool bind) +void MessageEndpoint::open() { // Check we are opening from the same thread. We assert not to incur in // costly checks when running a Release build. @@ -50,16 +53,20 @@ void MessageEndpoint::open(faabric::transport::SocketType sockType, bool bind) // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between // bind and connect does not matter. std::shared_ptr context = getGlobalMessageContext(); - switch (sockType) { - case faabric::transport::SocketType::PUSH: + switch (socketType) { + case faabric::transport::SocketType::PUSH: { CATCH_ZMQ_ERR( { this->socket = std::make_unique( *context, zmq::socket_type::push); }, "push_socket") + + CATCH_ZMQ_ERR(this->socket->connect(address), "connect") + break; - case faabric::transport::SocketType::PULL: + } + case faabric::transport::SocketType::PULL: { CATCH_ZMQ_ERR( { this->socket = std::make_unique( @@ -67,19 +74,13 @@ void MessageEndpoint::open(faabric::transport::SocketType sockType, bool bind) }, "pull_socket") - break; - default: - throw std::runtime_error("Unrecognized socket type"); - } - - // Check opening the socket has worked - assert(this->socket != nullptr); + CATCH_ZMQ_ERR(this->socket->bind(address), "bind") - // Bind or connect the socket - if (bind) { - CATCH_ZMQ_ERR(this->socket->bind(address), "bind") - } else { - CATCH_ZMQ_ERR(this->socket->connect(address), "connect") + break; + } + default: { + throw std::runtime_error("Opening unrecognized socket type"); + } } // Set socket options @@ -183,7 +184,7 @@ Message MessageEndpoint::recvNoBuffer() return Message(msg); } -void MessageEndpoint::close(bool bind) +void MessageEndpoint::close() { if (this->socket != nullptr) { if (tid != std::this_thread::get_id()) { @@ -191,32 +192,39 @@ void MessageEndpoint::close(bool bind) } // We duplicate the call to close() because when unbinding, we want to - // block until we _actually_ have unbinded, i.e. 0MQ has closed the + // block until we _actually_ have unbound, i.e. 0MQ has closed the // socket (which happens asynchronously). For connect()-ed sockets we // don't care. // Not blocking on un-bind can cause race-conditions when the underlying // system is slow at closing sockets, and the application relies a lot // on synchronous message-passing. - if (bind) { - // NOTE - unbinding a socket has a considerable overhead compared to - // disconnecting it. - CATCH_ZMQ_ERR(this->socket->unbind(address), "unbind") - - // TODO - could we reuse the monitor? - zmq::monitor_t mon; - const std::string monAddr = - "inproc://monitor_" + std::to_string(id); - mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); - - CATCH_ZMQ_ERR(this->socket->close(), "close_unbind") - - // Wait for this to complete - mon.check_event(recvTimeoutMs); - - } else { - CATCH_ZMQ_ERR(this->socket->disconnect(address), "disconnect") - - CATCH_ZMQ_ERR(this->socket->close(), "close_disconnect") + switch (socketType) { + case SocketType::PUSH: { + CATCH_ZMQ_ERR(this->socket->disconnect(address), "disconnect") + + CATCH_ZMQ_ERR(this->socket->close(), "close_disconnect") + break; + } + case SocketType::PULL: { + // NOTE - unbinding a socket has a considerable overhead + // compared to disconnecting it. + CATCH_ZMQ_ERR(this->socket->unbind(address), "unbind") + + // TODO - could we reuse the monitor? + zmq::monitor_t mon; + const std::string monAddr = + "inproc://monitor_" + std::to_string(id); + mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); + + CATCH_ZMQ_ERR(this->socket->close(), "close_unbind") + + // Wait for this to complete + mon.check_event(recvTimeoutMs); + break; + } + default: { + throw std::runtime_error("Closing unrecognised type of socket"); + } } // Finally, null the socket @@ -262,7 +270,7 @@ void MessageEndpoint::setSendTimeoutMs(int value) /* Send and Recv Message Endpoints */ SendMessageEndpoint::SendMessageEndpoint(const std::string& hostIn, int portIn) - : MessageEndpoint(hostIn, portIn) + : MessageEndpoint(SocketType::PUSH, hostIn, portIn) {} void SendMessageEndpoint::open() @@ -270,7 +278,7 @@ void SendMessageEndpoint::open() SPDLOG_TRACE( fmt::format("Opening socket: {} (SEND {}:{})", id, host, port)); - MessageEndpoint::open(SocketType::PUSH, false); + MessageEndpoint::open(); } void SendMessageEndpoint::close() @@ -278,11 +286,11 @@ void SendMessageEndpoint::close() SPDLOG_TRACE( fmt::format("Closing socket: {} (SEND {}:{})", id, host, port)); - MessageEndpoint::close(false); + MessageEndpoint::close(); } RecvMessageEndpoint::RecvMessageEndpoint(int portIn) - : MessageEndpoint(ANY_HOST, portIn) + : MessageEndpoint(SocketType::PULL, ANY_HOST, portIn) {} void RecvMessageEndpoint::open() @@ -290,7 +298,7 @@ void RecvMessageEndpoint::open() SPDLOG_TRACE( fmt::format("Opening socket: {} (RECV {}:{})", id, host, port)); - MessageEndpoint::open(SocketType::PULL, true); + MessageEndpoint::open(); } void RecvMessageEndpoint::close() @@ -298,6 +306,6 @@ void RecvMessageEndpoint::close() SPDLOG_TRACE( fmt::format("Closing socket: {} (RECV {}:{})", id, host, port)); - MessageEndpoint::close(true); + MessageEndpoint::close(); } } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index bf825dccd..6d07a2167 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -37,10 +37,11 @@ void MessageEndpointServer::start() void MessageEndpointServer::stop() { - // Send a shutdown message via a temporary endpoint SPDLOG_TRACE("Sending shutdown message locally to {}:{}", endpoint->getHost(), endpoint->getPort()); + + // Send a shutdown message via a temporary endpoint SendMessageEndpoint e(endpoint->getHost(), endpoint->getPort()); e.open(); e.send(nullptr, 0); diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 5b770ceda..d9025da21 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -18,16 +18,16 @@ TEST_CASE_METHOD(MessageContextFixture, "[transport]") { // Open an endpoint client, don't bind - MessageEndpoint cli(thisHost, testPort); - REQUIRE_NOTHROW(cli.open(SocketType::PULL, false)); + MessageEndpoint cli(SocketType::PULL, thisHost, testPort); + REQUIRE_NOTHROW(cli.open()); // Open another endpoint client, bind - MessageEndpoint secondCli(thisHost, testPort); - REQUIRE_NOTHROW(secondCli.open(SocketType::PUSH, true)); + MessageEndpoint secondCli(SocketType::PUSH, thisHost, testPort); + REQUIRE_NOTHROW(secondCli.open()); // Close all endpoint clients - REQUIRE_NOTHROW(cli.close(false)); - REQUIRE_NOTHROW(secondCli.close(true)); + REQUIRE_NOTHROW(cli.close()); + REQUIRE_NOTHROW(secondCli.close()); } TEST_CASE_METHOD(MessageContextFixture, @@ -65,7 +65,7 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") std::string expectedMsg = "Hello "; std::string expectedResponse = "world!"; - std::thread senderThread([this, expectedMsg, expectedResponse] { + std::thread senderThread([expectedMsg, expectedResponse] { // Open the source endpoint client, don't bind MessageEndpointClient src(thisHost, testPort); src.open(); @@ -116,7 +116,7 @@ TEST_CASE_METHOD(MessageContextFixture, int numMessages = 10000; std::string baseMsg = "Hello "; - std::thread senderThread([this, numMessages, baseMsg] { + std::thread senderThread([numMessages, baseMsg] { // Open the source endpoint client, don't bind SendMessageEndpoint src(thisHost, testPort); src.open(); @@ -164,19 +164,18 @@ TEST_CASE_METHOD(MessageContextFixture, std::vector senderThreads; for (int j = 0; j < numSenders; j++) { - senderThreads.emplace_back( - std::thread([this, numMessages, expectedMsg] { - // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); - src.open(); - for (int i = 0; i < numMessages; i++) { - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - src.send(msg, expectedMsg.size()); - } - - src.close(); - })); + senderThreads.emplace_back(std::thread([numMessages, expectedMsg] { + // Open the source endpoint client, don't bind + SendMessageEndpoint src(thisHost, testPort); + src.open(); + for (int i = 0; i < numMessages; i++) { + uint8_t msg[expectedMsg.size()]; + memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); + src.send(msg, expectedMsg.size()); + } + + src.close(); + })); } // Receive messages @@ -207,7 +206,7 @@ TEST_CASE_METHOD(MessageContextFixture, "Test can't set invalid send/recv timeouts", "[transport]") { - MessageEndpoint cli(thisHost, testPort); + MessageEndpoint cli(SocketType::PULL, thisHost, testPort); SECTION("Sanity check valid timeout") { @@ -231,13 +230,13 @@ TEST_CASE_METHOD(MessageContextFixture, SECTION("Recv, socket already initialised") { - cli.open(SocketType::PULL, false); + cli.open(); REQUIRE_THROWS(cli.setRecvTimeoutMs(100)); } SECTION("Send, socket already initialised") { - cli.open(SocketType::PULL, false); + cli.open(); REQUIRE_THROWS(cli.setSendTimeoutMs(100)); } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 819a3415d..062c7b374 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -46,30 +46,21 @@ class SlowServer final : public MessageEndpointServer { public: int delayMs = 1000; + std::vector data = { 0, 1, 2, 3 }; SlowServer() : MessageEndpointServer(testPort) {} - void sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort) - { - usleep(delayMs * 1000); - - SPDLOG_DEBUG("Slow message server test sending response"); - MessageEndpointServer::sendResponse( - serialisedMsg, size, returnHost, returnPort); - } - private: void doRecv(faabric::transport::Message& header, faabric::transport::Message& body) override { SPDLOG_DEBUG("Slow message server test recv"); - std::vector data = { 0, 1, 2, 3 }; - sendResponse(data.data(), data.size(), thisHost, testPort); + + usleep(delayMs * 1000); + MessageEndpointServer::sendResponse( + data.data(), data.size(), thisHost, testPort); } }; @@ -207,8 +198,8 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") SECTION("Short timeout failure") { - clientTimeout = 100; - expectFailure = true; + clientTimeout = 20000; + expectFailure = false; } // Start the server in the background @@ -234,12 +225,16 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") cli.send(data.data(), data.size(), true); cli.send(data.data(), data.size()); - // Check for failure accordingly if (expectFailure) { + // Check for failure REQUIRE_THROWS_AS(cli.awaitResponse(testPort + REPLY_PORT_OFFSET), MessageTimeoutException); } else { - REQUIRE_NOTHROW(cli.awaitResponse(testPort + REPLY_PORT_OFFSET)); + // Check response from server successful + Message responseMessage = + cli.awaitResponse(testPort + REPLY_PORT_OFFSET); + std::vector expected = { 0, 1, 2, 3 }; + REQUIRE(responseMessage.dataCopy() == expected); } cli.close(); From 9a3349a06d7dd01652fdccd9908e0e1e323e558d Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 22 Jun 2021 15:27:19 +0000 Subject: [PATCH 13/66] Error checking around use of send and recv sockets --- include/faabric/transport/MessageEndpoint.h | 3 + .../faabric/transport/MessageEndpointServer.h | 2 +- src/transport/MessageEndpoint.cpp | 110 +++++++++++------- src/transport/MessageEndpointServer.cpp | 28 ++--- 4 files changed, 85 insertions(+), 58 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 92891616e..20e79912c 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -18,6 +18,9 @@ #define DEFAULT_RECV_TIMEOUT_MS 20000 #define DEFAULT_SEND_TIMEOUT_MS 20000 +// The monitor is checking an asynchronous event has completed, so can be short +#define MONITOR_TIMEOUT_MS 2000 + namespace faabric::transport { enum class SocketType { diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 268d2d14f..973ab87df 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -51,7 +51,7 @@ class MessageEndpointServer private: const int port; - std::unique_ptr endpoint = nullptr; + std::unique_ptr recvEndpoint = nullptr; std::thread servingThread; }; diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 44548e3ad..a67c5e9d2 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -45,10 +45,14 @@ MessageEndpoint::~MessageEndpoint() void MessageEndpoint::open() { - // Check we are opening from the same thread. We assert not to incur in - // costly checks when running a Release build. + // Check we are opening from the same thread. assert(tid == std::this_thread::get_id()); + if (socket != nullptr) { + SPDLOG_ERROR("Double opening socket {}", id); + throw std::runtime_error("Double opening a socket"); + } + // Note - only one socket may bind, but several can connect. This // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between // bind and connect does not matter. @@ -91,7 +95,14 @@ void MessageEndpoint::open() void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) { assert(tid == std::this_thread::get_id()); - assert(this->socket != nullptr); + + if (this->socket == nullptr) { + throw std::runtime_error("Sending on an unopened socket"); + } + + if (socketType == SocketType::PULL) { + throw std::runtime_error("Sending on a recv socket"); + } zmq::send_flags sendFlags = more ? zmq::send_flags::sndmore : zmq::send_flags::none; @@ -116,9 +127,16 @@ void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) Message MessageEndpoint::recv(int size) { assert(tid == std::this_thread::get_id()); - assert(this->socket != nullptr); assert(size >= 0); + if (this->socket == nullptr) { + throw std::runtime_error("Receiving on an unopened socket"); + } + + if (socketType == SocketType::PUSH) { + throw std::runtime_error("Receiving on a send socket"); + } + if (size == 0) { return recvNoBuffer(); } @@ -186,50 +204,54 @@ Message MessageEndpoint::recvNoBuffer() void MessageEndpoint::close() { - if (this->socket != nullptr) { - if (tid != std::this_thread::get_id()) { - SPDLOG_WARN("Closing socket from a different thread"); - } + if (this->socket == nullptr) { + SPDLOG_ERROR("Closing unopened socket {}", id); + throw std::runtime_error("Closing unopened socket"); + } + + if (tid != std::this_thread::get_id()) { + SPDLOG_WARN("Closing socket from a different thread"); + } - // We duplicate the call to close() because when unbinding, we want to - // block until we _actually_ have unbound, i.e. 0MQ has closed the - // socket (which happens asynchronously). For connect()-ed sockets we - // don't care. - // Not blocking on un-bind can cause race-conditions when the underlying - // system is slow at closing sockets, and the application relies a lot - // on synchronous message-passing. - switch (socketType) { - case SocketType::PUSH: { - CATCH_ZMQ_ERR(this->socket->disconnect(address), "disconnect") - - CATCH_ZMQ_ERR(this->socket->close(), "close_disconnect") - break; - } - case SocketType::PULL: { - // NOTE - unbinding a socket has a considerable overhead - // compared to disconnecting it. - CATCH_ZMQ_ERR(this->socket->unbind(address), "unbind") - - // TODO - could we reuse the monitor? - zmq::monitor_t mon; - const std::string monAddr = - "inproc://monitor_" + std::to_string(id); - mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); - - CATCH_ZMQ_ERR(this->socket->close(), "close_unbind") - - // Wait for this to complete - mon.check_event(recvTimeoutMs); - break; - } - default: { - throw std::runtime_error("Closing unrecognised type of socket"); - } + switch (socketType) { + case SocketType::PUSH: { + CATCH_ZMQ_ERR(this->socket->disconnect(address), "disconnect") + CATCH_ZMQ_ERR(this->socket->close(), "close") + break; } + case SocketType::PULL: { + // We duplicate the call to close() because when unbinding, we want + // to block until we _actually_ have unbound, i.e. 0MQ has closed + // the socket (which happens asynchronously). + // + // Not blocking on un-bind can cause race-conditions when the + // underlying system is slow at closing sockets, and the application + // relies a lot on synchronous message-passing. + // + // NOTE that unbinding a socket has a considerable + // overhead compared to disconnecting it. + CATCH_ZMQ_ERR(this->socket->unbind(address), "unbind") + + // TODO - could we reuse the monitor across sockets? + zmq::monitor_t mon; + const std::string monAddr = + "inproc://monitor_" + std::to_string(id); + mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); + + CATCH_ZMQ_ERR(this->socket->close(), "close") + + // Wait for this to complete (should be fast) + mon.check_event(MONITOR_TIMEOUT_MS); - // Finally, null the socket - this->socket = nullptr; + break; + } + default: { + throw std::runtime_error("Closing unrecognised type of socket"); + } } + + // Finally, null the socket + this->socket = nullptr; } std::string MessageEndpoint::getHost() diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 6d07a2167..987d0ff7d 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -13,10 +13,10 @@ void MessageEndpointServer::start() { // Start serving thread in background servingThread = std::thread([this] { - endpoint = std::make_unique(this->port); + recvEndpoint = std::make_unique(this->port); // Open message endpoint, and bind - endpoint->open(); + recvEndpoint->open(); // Loop until we receive a shutdown message while (true) { @@ -27,22 +27,23 @@ void MessageEndpointServer::start() break; } } catch (MessageTimeoutException& ex) { + SPDLOG_TRACE("Server timed out with no messages, continuing"); continue; } } - endpoint->close(); + recvEndpoint->close(); }); } void MessageEndpointServer::stop() { SPDLOG_TRACE("Sending shutdown message locally to {}:{}", - endpoint->getHost(), - endpoint->getPort()); + recvEndpoint->getHost(), + recvEndpoint->getPort()); // Send a shutdown message via a temporary endpoint - SendMessageEndpoint e(endpoint->getHost(), endpoint->getPort()); + SendMessageEndpoint e(recvEndpoint->getHost(), recvEndpoint->getPort()); e.open(); e.send(nullptr, 0); @@ -57,10 +58,10 @@ void MessageEndpointServer::stop() bool MessageEndpointServer::recv() { // Check endpoint has been initialised - assert(endpoint->socket != nullptr); + assert(recvEndpoint->socket != nullptr); // Receive header and body - Message header = endpoint->recv(); + Message header = recvEndpoint->recv(); // Detect shutdown condition if (header.size() == 0) { @@ -73,7 +74,7 @@ bool MessageEndpointServer::recv() } // Check that there are no more messages to receive - Message body = endpoint->recv(); + Message body = recvEndpoint->recv(); if (body.more()) { throw std::runtime_error("Body sent with SNDMORE flag"); } @@ -93,9 +94,10 @@ void MessageEndpointServer::sendResponse(uint8_t* serialisedMsg, int returnPort) { // Open the endpoint socket, server connects (not bind) to remote address - SendMessageEndpoint endpoint(returnHost, returnPort + REPLY_PORT_OFFSET); - endpoint.open(); - endpoint.send(serialisedMsg, size); - endpoint.close(); + SendMessageEndpoint sendEndpoint(returnHost, + returnPort + REPLY_PORT_OFFSET); + sendEndpoint.open(); + sendEndpoint.send(serialisedMsg, size); + sendEndpoint.close(); } } From 880af93513282fcf73026dffac345f58d04a025a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 22 Jun 2021 15:51:15 +0000 Subject: [PATCH 14/66] Remove the need for persist on messages --- include/faabric/transport/Message.h | 3 --- src/scheduler/SnapshotServer.cpp | 16 +++++++++++----- src/transport/Message.cpp | 7 ------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/include/faabric/transport/Message.h b/include/faabric/transport/Message.h index f84b211fb..8e3378d63 100644 --- a/include/faabric/transport/Message.h +++ b/include/faabric/transport/Message.h @@ -28,12 +28,9 @@ class Message bool more(); - void persist(); - private: std::vector bytes; bool _more; - bool _persist; }; } diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 7765711ff..45067a485 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -7,6 +7,8 @@ #include #include +#include + namespace faabric::scheduler { SnapshotServer::SnapshotServer() : faabric::transport::MessageEndpointServer(SNAPSHOT_PORT) @@ -60,12 +62,16 @@ void SnapshotServer::recvPushSnapshot(faabric::transport::Message& msg) // Set up the snapshot faabric::util::SnapshotData data; data.size = r->contents()->size(); - data.data = r->mutable_contents()->Data(); - reg.takeSnapshot(r->key()->str(), data, true); - // Note that now the snapshot data is owned by Faabric and will be deleted - // later, so we don't want the message to delete it - msg.persist(); + // TODO - avoid this copy by changing server superclass to allow subclasses + // to provide a buffer to receive data. + // TODO - work out snapshot ownership here, how do we know when to delete + // this data? + data.data = (uint8_t*)mmap( + nullptr, data.size, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + std::memcpy(data.data, r->mutable_contents()->Data(), data.size); + + reg.takeSnapshot(r->key()->str(), data, true); // Send response faabric::EmptyResponse response; diff --git a/src/transport/Message.cpp b/src/transport/Message.cpp index 74d0ec1eb..e5748d9d3 100644 --- a/src/transport/Message.cpp +++ b/src/transport/Message.cpp @@ -4,7 +4,6 @@ namespace faabric::transport { Message::Message(const zmq::message_t& msgIn) : bytes(msgIn.size()) , _more(msgIn.more()) - , _persist(false) { memcpy(bytes.data(), msgIn.data(), msgIn.size()); } @@ -12,7 +11,6 @@ Message::Message(const zmq::message_t& msgIn) Message::Message(int sizeIn) : bytes(sizeIn) , _more(false) - , _persist(false) {} // Empty message signals shutdown @@ -42,9 +40,4 @@ bool Message::more() { return _more; } - -void Message::persist() -{ - _persist = true; -} } From 4b6fd9e48ba4f61801e04e47f09b876df82e3a31 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 22 Jun 2021 16:50:16 +0000 Subject: [PATCH 15/66] Restore simple class wrapper around context for shutdown --- src/transport/MessageContext.cpp | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index 19a8bb78c..f54d3c819 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -4,7 +4,23 @@ namespace faabric::transport { -static std::shared_ptr instance = nullptr; +class ContextWrapper +{ + public: + std::shared_ptr ctx; + ContextWrapper() + { + ctx = std::make_shared(ZMQ_CONTEXT_IO_THREADS); + } + + ~ContextWrapper() + { + SPDLOG_TRACE("Destroying ZeroMQ context"); + ctx->shutdown(); + } +}; + +static std::shared_ptr instance = nullptr; static std::shared_mutex mx; std::shared_ptr getGlobalMessageContext() @@ -12,13 +28,13 @@ std::shared_ptr getGlobalMessageContext() if (instance == nullptr) { faabric::util::FullLock lock(mx); if (instance == nullptr) { - instance = std::make_shared(ZMQ_CONTEXT_IO_THREADS); + instance = std::make_shared(); } } { faabric::util::SharedLock lock(mx); - return instance; + return instance->ctx; } } } From d4e8f41e2a6294f93655c0caba0d69d4286cfa7a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 22 Jun 2021 17:39:07 +0000 Subject: [PATCH 16/66] Close context from main thread --- src/transport/Message.cpp | 2 +- src/transport/MessageContext.cpp | 6 ++++++ tests/test/main.cpp | 4 ++++ tests/test/transport/test_message_server.cpp | 4 ++-- 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transport/Message.cpp b/src/transport/Message.cpp index e5748d9d3..5779a9e42 100644 --- a/src/transport/Message.cpp +++ b/src/transport/Message.cpp @@ -5,7 +5,7 @@ Message::Message(const zmq::message_t& msgIn) : bytes(msgIn.size()) , _more(msgIn.more()) { - memcpy(bytes.data(), msgIn.data(), msgIn.size()); + std::memcpy(bytes.data(), msgIn.data(), msgIn.size()); } Message::Message(int sizeIn) diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp index f54d3c819..810f9102f 100644 --- a/src/transport/MessageContext.cpp +++ b/src/transport/MessageContext.cpp @@ -8,6 +8,7 @@ class ContextWrapper { public: std::shared_ptr ctx; + ContextWrapper() { ctx = std::make_shared(ZMQ_CONTEXT_IO_THREADS); @@ -16,7 +17,12 @@ class ContextWrapper ~ContextWrapper() { SPDLOG_TRACE("Destroying ZeroMQ context"); + + // Force outstanding ops to return ETERM ctx->shutdown(); + + // Close the context + ctx->close(); } }; diff --git a/tests/test/main.cpp b/tests/test/main.cpp index bddede144..1a6857965 100644 --- a/tests/test/main.cpp +++ b/tests/test/main.cpp @@ -4,6 +4,7 @@ #include "faabric_utils.h" +#include #include #include @@ -18,5 +19,8 @@ int main(int argc, char* argv[]) fflush(stdout); + faabric::transport::getGlobalMessageContext()->shutdown(); + faabric::transport::getGlobalMessageContext()->close(); + return result; } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 062c7b374..d21442a06 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -198,8 +198,8 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") SECTION("Short timeout failure") { - clientTimeout = 20000; - expectFailure = false; + clientTimeout = 100; + expectFailure = true; } // Start the server in the background From 2ee0ef4fc67ac52490b3a8f190bea92de44706fd Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 08:20:57 +0000 Subject: [PATCH 17/66] Switch to raii sockets --- include/faabric/transport/MessageEndpoint.h | 61 ++--- .../faabric/transport/MessageEndpointClient.h | 4 +- src/scheduler/FunctionCallClient.cpp | 1 - src/scheduler/Scheduler.cpp | 6 - src/scheduler/SnapshotClient.cpp | 1 - src/state/InMemoryStateKeyValue.cpp | 11 - src/state/StateClient.cpp | 1 - src/transport/MessageEndpoint.cpp | 250 +++++------------- src/transport/MessageEndpointClient.cpp | 22 +- src/transport/MessageEndpointServer.cpp | 19 +- src/transport/MpiMessageEndpoint.cpp | 24 +- tests/test/main.cpp | 3 - .../scheduler/test_function_client_server.cpp | 1 - .../scheduler/test_snapshot_client_server.cpp | 1 - tests/test/state/test_state_server.cpp | 3 - .../test_message_endpoint_client.cpp | 84 ------ tests/test/transport/test_message_server.cpp | 19 +- 17 files changed, 93 insertions(+), 418 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 20e79912c..ff925b575 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -18,17 +18,9 @@ #define DEFAULT_RECV_TIMEOUT_MS 20000 #define DEFAULT_SEND_TIMEOUT_MS 20000 -// The monitor is checking an asynchronous event has completed, so can be short -#define MONITOR_TIMEOUT_MS 2000 - namespace faabric::transport { -enum class SocketType -{ - PUSH, - PULL -}; -/* Wrapper arround zmq::socket_t + /* Wrapper arround zmq::socket_t * * Thread-unsafe socket-like object. MUST be open-ed and close-ed from the * _same_ thread. For a proto://host:pair triple, one socket may bind, and all @@ -38,55 +30,31 @@ enum class SocketType class MessageEndpoint { public: - MessageEndpoint(SocketType socketTypeIn, + MessageEndpoint(zmq::socket_type socketTypeIn, const std::string& hostIn, - int portIn); + int portIn, int timeoutMs); - // Message endpoints shouldn't be assigned as ZeroMQ sockets are not thread - // safe + // Delete assignment and copy-constructor as we need to be very careful with + // socping and same-thread instantiation MessageEndpoint& operator=(const MessageEndpoint&) = delete; - // Neither copied MessageEndpoint(const MessageEndpoint& ctx) = delete; - ~MessageEndpoint(); - - void open(); - - void close(); - - void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); - - // If known, pass a size parameter to pre-allocate a recv buffer - Message recv(int size = 0); - - // The MessageEndpointServer needs direct access to the socket - std::unique_ptr socket; - std::string getHost(); int getPort(); - void setRecvTimeoutMs(int value); - - void setSendTimeoutMs(int value); - protected: - const SocketType socketType; + const zmq::socket_type socketType; const std::string host; const int port; const std::string address; std::thread::id tid; int id; - int recvTimeoutMs = DEFAULT_RECV_TIMEOUT_MS; - int sendTimeoutMs = DEFAULT_SEND_TIMEOUT_MS; + zmq::socket_t socket; void validateTimeout(int value); - - Message recvBuffer(int size); - - Message recvNoBuffer(); }; /* Send and Recv Message Endpoints */ @@ -94,11 +62,11 @@ class MessageEndpoint class SendMessageEndpoint : public MessageEndpoint { public: - SendMessageEndpoint(const std::string& hostIn, int portIn); - - void open(); + SendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); - void close(); + void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); }; class RecvMessageEndpoint : public MessageEndpoint @@ -106,9 +74,12 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn); - void open(); + Message recv(int size = 0); - void close(); + private: + Message recvBuffer(int size); + + Message recvNoBuffer(); }; class MessageTimeoutException : public faabric::util::FaabricException diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h index 1cbfbc966..a6ffd3635 100644 --- a/include/faabric/transport/MessageEndpointClient.h +++ b/include/faabric/transport/MessageEndpointClient.h @@ -12,7 +12,9 @@ namespace faabric::transport { class MessageEndpointClient : public faabric::transport::SendMessageEndpoint { public: - MessageEndpointClient(const std::string& host, int port); + MessageEndpointClient(const std::string& host, + int port, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); /* Wait for a message * diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 006fa631c..5af8fd12f 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -82,7 +82,6 @@ void clearMockRequests() FunctionCallClient::FunctionCallClient(const std::string& hostIn) : faabric::transport::MessageEndpointClient(hostIn, FUNCTION_CALL_PORT) { - this->open(); } void FunctionCallClient::sendHeader(faabric::scheduler::FunctionCalls call) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 0df99e888..96a3faeba 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -63,17 +63,11 @@ void Scheduler::addHostToGlobalSet() void Scheduler::closeFunctionCallClients() { - for (auto& iter : functionCallClients) { - iter.second.close(); - } functionCallClients.clear(); } void Scheduler::closeSnapshotClients() { - for (auto& iter : snapshotClients) { - iter.second.close(); - } snapshotClients.clear(); } diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index ba4a9b6d8..0676c9c22 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -78,7 +78,6 @@ void clearMockSnapshotRequests() SnapshotClient::SnapshotClient(const std::string& hostIn) : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_PORT) { - this->open(); } void SnapshotClient::sendHeader(faabric::scheduler::SnapshotCalls call) diff --git a/src/state/InMemoryStateKeyValue.cpp b/src/state/InMemoryStateKeyValue.cpp index e5d15833c..0f975cb8d 100644 --- a/src/state/InMemoryStateKeyValue.cpp +++ b/src/state/InMemoryStateKeyValue.cpp @@ -25,7 +25,6 @@ size_t InMemoryStateKeyValue::getStateSizeFromRemote(const std::string& userIn, StateClient stateClient(userIn, keyIn, masterIP); size_t stateSize = stateClient.stateSize(); - stateClient.close(); return stateSize; } @@ -43,7 +42,6 @@ void InMemoryStateKeyValue::deleteFromRemote(const std::string& userIn, StateClient stateClient(userIn, keyIn, masterIP); stateClient.deleteState(); - stateClient.close(); } void InMemoryStateKeyValue::clearAll(bool global) @@ -90,7 +88,6 @@ void InMemoryStateKeyValue::lockGlobal() } else { StateClient cli(user, key, masterIP); cli.lock(); - cli.close(); } } @@ -101,7 +98,6 @@ void InMemoryStateKeyValue::unlockGlobal() } else { StateClient cli(user, key, masterIP); cli.unlock(); - cli.close(); } } @@ -114,7 +110,6 @@ void InMemoryStateKeyValue::pullFromRemote() std::vector chunks = getAllChunks(); StateClient cli(user, key, masterIP); cli.pullChunks(chunks, BYTES(sharedMemory)); - cli.close(); } void InMemoryStateKeyValue::pullChunkFromRemote(long offset, size_t length) @@ -127,7 +122,6 @@ void InMemoryStateKeyValue::pullChunkFromRemote(long offset, size_t length) std::vector chunks = { StateChunk(offset, length, chunkStart) }; StateClient cli(user, key, masterIP); cli.pullChunks(chunks, BYTES(sharedMemory)); - cli.close(); } void InMemoryStateKeyValue::pushToRemote() @@ -139,7 +133,6 @@ void InMemoryStateKeyValue::pushToRemote() std::vector allChunks = getAllChunks(); StateClient cli(user, key, masterIP); cli.pushChunks(allChunks); - cli.close(); } void InMemoryStateKeyValue::pushPartialToRemote( @@ -150,7 +143,6 @@ void InMemoryStateKeyValue::pushPartialToRemote( } else { StateClient cli(user, key, masterIP); cli.pushChunks(chunks); - cli.close(); } } @@ -166,7 +158,6 @@ void InMemoryStateKeyValue::appendToRemote(const uint8_t* data, size_t length) } else { StateClient cli(user, key, masterIP); cli.append(data, length); - cli.close(); } } @@ -186,7 +177,6 @@ void InMemoryStateKeyValue::pullAppendedFromRemote(uint8_t* data, } else { StateClient cli(user, key, masterIP); cli.pullAppended(data, length, nValues); - cli.close(); } } @@ -198,7 +188,6 @@ void InMemoryStateKeyValue::clearAppendedFromRemote() } else { StateClient cli(user, key, masterIP); cli.clearAppended(); - cli.close(); } } diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index 8506e0dac..f2eb92447 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -13,7 +13,6 @@ StateClient::StateClient(const std::string& userIn, , host(hostIn) , reg(state::getInMemoryStateRegistry()) { - this->open(); } void StateClient::sendHeader(faabric::state::StateCalls call) diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index a67c5e9d2..e5038cb49 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -24,75 +24,75 @@ namespace faabric::transport { -MessageEndpoint::MessageEndpoint(SocketType socketTypeIn, +MessageEndpoint::MessageEndpoint(zmq::socket_type socketTypeIn, const std::string& hostIn, - int portIn) + int portIn, + int timeoutMs) : socketType(socketTypeIn) , host(hostIn) , port(portIn) , address("tcp://" + host + ":" + std::to_string(port)) , tid(std::this_thread::get_id()) , id(faabric::util::generateGid()) -{} - -MessageEndpoint::~MessageEndpoint() { - if (this->socket != nullptr) { - SPDLOG_WARN("Destroying an open message endpoint!"); - this->close(); - } -} - -void MessageEndpoint::open() -{ - // Check we are opening from the same thread. - assert(tid == std::this_thread::get_id()); + // Create the socket + CATCH_ZMQ_ERR(socket = + zmq::socket_t(*getGlobalMessageContext(), socketType), + "socket_create") - if (socket != nullptr) { - SPDLOG_ERROR("Double opening socket {}", id); - throw std::runtime_error("Double opening a socket"); - } + // Set socket options + socket.set(zmq::sockopt::rcvtimeo, timeoutMs); + socket.set(zmq::sockopt::sndtimeo, timeoutMs); // Note - only one socket may bind, but several can connect. This // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between // bind and connect does not matter. - std::shared_ptr context = getGlobalMessageContext(); switch (socketType) { - case faabric::transport::SocketType::PUSH: { - CATCH_ZMQ_ERR( - { - this->socket = std::make_unique( - *context, zmq::socket_type::push); - }, - "push_socket") - - CATCH_ZMQ_ERR(this->socket->connect(address), "connect") - + case zmq::socket_type::push: { + SPDLOG_TRACE("Opening push socket {}:{} (timeout {}ms)", + host, + port, + timeoutMs); + CATCH_ZMQ_ERR(socket.connect(address), "connect") break; } - case faabric::transport::SocketType::PULL: { - CATCH_ZMQ_ERR( - { - this->socket = std::make_unique( - *context, zmq::socket_type::pull); - }, - "pull_socket") - - CATCH_ZMQ_ERR(this->socket->bind(address), "bind") - + case zmq::socket_type::pull: { + SPDLOG_TRACE("Opening pull socket {}:{} (timeout {}ms)", + host, + port, + timeoutMs); + CATCH_ZMQ_ERR(socket.bind(address), "bind") break; } default: { throw std::runtime_error("Opening unrecognized socket type"); } } +} - // Set socket options - this->socket->setsockopt(ZMQ_RCVTIMEO, recvTimeoutMs); - this->socket->setsockopt(ZMQ_SNDTIMEO, sendTimeoutMs); +std::string MessageEndpoint::getHost() +{ + return host; +} + +int MessageEndpoint::getPort() +{ + return port; } -void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) +// ---------------------------------------------- +// SEND ENDPOINT +// ---------------------------------------------- + +SendMessageEndpoint::SendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs) + : MessageEndpoint(zmq::socket_type::push, hostIn, portIn, timeoutMs) +{} + +void SendMessageEndpoint::send(uint8_t* serialisedMsg, + size_t msgSize, + bool more) { assert(tid == std::this_thread::get_id()); @@ -100,17 +100,13 @@ void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) throw std::runtime_error("Sending on an unopened socket"); } - if (socketType == SocketType::PULL) { - throw std::runtime_error("Sending on a recv socket"); - } - zmq::send_flags sendFlags = more ? zmq::send_flags::sndmore : zmq::send_flags::none; CATCH_ZMQ_ERR( { auto res = - this->socket->send(zmq::buffer(serialisedMsg, msgSize), sendFlags); + socket.send(zmq::buffer(serialisedMsg, msgSize), sendFlags); if (res != msgSize) { SPDLOG_ERROR("Sent different bytes than expected (sent " "{}, expected {})", @@ -122,9 +118,18 @@ void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) "send") } -// By passing the expected recv buffer size, we instrument zeromq to receive on -// our provisioned buffer -Message MessageEndpoint::recv(int size) +// ---------------------------------------------- +// RECV ENDPOINT +// ---------------------------------------------- + +RecvMessageEndpoint::RecvMessageEndpoint(int portIn) + : MessageEndpoint(zmq::socket_type::pull, + ANY_HOST, + portIn, + DEFAULT_RECV_TIMEOUT_MS) +{} + +Message RecvMessageEndpoint::recv(int size) { assert(tid == std::this_thread::get_id()); assert(size >= 0); @@ -133,10 +138,6 @@ Message MessageEndpoint::recv(int size) throw std::runtime_error("Receiving on an unopened socket"); } - if (socketType == SocketType::PUSH) { - throw std::runtime_error("Receiving on a send socket"); - } - if (size == 0) { return recvNoBuffer(); } @@ -144,14 +145,14 @@ Message MessageEndpoint::recv(int size) return recvBuffer(size); } -Message MessageEndpoint::recvBuffer(int size) +Message RecvMessageEndpoint::recvBuffer(int size) { // Pre-allocate buffer to avoid copying data Message msg(size); CATCH_ZMQ_ERR( try { - auto res = this->socket->recv(zmq::buffer(msg.udata(), msg.size())); + auto res = socket.recv(zmq::buffer(msg.udata(), msg.size())); if (!res.has_value()) { SPDLOG_ERROR("Timed out receiving message of size {}", size); @@ -178,13 +179,13 @@ Message MessageEndpoint::recvBuffer(int size) return msg; } -Message MessageEndpoint::recvNoBuffer() +Message RecvMessageEndpoint::recvNoBuffer() { // Allocate a message to receive data zmq::message_t msg; CATCH_ZMQ_ERR( try { - auto res = this->socket->recv(msg); + auto res = socket.recv(msg); if (!res.has_value()) { SPDLOG_ERROR("Timed out receiving message with no size"); throw MessageTimeoutException("Timed out receiving message"); @@ -201,133 +202,4 @@ Message MessageEndpoint::recvNoBuffer() // Copy the received message to a buffer whose scope we control return Message(msg); } - -void MessageEndpoint::close() -{ - if (this->socket == nullptr) { - SPDLOG_ERROR("Closing unopened socket {}", id); - throw std::runtime_error("Closing unopened socket"); - } - - if (tid != std::this_thread::get_id()) { - SPDLOG_WARN("Closing socket from a different thread"); - } - - switch (socketType) { - case SocketType::PUSH: { - CATCH_ZMQ_ERR(this->socket->disconnect(address), "disconnect") - CATCH_ZMQ_ERR(this->socket->close(), "close") - break; - } - case SocketType::PULL: { - // We duplicate the call to close() because when unbinding, we want - // to block until we _actually_ have unbound, i.e. 0MQ has closed - // the socket (which happens asynchronously). - // - // Not blocking on un-bind can cause race-conditions when the - // underlying system is slow at closing sockets, and the application - // relies a lot on synchronous message-passing. - // - // NOTE that unbinding a socket has a considerable - // overhead compared to disconnecting it. - CATCH_ZMQ_ERR(this->socket->unbind(address), "unbind") - - // TODO - could we reuse the monitor across sockets? - zmq::monitor_t mon; - const std::string monAddr = - "inproc://monitor_" + std::to_string(id); - mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); - - CATCH_ZMQ_ERR(this->socket->close(), "close") - - // Wait for this to complete (should be fast) - mon.check_event(MONITOR_TIMEOUT_MS); - - break; - } - default: { - throw std::runtime_error("Closing unrecognised type of socket"); - } - } - - // Finally, null the socket - this->socket = nullptr; -} - -std::string MessageEndpoint::getHost() -{ - return host; -} - -int MessageEndpoint::getPort() -{ - return port; -} - -void MessageEndpoint::validateTimeout(int value) -{ - if (value <= 0) { - SPDLOG_ERROR("Setting invalid timeout of {}", value); - throw std::runtime_error("Setting invalid timeout"); - } - - if (socket != nullptr) { - SPDLOG_ERROR("Setting timeout of {} after socket created", value); - throw std::runtime_error("Setting timeout after socket created"); - } -} - -void MessageEndpoint::setRecvTimeoutMs(int value) -{ - validateTimeout(value); - recvTimeoutMs = value; -} - -void MessageEndpoint::setSendTimeoutMs(int value) -{ - validateTimeout(value); - sendTimeoutMs = value; -} - -/* Send and Recv Message Endpoints */ - -SendMessageEndpoint::SendMessageEndpoint(const std::string& hostIn, int portIn) - : MessageEndpoint(SocketType::PUSH, hostIn, portIn) -{} - -void SendMessageEndpoint::open() -{ - SPDLOG_TRACE( - fmt::format("Opening socket: {} (SEND {}:{})", id, host, port)); - - MessageEndpoint::open(); -} - -void SendMessageEndpoint::close() -{ - SPDLOG_TRACE( - fmt::format("Closing socket: {} (SEND {}:{})", id, host, port)); - - MessageEndpoint::close(); -} - -RecvMessageEndpoint::RecvMessageEndpoint(int portIn) - : MessageEndpoint(SocketType::PULL, ANY_HOST, portIn) -{} - -void RecvMessageEndpoint::open() -{ - SPDLOG_TRACE( - fmt::format("Opening socket: {} (RECV {}:{})", id, host, port)); - - MessageEndpoint::open(); -} - -void RecvMessageEndpoint::close() -{ - SPDLOG_TRACE( - fmt::format("Closing socket: {} (RECV {}:{})", id, host, port)); - - MessageEndpoint::close(); -} } diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index dee3228b8..cde120907 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -2,8 +2,10 @@ #include namespace faabric::transport { -MessageEndpointClient::MessageEndpointClient(const std::string& host, int port) - : SendMessageEndpoint(host, port) +MessageEndpointClient::MessageEndpointClient(const std::string& host, + int port, + int timeoutMs) + : SendMessageEndpoint(host, port, timeoutMs) {} // Block until we receive a response from the server @@ -12,21 +14,7 @@ Message MessageEndpointClient::awaitResponse(int port) // Wait for the response, open a temporary endpoint for it RecvMessageEndpoint endpoint(port); - // Inherit timeouts on temporary endpoint - endpoint.setRecvTimeoutMs(recvTimeoutMs); - endpoint.setSendTimeoutMs(sendTimeoutMs); - - endpoint.open(); - - Message receivedMessage; - try { - receivedMessage = endpoint.recv(); - } catch (MessageTimeoutException& ex) { - endpoint.close(); - throw; - } - - endpoint.close(); + Message receivedMessage = endpoint.recv(); return receivedMessage; } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 987d0ff7d..e50e7ddab 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -15,9 +15,6 @@ void MessageEndpointServer::start() servingThread = std::thread([this] { recvEndpoint = std::make_unique(this->port); - // Open message endpoint, and bind - recvEndpoint->open(); - // Loop until we receive a shutdown message while (true) { try { @@ -31,35 +28,27 @@ void MessageEndpointServer::start() continue; } } - - recvEndpoint->close(); }); } void MessageEndpointServer::stop() { + // Send a shutdown message via a temporary endpoint + SendMessageEndpoint e(recvEndpoint->getHost(), recvEndpoint->getPort()); + SPDLOG_TRACE("Sending shutdown message locally to {}:{}", recvEndpoint->getHost(), recvEndpoint->getPort()); - - // Send a shutdown message via a temporary endpoint - SendMessageEndpoint e(recvEndpoint->getHost(), recvEndpoint->getPort()); - e.open(); e.send(nullptr, 0); // Join the serving thread if (servingThread.joinable()) { servingThread.join(); } - - e.close(); } bool MessageEndpointServer::recv() { - // Check endpoint has been initialised - assert(recvEndpoint->socket != nullptr); - // Receive header and body Message header = recvEndpoint->recv(); @@ -96,8 +85,6 @@ void MessageEndpointServer::sendResponse(uint8_t* serialisedMsg, // Open the endpoint socket, server connects (not bind) to remote address SendMessageEndpoint sendEndpoint(returnHost, returnPort + REPLY_PORT_OFFSET); - sendEndpoint.open(); sendEndpoint.send(serialisedMsg, size); - sendEndpoint.close(); } } diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 39c7ea688..fd9396929 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -4,10 +4,8 @@ namespace faabric::transport { faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() { faabric::transport::RecvMessageEndpoint endpoint(MPI_PORT); - endpoint.open(); faabric::transport::Message m = endpoint.recv(); PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); - endpoint.close(); return msg; } @@ -22,29 +20,21 @@ void sendMpiHostRankMsg(const std::string& hostIn, throw std::runtime_error("Error serialising message"); } faabric::transport::SendMessageEndpoint endpoint(hostIn, MPI_PORT); - endpoint.open(); endpoint.send(sMsg, msgSize, false); - endpoint.close(); } } MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) : sendMessageEndpoint(hostIn, portIn) , recvMessageEndpoint(portIn) -{ - sendMessageEndpoint.open(); - recvMessageEndpoint.open(); -} +{} MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort) : sendMessageEndpoint(hostIn, sendPort) , recvMessageEndpoint(recvPort) -{ - sendMessageEndpoint.open(); - recvMessageEndpoint.open(); -} +{} void MpiMessageEndpoint::sendMpiMessage( const std::shared_ptr& msg) @@ -67,13 +57,5 @@ std::shared_ptr MpiMessageEndpoint::recvMpiMessage() return std::make_shared(msg); } -void MpiMessageEndpoint::close() -{ - if (sendMessageEndpoint.socket != nullptr) { - sendMessageEndpoint.close(); - } - if (recvMessageEndpoint.socket != nullptr) { - recvMessageEndpoint.close(); - } -} +void MpiMessageEndpoint::close() {} } diff --git a/tests/test/main.cpp b/tests/test/main.cpp index 1a6857965..d9c0b5828 100644 --- a/tests/test/main.cpp +++ b/tests/test/main.cpp @@ -19,8 +19,5 @@ int main(int argc, char* argv[]) fflush(stdout); - faabric::transport::getGlobalMessageContext()->shutdown(); - faabric::transport::getGlobalMessageContext()->close(); - return result; } diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index c7c68b8bb..988039b4d 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -46,7 +46,6 @@ class ClientServerFixture ~ClientServerFixture() { - cli.close(); server.stop(); executorFactory->reset(); } diff --git a/tests/test/scheduler/test_snapshot_client_server.cpp b/tests/test/scheduler/test_snapshot_client_server.cpp index 27461c97e..d995d2578 100644 --- a/tests/test/scheduler/test_snapshot_client_server.cpp +++ b/tests/test/scheduler/test_snapshot_client_server.cpp @@ -34,7 +34,6 @@ class SnapshotClientServerFixture ~SnapshotClientServerFixture() { - cli.close(); server.stop(); } }; diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index f3d9c7f91..b82b77cf9 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -134,9 +134,6 @@ TEST_CASE("Test request/ response", "[state]") REQUIRE(actualAppended == expected); } - // Close the state client - client.close(); - s.stop(); resetStateMode(); diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index d9025da21..a9c2aba82 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -13,22 +13,6 @@ const int testPort = 9999; const int testReplyPort = 9996; namespace tests { -TEST_CASE_METHOD(MessageContextFixture, - "Test open/close one client", - "[transport]") -{ - // Open an endpoint client, don't bind - MessageEndpoint cli(SocketType::PULL, thisHost, testPort); - REQUIRE_NOTHROW(cli.open()); - - // Open another endpoint client, bind - MessageEndpoint secondCli(SocketType::PUSH, thisHost, testPort); - REQUIRE_NOTHROW(secondCli.open()); - - // Close all endpoint clients - REQUIRE_NOTHROW(cli.close()); - REQUIRE_NOTHROW(secondCli.close()); -} TEST_CASE_METHOD(MessageContextFixture, "Test send/recv one message", @@ -36,11 +20,9 @@ TEST_CASE_METHOD(MessageContextFixture, { // Open the source endpoint client, don't bind SendMessageEndpoint src(thisHost, testPort); - src.open(); // Open the destination endpoint client, bind RecvMessageEndpoint dst(testPort); - dst.open(); // Send message std::string expectedMsg = "Hello world!"; @@ -53,10 +35,6 @@ TEST_CASE_METHOD(MessageContextFixture, REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); REQUIRE(actualMsg == expectedMsg); - - // Close endpoints - src.close(); - dst.close(); } TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") @@ -68,7 +46,6 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") std::thread senderThread([expectedMsg, expectedResponse] { // Open the source endpoint client, don't bind MessageEndpointClient src(thisHost, testPort); - src.open(); // Send message and wait for response uint8_t msg[expectedMsg.size()]; @@ -80,13 +57,10 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") assert(recvMsg.size() == expectedResponse.size()); std::string actualResponse(recvMsg.data(), recvMsg.size()); assert(actualResponse == expectedResponse); - - src.close(); }); // Receive message RecvMessageEndpoint dst(testPort); - dst.open(); faabric::transport::Message recvMsg = dst.recv(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); @@ -94,7 +68,6 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") // Send response, open a new endpoint for it SendMessageEndpoint dstResponse(thisHost, testReplyPort); - dstResponse.open(); uint8_t msg[expectedResponse.size()]; memcpy(msg, expectedResponse.c_str(), expectedResponse.size()); dstResponse.send(msg, expectedResponse.size()); @@ -103,10 +76,6 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") if (senderThread.joinable()) { senderThread.join(); } - - // Close receiving endpoints - dst.close(); - dstResponse.close(); } TEST_CASE_METHOD(MessageContextFixture, @@ -119,20 +88,16 @@ TEST_CASE_METHOD(MessageContextFixture, std::thread senderThread([numMessages, baseMsg] { // Open the source endpoint client, don't bind SendMessageEndpoint src(thisHost, testPort); - src.open(); for (int i = 0; i < numMessages; i++) { std::string expectedMsg = baseMsg + std::to_string(i); uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); src.send(msg, expectedMsg.size()); } - - src.close(); }); // Receive messages RecvMessageEndpoint dst(testPort); - dst.open(); for (int i = 0; i < numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -149,9 +114,6 @@ TEST_CASE_METHOD(MessageContextFixture, if (senderThread.joinable()) { senderThread.join(); } - - // Close the destination endpoint - dst.close(); } TEST_CASE_METHOD(MessageContextFixture, @@ -167,20 +129,16 @@ TEST_CASE_METHOD(MessageContextFixture, senderThreads.emplace_back(std::thread([numMessages, expectedMsg] { // Open the source endpoint client, don't bind SendMessageEndpoint src(thisHost, testPort); - src.open(); for (int i = 0; i < numMessages; i++) { uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); src.send(msg, expectedMsg.size()); } - - src.close(); })); } // Receive messages RecvMessageEndpoint dst(testPort); - dst.open(); for (int i = 0; i < numSenders * numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -197,47 +155,5 @@ TEST_CASE_METHOD(MessageContextFixture, t.join(); } } - - // Close the destination endpoint - dst.close(); -} - -TEST_CASE_METHOD(MessageContextFixture, - "Test can't set invalid send/recv timeouts", - "[transport]") -{ - MessageEndpoint cli(SocketType::PULL, thisHost, testPort); - - SECTION("Sanity check valid timeout") - { - REQUIRE_NOTHROW(cli.setRecvTimeoutMs(100)); - REQUIRE_NOTHROW(cli.setSendTimeoutMs(100)); - } - - SECTION("Recv zero timeout") { REQUIRE_THROWS(cli.setRecvTimeoutMs(0)); } - - SECTION("Send zero timeout") { REQUIRE_THROWS(cli.setSendTimeoutMs(0)); } - - SECTION("Recv negative timeout") - { - REQUIRE_THROWS(cli.setRecvTimeoutMs(-1)); - } - - SECTION("Send negative timeout") - { - REQUIRE_THROWS(cli.setSendTimeoutMs(-1)); - } - - SECTION("Recv, socket already initialised") - { - cli.open(); - REQUIRE_THROWS(cli.setRecvTimeoutMs(100)); - } - - SECTION("Send, socket already initialised") - { - cli.open(); - REQUIRE_THROWS(cli.setSendTimeoutMs(100)); - } } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index d21442a06..ff94f2ab8 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -84,8 +84,6 @@ TEST_CASE("Test send one message to server", "[transport]") // Open the source endpoint client, don't bind MessageEndpointClient src(thisHost, testPort); - src.open(); - // Send message: server expects header + body std::string header = "header"; uint8_t headerMsg[header.size()]; @@ -103,9 +101,6 @@ TEST_CASE("Test send one message to server", "[transport]") usleep(1000 * 300); REQUIRE(server.messageCount == 1); - // Close the client - src.close(); - // Close the server server.stop(); } @@ -120,14 +115,11 @@ TEST_CASE("Test send one-off response to client", "[transport]") std::thread clientThread([expectedMsg] { // Open the source endpoint client, don't bind MessageEndpointClient cli(thisHost, testPort); - cli.open(); Message msg = cli.awaitResponse(testPort + REPLY_PORT_OFFSET); assert(msg.size() == expectedMsg.size()); std::string actualMsg(msg.data(), msg.size()); assert(actualMsg == expectedMsg); - - cli.close(); }); uint8_t msg[expectedMsg.size()]; @@ -154,7 +146,6 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") clientThreads.emplace_back(std::thread([numMessages] { // Prepare client MessageEndpointClient cli(thisHost, testPort); - cli.open(); std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { @@ -169,8 +160,6 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") } usleep(1000 * 300); - - cli.close(); })); } @@ -198,7 +187,7 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") SECTION("Short timeout failure") { - clientTimeout = 100; + clientTimeout = 1; expectFailure = true; } @@ -217,9 +206,7 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") usleep(500 * 1000); // Set up the client - MessageEndpointClient cli(thisHost, testPort); - cli.setRecvTimeoutMs(clientTimeout); - cli.open(); + MessageEndpointClient cli(thisHost, testPort, clientTimeout); std::vector data = { 1, 1, 1 }; cli.send(data.data(), data.size(), true); @@ -237,8 +224,6 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") REQUIRE(responseMessage.dataCopy() == expected); } - cli.close(); - if (t.joinable()) { t.join(); } From c2af35e14dc091cf53852cda46470aa69b98c92a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 09:03:37 +0000 Subject: [PATCH 18/66] Remove MessageEndpointClient --- .../faabric/scheduler/FunctionCallClient.h | 4 +-- .../faabric/scheduler/FunctionCallServer.h | 1 - include/faabric/scheduler/SnapshotClient.h | 4 +-- include/faabric/state/StateClient.h | 4 +-- include/faabric/transport/MessageEndpoint.h | 9 ++++--- .../faabric/transport/MessageEndpointClient.h | 27 ------------------- .../faabric/transport/MessageEndpointServer.h | 1 - src/scheduler/FunctionCallClient.cpp | 6 ++--- src/scheduler/FunctionCallServer.cpp | 1 + src/scheduler/SnapshotClient.cpp | 6 ++--- src/scheduler/SnapshotServer.cpp | 1 + src/state/StateClient.cpp | 8 +++--- src/state/StateServer.cpp | 1 + src/transport/CMakeLists.txt | 4 +-- src/transport/MessageEndpoint.cpp | 14 +++++++++- src/transport/MessageEndpointClient.cpp | 21 --------------- src/transport/MessageEndpointServer.cpp | 1 + .../test_message_endpoint_client.cpp | 6 ++--- tests/test/transport/test_message_server.cpp | 9 ++++--- 19 files changed, 48 insertions(+), 80 deletions(-) delete mode 100644 include/faabric/transport/MessageEndpointClient.h delete mode 100644 src/transport/MessageEndpointClient.cpp diff --git a/include/faabric/scheduler/FunctionCallClient.h b/include/faabric/scheduler/FunctionCallClient.h index c6f20519f..44d969803 100644 --- a/include/faabric/scheduler/FunctionCallClient.h +++ b/include/faabric/scheduler/FunctionCallClient.h @@ -2,7 +2,7 @@ #include #include -#include +#include #include namespace faabric::scheduler { @@ -32,7 +32,7 @@ void clearMockRequests(); // ----------------------------------- // Message client // ----------------------------------- -class FunctionCallClient : public faabric::transport::MessageEndpointClient +class FunctionCallClient : public faabric::transport::SendMessageEndpoint { public: explicit FunctionCallClient(const std::string& hostIn); diff --git a/include/faabric/scheduler/FunctionCallServer.h b/include/faabric/scheduler/FunctionCallServer.h index 962cc7439..bf18a112a 100644 --- a/include/faabric/scheduler/FunctionCallServer.h +++ b/include/faabric/scheduler/FunctionCallServer.h @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace faabric::scheduler { diff --git a/include/faabric/scheduler/SnapshotClient.h b/include/faabric/scheduler/SnapshotClient.h index c97cc497c..3189847b2 100644 --- a/include/faabric/scheduler/SnapshotClient.h +++ b/include/faabric/scheduler/SnapshotClient.h @@ -2,7 +2,7 @@ #include #include -#include +#include #include namespace faabric::scheduler { @@ -32,7 +32,7 @@ void clearMockSnapshotRequests(); // gRPC client // ----------------------------------- -class SnapshotClient final : public faabric::transport::MessageEndpointClient +class SnapshotClient final : public faabric::transport::SendMessageEndpoint { public: explicit SnapshotClient(const std::string& hostIn); diff --git a/include/faabric/state/StateClient.h b/include/faabric/state/StateClient.h index 79b5ae9f1..13dd2bd96 100644 --- a/include/faabric/state/StateClient.h +++ b/include/faabric/state/StateClient.h @@ -3,10 +3,10 @@ #include #include #include -#include +#include namespace faabric::state { -class StateClient : public faabric::transport::MessageEndpointClient +class StateClient : public faabric::transport::SendMessageEndpoint { public: explicit StateClient(const std::string& userIn, diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index ff925b575..171525b23 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -32,7 +32,7 @@ class MessageEndpoint public: MessageEndpoint(zmq::socket_type socketTypeIn, const std::string& hostIn, - int portIn, int timeoutMs); + int portIn, int timeoutMsIn); // Delete assignment and copy-constructor as we need to be very careful with // socping and same-thread instantiation @@ -49,8 +49,9 @@ class MessageEndpoint const std::string host; const int port; const std::string address; - std::thread::id tid; - int id; + const int timeoutMs; + const std::thread::id tid; + const int id; zmq::socket_t socket; @@ -67,6 +68,8 @@ class SendMessageEndpoint : public MessageEndpoint int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); + + Message awaitResponse(int port); }; class RecvMessageEndpoint : public MessageEndpoint diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h deleted file mode 100644 index a6ffd3635..000000000 --- a/include/faabric/transport/MessageEndpointClient.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include -#include - -namespace faabric::transport { -/* Minimal message endpoint client - * - * Low-level and minimal message endpoint client to run in companion with - * a background-running server. - */ -class MessageEndpointClient : public faabric::transport::SendMessageEndpoint -{ - public: - MessageEndpointClient(const std::string& host, - int port, - int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - - /* Wait for a message - * - * This method blocks the calling thread until we receive a message from - * the specified host:port pair. When pointed at a server, this method - * allows for blocking communications. - */ - Message awaitResponse(int port); -}; -} diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 973ab87df..fe07aa922 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -2,7 +2,6 @@ #include #include -#include #include diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 5af8fd12f..0a7bf1a8a 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -80,9 +81,8 @@ void clearMockRequests() // Message Client // ----------------------------------- FunctionCallClient::FunctionCallClient(const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, FUNCTION_CALL_PORT) -{ -} + : faabric::transport::SendMessageEndpoint(hostIn, FUNCTION_CALL_PORT) +{} void FunctionCallClient::sendHeader(faabric::scheduler::FunctionCalls call) { diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 5237ee5db..88b0545c6 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 0676c9c22..78628761f 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -76,9 +77,8 @@ void clearMockSnapshotRequests() send(buffer, size); SnapshotClient::SnapshotClient(const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_PORT) -{ -} + : faabric::transport::SendMessageEndpoint(hostIn, SNAPSHOT_PORT) +{} void SnapshotClient::sendHeader(faabric::scheduler::SnapshotCalls call) { diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 45067a485..74034f48d 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index f2eb92447..54b022d84 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -7,13 +8,12 @@ namespace faabric::state { StateClient::StateClient(const std::string& userIn, const std::string& keyIn, const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, STATE_PORT) + : faabric::transport::SendMessageEndpoint(hostIn, STATE_PORT) , user(userIn) , key(keyIn) , host(hostIn) , reg(state::getInMemoryStateRegistry()) -{ -} +{} void StateClient::sendHeader(faabric::state::StateCalls call) { @@ -24,7 +24,7 @@ void StateClient::sendHeader(faabric::state::StateCalls call) faabric::transport::Message StateClient::awaitResponse() { // Call the superclass implementation - return MessageEndpointClient::awaitResponse(STATE_PORT + REPLY_PORT_OFFSET); + return SendMessageEndpoint::awaitResponse(STATE_PORT + REPLY_PORT_OFFSET); } void StateClient::sendStateRequest(faabric::state::StateCalls header, diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index 4c3feffd1..91dea2ace 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index 3ac400f19..b92573251 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -8,16 +8,14 @@ set(HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/transport/Message.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageContext.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpoint.h" - "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointClient.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointServer.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MpiMessageEndpoint.h" ) -set(LIB_FILES +set(LIB_FILES Message.cpp MessageContext.cpp MessageEndpoint.cpp - MessageEndpointClient.cpp MessageEndpointServer.cpp MpiMessageEndpoint.cpp ${HEADERS} diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index e5038cb49..5fed1af00 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -27,11 +27,12 @@ namespace faabric::transport { MessageEndpoint::MessageEndpoint(zmq::socket_type socketTypeIn, const std::string& hostIn, int portIn, - int timeoutMs) + int timeoutMsIn) : socketType(socketTypeIn) , host(hostIn) , port(portIn) , address("tcp://" + host + ":" + std::to_string(port)) + , timeoutMs(timeoutMsIn) , tid(std::this_thread::get_id()) , id(faabric::util::generateGid()) { @@ -118,6 +119,17 @@ void SendMessageEndpoint::send(uint8_t* serialisedMsg, "send") } +// Block until we receive a response from the server +Message SendMessageEndpoint::awaitResponse(int port) +{ + // Wait for the response, open a temporary endpoint for it + RecvMessageEndpoint endpoint(port); + + Message receivedMessage = endpoint.recv(); + + return receivedMessage; +} + // ---------------------------------------------- // RECV ENDPOINT // ---------------------------------------------- diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp deleted file mode 100644 index cde120907..000000000 --- a/src/transport/MessageEndpointClient.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include -#include - -namespace faabric::transport { -MessageEndpointClient::MessageEndpointClient(const std::string& host, - int port, - int timeoutMs) - : SendMessageEndpoint(host, port, timeoutMs) -{} - -// Block until we receive a response from the server -Message MessageEndpointClient::awaitResponse(int port) -{ - // Wait for the response, open a temporary endpoint for it - RecvMessageEndpoint endpoint(port); - - Message receivedMessage = endpoint.recv(); - - return receivedMessage; -} -} diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index e50e7ddab..7469c06be 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,4 +1,5 @@ #include +#include #include #include diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index a9c2aba82..c2905c173 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include using namespace faabric::transport; @@ -44,8 +44,8 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") std::string expectedResponse = "world!"; std::thread senderThread([expectedMsg, expectedResponse] { - // Open the source endpoint client, don't bind - MessageEndpointClient src(thisHost, testPort); + // Open the source endpoint client + SendMessageEndpoint src(thisHost, testPort); // Send message and wait for response uint8_t msg[expectedMsg.size()]; diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index ff94f2ab8..f04f4769d 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -3,6 +3,7 @@ #include #include +#include #include using namespace faabric::transport; @@ -82,7 +83,7 @@ TEST_CASE("Test send one message to server", "[transport]") server.start(); // Open the source endpoint client, don't bind - MessageEndpointClient src(thisHost, testPort); + SendMessageEndpoint src(thisHost, testPort); // Send message: server expects header + body std::string header = "header"; @@ -114,7 +115,7 @@ TEST_CASE("Test send one-off response to client", "[transport]") std::thread clientThread([expectedMsg] { // Open the source endpoint client, don't bind - MessageEndpointClient cli(thisHost, testPort); + SendMessageEndpoint cli(thisHost, testPort); Message msg = cli.awaitResponse(testPort + REPLY_PORT_OFFSET); assert(msg.size() == expectedMsg.size()); @@ -145,7 +146,7 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") for (int i = 0; i < numClients; i++) { clientThreads.emplace_back(std::thread([numMessages] { // Prepare client - MessageEndpointClient cli(thisHost, testPort); + SendMessageEndpoint cli(thisHost, testPort); std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { @@ -206,7 +207,7 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") usleep(500 * 1000); // Set up the client - MessageEndpointClient cli(thisHost, testPort, clientTimeout); + SendMessageEndpoint cli(thisHost, testPort, clientTimeout); std::vector data = { 1, 1, 1 }; cli.send(data.data(), data.size(), true); From d038e8e5de269173ccf4e1639052b4e808465ff4 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 09:23:05 +0000 Subject: [PATCH 19/66] Explicitly open and close message context --- include/faabric/transport/MessageContext.h | 15 ---- include/faabric/transport/context.h | 16 +++++ src/runner/FaabricMain.cpp | 1 + src/transport/CMakeLists.txt | 4 +- src/transport/MessageContext.cpp | 46 ------------ src/transport/MessageEndpoint.cpp | 2 +- src/transport/context.cpp | 71 +++++++++++++++++++ tests/test/main.cpp | 6 +- .../test_message_endpoint_client.cpp | 8 +-- tests/test/transport/test_message_server.cpp | 5 +- .../transport/test_mpi_message_endpoint.cpp | 4 +- tests/utils/fixtures.h | 8 --- 12 files changed, 105 insertions(+), 81 deletions(-) delete mode 100644 include/faabric/transport/MessageContext.h create mode 100644 include/faabric/transport/context.h delete mode 100644 src/transport/MessageContext.cpp create mode 100644 src/transport/context.cpp diff --git a/include/faabric/transport/MessageContext.h b/include/faabric/transport/MessageContext.h deleted file mode 100644 index 4b59bac5e..000000000 --- a/include/faabric/transport/MessageContext.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include - -/* - * The context object is thread safe, and the constructor parameter indicates - * the number of hardware IO threads to be used. As a rule of thumb, use one - * IO thread per Gbps of data. - */ -#define ZMQ_CONTEXT_IO_THREADS 1 - -namespace faabric::transport { -std::shared_ptr getGlobalMessageContext(); -} diff --git a/include/faabric/transport/context.h b/include/faabric/transport/context.h new file mode 100644 index 000000000..9a0a7d846 --- /dev/null +++ b/include/faabric/transport/context.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +#define ZMQ_CONTEXT_IO_THREADS 1 + +namespace faabric::transport { + +void initGlobalMessageContext(); + +std::shared_ptr getGlobalMessageContext(); + +void closeGlobalMessageContext(); + +} diff --git a/src/runner/FaabricMain.cpp b/src/runner/FaabricMain.cpp index ed9764b58..2e28a98e7 100644 --- a/src/runner/FaabricMain.cpp +++ b/src/runner/FaabricMain.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index b92573251..d7577ca34 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -4,17 +4,17 @@ set(HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/transport/common.h" + "${FAABRIC_INCLUDE_DIR}/faabric/transport/context.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/macros.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/Message.h" - "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageContext.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpoint.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointServer.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MpiMessageEndpoint.h" ) set(LIB_FILES + context.cpp Message.cpp - MessageContext.cpp MessageEndpoint.cpp MessageEndpointServer.cpp MpiMessageEndpoint.cpp diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp deleted file mode 100644 index 810f9102f..000000000 --- a/src/transport/MessageContext.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include -#include - -namespace faabric::transport { - -class ContextWrapper -{ - public: - std::shared_ptr ctx; - - ContextWrapper() - { - ctx = std::make_shared(ZMQ_CONTEXT_IO_THREADS); - } - - ~ContextWrapper() - { - SPDLOG_TRACE("Destroying ZeroMQ context"); - - // Force outstanding ops to return ETERM - ctx->shutdown(); - - // Close the context - ctx->close(); - } -}; - -static std::shared_ptr instance = nullptr; -static std::shared_mutex mx; - -std::shared_ptr getGlobalMessageContext() -{ - if (instance == nullptr) { - faabric::util::FullLock lock(mx); - if (instance == nullptr) { - instance = std::make_shared(); - } - } - - { - faabric::util::SharedLock lock(mx); - return instance->ctx; - } -} -} diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 5fed1af00..615e50405 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include diff --git a/src/transport/context.cpp b/src/transport/context.cpp new file mode 100644 index 000000000..f610d1fd2 --- /dev/null +++ b/src/transport/context.cpp @@ -0,0 +1,71 @@ +#include +#include +#include + +namespace faabric::transport { + +/* + * The zmq::context_t object is thread safe, and the constructor parameter + * indicates the number of hardware IO threads to be used. As a rule of thumb, + * use one IO thread per Gbps of data. + */ + +class ContextWrapper +{ + public: + std::shared_ptr ctx; + + ContextWrapper() + { + ctx = std::make_shared(ZMQ_CONTEXT_IO_THREADS); + } + + ~ContextWrapper() + { + SPDLOG_TRACE("Destroying ZeroMQ context"); + + // Force outstanding ops to return ETERM + ctx->shutdown(); + + // Close the context + ctx->close(); + } +}; + +static std::shared_ptr instance = nullptr; + +void initGlobalMessageContext() +{ + if (instance != nullptr) { + throw std::runtime_error("Must not initialise global context twice"); + } + + SPDLOG_TRACE("Initialising global ZeroMQ context"); + instance = std::make_shared(); +} + +std::shared_ptr getGlobalMessageContext() +{ + if (instance == nullptr) { + throw std::runtime_error( + "Must explicitly initialise and close global message context"); + } + + return instance->ctx; +} + +void closeGlobalMessageContext() +{ + if (instance == nullptr) { + throw std::runtime_error("Cannot close an uninitialised context"); + } + + SPDLOG_TRACE("Destroying global ZeroMQ context"); + + // Force outstanding ops to return ETERM + instance->ctx->shutdown(); + + // Close the context + instance->ctx->close(); +} +} diff --git a/tests/test/main.cpp b/tests/test/main.cpp index d9c0b5828..30a52adb3 100644 --- a/tests/test/main.cpp +++ b/tests/test/main.cpp @@ -4,7 +4,7 @@ #include "faabric_utils.h" -#include +#include #include #include @@ -12,6 +12,8 @@ FAABRIC_CATCH_LOGGER int main(int argc, char* argv[]) { + faabric::transport::initGlobalMessageContext(); + faabric::util::setTestMode(true); faabric::util::initLogging(); @@ -19,5 +21,7 @@ int main(int argc, char* argv[]) fflush(stdout); + faabric::transport::closeGlobalMessageContext(); + return result; } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index c2905c173..9de60a155 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -14,7 +14,7 @@ const int testReplyPort = 9996; namespace tests { -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv one message", "[transport]") { @@ -37,7 +37,7 @@ TEST_CASE_METHOD(MessageContextFixture, REQUIRE(actualMsg == expectedMsg); } -TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") +TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") { // Prepare common message/response std::string expectedMsg = "Hello "; @@ -78,7 +78,7 @@ TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") } } -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv many messages", "[transport]") { @@ -116,7 +116,7 @@ TEST_CASE_METHOD(MessageContextFixture, } } -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv many messages from many clients", "[transport]") { diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index f04f4769d..886b7fc97 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -186,10 +186,11 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") expectFailure = false; } + // TODO - reinstate this SECTION("Short timeout failure") { - clientTimeout = 1; - expectFailure = true; + clientTimeout = 20000; + expectFailure = false; } // Start the server in the background diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index a27ae9bf7..09d535ad3 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -6,7 +6,7 @@ using namespace faabric::transport; namespace tests { -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send and recv the hosts to rank message", "[transport]") { @@ -26,7 +26,7 @@ TEST_CASE_METHOD(MessageContextFixture, } } -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send and recv an MPI message", "[transport]") { diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 2c5ffc984..1568b3cf4 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -117,14 +117,6 @@ class ConfTestFixture faabric::util::SystemConfig& conf; }; -class MessageContextFixture : public SchedulerTestFixture -{ - public: - MessageContextFixture() {} - - ~MessageContextFixture() {} -}; - class MpiBaseTestFixture : public SchedulerTestFixture { public: From 317bf5ea1e3bae9b8da62ce3dac441cbc5129b2a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 09:27:56 +0000 Subject: [PATCH 20/66] Simplify global context --- src/transport/context.cpp | 39 ++++++++------------------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/src/transport/context.cpp b/src/transport/context.cpp index f610d1fd2..7b2c96b92 100644 --- a/src/transport/context.cpp +++ b/src/transport/context.cpp @@ -10,62 +10,39 @@ namespace faabric::transport { * use one IO thread per Gbps of data. */ -class ContextWrapper -{ - public: - std::shared_ptr ctx; - - ContextWrapper() - { - ctx = std::make_shared(ZMQ_CONTEXT_IO_THREADS); - } - - ~ContextWrapper() - { - SPDLOG_TRACE("Destroying ZeroMQ context"); - - // Force outstanding ops to return ETERM - ctx->shutdown(); - - // Close the context - ctx->close(); - } -}; - -static std::shared_ptr instance = nullptr; +static std::shared_ptr instance = nullptr; void initGlobalMessageContext() { if (instance != nullptr) { - throw std::runtime_error("Must not initialise global context twice"); + throw std::runtime_error("Trying to initialise ZeroMQ context twice"); } SPDLOG_TRACE("Initialising global ZeroMQ context"); - instance = std::make_shared(); + instance = std::make_shared(ZMQ_CONTEXT_IO_THREADS); } std::shared_ptr getGlobalMessageContext() { if (instance == nullptr) { - throw std::runtime_error( - "Must explicitly initialise and close global message context"); + throw std::runtime_error("Trying to access uninitialised ZeroMQ context"); } - return instance->ctx; + return instance; } void closeGlobalMessageContext() { if (instance == nullptr) { - throw std::runtime_error("Cannot close an uninitialised context"); + throw std::runtime_error("Cannot close uninitialised ZeroMQ context"); } SPDLOG_TRACE("Destroying global ZeroMQ context"); // Force outstanding ops to return ETERM - instance->ctx->shutdown(); + instance->shutdown(); // Close the context - instance->ctx->close(); + instance->close(); } } From b3bcc5d0d7369fd51113bf4ec775dae4e2101553 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 09:35:49 +0000 Subject: [PATCH 21/66] Reinstate timeout test --- include/faabric/transport/MessageEndpoint.h | 2 +- src/transport/MessageEndpoint.cpp | 19 ++++--------------- src/transport/context.cpp | 3 ++- tests/test/transport/test_message_server.cpp | 4 ++-- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 171525b23..a4d5d85a1 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -75,7 +75,7 @@ class SendMessageEndpoint : public MessageEndpoint class RecvMessageEndpoint : public MessageEndpoint { public: - RecvMessageEndpoint(int portIn); + RecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); Message recv(int size = 0); diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 615e50405..b89e74993 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include @@ -97,10 +97,6 @@ void SendMessageEndpoint::send(uint8_t* serialisedMsg, { assert(tid == std::this_thread::get_id()); - if (this->socket == nullptr) { - throw std::runtime_error("Sending on an unopened socket"); - } - zmq::send_flags sendFlags = more ? zmq::send_flags::sndmore : zmq::send_flags::none; @@ -123,7 +119,7 @@ void SendMessageEndpoint::send(uint8_t* serialisedMsg, Message SendMessageEndpoint::awaitResponse(int port) { // Wait for the response, open a temporary endpoint for it - RecvMessageEndpoint endpoint(port); + RecvMessageEndpoint endpoint(port, timeoutMs); Message receivedMessage = endpoint.recv(); @@ -134,11 +130,8 @@ Message SendMessageEndpoint::awaitResponse(int port) // RECV ENDPOINT // ---------------------------------------------- -RecvMessageEndpoint::RecvMessageEndpoint(int portIn) - : MessageEndpoint(zmq::socket_type::pull, - ANY_HOST, - portIn, - DEFAULT_RECV_TIMEOUT_MS) +RecvMessageEndpoint::RecvMessageEndpoint(int portIn, int timeoutMs) + : MessageEndpoint(zmq::socket_type::pull, ANY_HOST, portIn, timeoutMs) {} Message RecvMessageEndpoint::recv(int size) @@ -146,10 +139,6 @@ Message RecvMessageEndpoint::recv(int size) assert(tid == std::this_thread::get_id()); assert(size >= 0); - if (this->socket == nullptr) { - throw std::runtime_error("Receiving on an unopened socket"); - } - if (size == 0) { return recvNoBuffer(); } diff --git a/src/transport/context.cpp b/src/transport/context.cpp index 7b2c96b92..21780571b 100644 --- a/src/transport/context.cpp +++ b/src/transport/context.cpp @@ -25,7 +25,8 @@ void initGlobalMessageContext() std::shared_ptr getGlobalMessageContext() { if (instance == nullptr) { - throw std::runtime_error("Trying to access uninitialised ZeroMQ context"); + throw std::runtime_error( + "Trying to access uninitialised ZeroMQ context"); } return instance; diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 886b7fc97..8feeeff91 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -189,8 +189,8 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") // TODO - reinstate this SECTION("Short timeout failure") { - clientTimeout = 20000; - expectFailure = false; + clientTimeout = 10; + expectFailure = true; } // Start the server in the background From 1daeb8e063432c4766cb553d7caad6cb7e980060 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 09:57:48 +0000 Subject: [PATCH 22/66] Move sendResponse into RecvMessageEndpoint --- include/faabric/transport/MessageEndpoint.h | 16 ++++++++++++-- .../faabric/transport/MessageEndpointServer.h | 15 ++----------- include/faabric/transport/macros.h | 2 +- src/transport/MessageEndpoint.cpp | 14 +++++++++++++ src/transport/MessageEndpointServer.cpp | 13 ------------ tests/test/transport/test_message_server.cpp | 21 +++---------------- 6 files changed, 34 insertions(+), 47 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index a4d5d85a1..0a2306591 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -20,7 +20,7 @@ namespace faabric::transport { - /* Wrapper arround zmq::socket_t +/* Wrapper arround zmq::socket_t * * Thread-unsafe socket-like object. MUST be open-ed and close-ed from the * _same_ thread. For a proto://host:pair triple, one socket may bind, and all @@ -32,7 +32,8 @@ class MessageEndpoint public: MessageEndpoint(zmq::socket_type socketTypeIn, const std::string& hostIn, - int portIn, int timeoutMsIn); + int portIn, + int timeoutMsIn); // Delete assignment and copy-constructor as we need to be very careful with // socping and same-thread instantiation @@ -79,6 +80,17 @@ class RecvMessageEndpoint : public MessageEndpoint Message recv(int size = 0); + /* Send response to the client + * + * Send a one-off response to a client identified by host:port pair. + * Together with a blocking recv at the client side, this + * method can be used to achieve synchronous client-server communication. + */ + void sendResponse(uint8_t* data, + int size, + const std::string& returnHost, + int returnPort); + private: Message recvBuffer(int size); diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index fe07aa922..da250f1ee 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -25,6 +25,8 @@ class MessageEndpointServer virtual void stop(); protected: + std::unique_ptr recvEndpoint = nullptr; + bool recv(); /* Template function to handle message reception @@ -36,22 +38,9 @@ class MessageEndpointServer virtual void doRecv(faabric::transport::Message& header, faabric::transport::Message& body) = 0; - /* Send response to the client - * - * Send a one-off response to a client identified by host:port pair. - * Together with a blocking recv at the client side, this - * method can be used to achieve synchronous client-server communication. - */ - void sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort); - private: const int port; - std::unique_ptr recvEndpoint = nullptr; - std::thread servingThread; }; } diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index 6dfac596e..508069a40 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -29,7 +29,7 @@ if (!msg.SerializeToArray(sMsg, msgSize)) { \ throw std::runtime_error("Error serialising message"); \ } \ - sendResponse(sMsg, msgSize, host, port); \ + recvEndpoint->sendResponse(sMsg, msgSize, host, port); \ } #define PARSE_MSG(T, data, size) \ diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index b89e74993..f940cfc0b 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -203,4 +204,17 @@ Message RecvMessageEndpoint::recvNoBuffer() // Copy the received message to a buffer whose scope we control return Message(msg); } + +// We create a new endpoint every time. Re-using them would be a possible +// optimisation if needed. +void RecvMessageEndpoint::sendResponse(uint8_t* data, + int size, + const std::string& returnHost, + int returnPort) +{ + // Open the endpoint socket, server connects (not bind) to remote address + SendMessageEndpoint sendEndpoint(returnHost, + returnPort + REPLY_PORT_OFFSET); + sendEndpoint.send(data, size); +} } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 7469c06be..3e95b578a 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -75,17 +75,4 @@ bool MessageEndpointServer::recv() return true; } - -// We create a new endpoint every time. Re-using them would be a possible -// optimisation if needed. -void MessageEndpointServer::sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort) -{ - // Open the endpoint socket, server connects (not bind) to remote address - SendMessageEndpoint sendEndpoint(returnHost, - returnPort + REPLY_PORT_OFFSET); - sendEndpoint.send(serialisedMsg, size); -} } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 8feeeff91..873402f8d 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -22,18 +22,6 @@ class DummyServer final : public MessageEndpointServer // Variable to keep track of the received messages int messageCount; - // This method is protected in the base class, as it's always called from - // the doRecv implementation. To ease testing, we make it public with this - // workaround. - void sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort) - { - MessageEndpointServer::sendResponse( - serialisedMsg, size, returnHost, returnPort); - } - private: void doRecv(faabric::transport::Message& header, faabric::transport::Message& body) override @@ -60,7 +48,7 @@ class SlowServer final : public MessageEndpointServer SPDLOG_DEBUG("Slow message server test recv"); usleep(delayMs * 1000); - MessageEndpointServer::sendResponse( + recvEndpoint->sendResponse( data.data(), data.size(), thisHost, testPort); } }; @@ -108,8 +96,7 @@ TEST_CASE("Test send one message to server", "[transport]") TEST_CASE("Test send one-off response to client", "[transport]") { - DummyServer server; - server.start(); + RecvMessageEndpoint recvEndpoint(testPort); std::string expectedMsg = "Response from server"; @@ -125,13 +112,11 @@ TEST_CASE("Test send one-off response to client", "[transport]") uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - server.sendResponse(msg, expectedMsg.size(), thisHost, testPort); + recvEndpoint.sendResponse(msg, expectedMsg.size(), thisHost, testPort); if (clientThread.joinable()) { clientThread.join(); } - - server.stop(); } TEST_CASE("Test multiple clients talking to one server", "[transport]") From d4be278770289ba48633068ad1ef1c78448cd3ca Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 10:08:02 +0000 Subject: [PATCH 23/66] Default reply port --- include/faabric/transport/MessageEndpoint.h | 7 ++----- include/faabric/transport/macros.h | 4 ++-- src/scheduler/FunctionCallClient.cpp | 4 ++-- src/scheduler/FunctionCallServer.cpp | 4 ++-- src/scheduler/SnapshotClient.cpp | 4 ++-- src/scheduler/SnapshotServer.cpp | 4 ++-- src/state/StateClient.cpp | 2 +- src/state/StateServer.cpp | 18 +++++++++--------- src/transport/MessageEndpoint.cpp | 10 ++++------ .../transport/test_message_endpoint_client.cpp | 7 +++---- tests/test/transport/test_message_server.cpp | 13 +++++-------- 11 files changed, 34 insertions(+), 43 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 0a2306591..cde7e1327 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -70,7 +70,7 @@ class SendMessageEndpoint : public MessageEndpoint void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); - Message awaitResponse(int port); + Message awaitResponse(); }; class RecvMessageEndpoint : public MessageEndpoint @@ -86,10 +86,7 @@ class RecvMessageEndpoint : public MessageEndpoint * Together with a blocking recv at the client side, this * method can be used to achieve synchronous client-server communication. */ - void sendResponse(uint8_t* data, - int size, - const std::string& returnHost, - int returnPort); + void sendResponse(uint8_t* data, int size, const std::string& returnHost); private: Message recvBuffer(int size); diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index 508069a40..21e715d4f 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -22,14 +22,14 @@ send(sMsg, msgSize); \ } -#define SEND_SERVER_RESPONSE(msg, host, port) \ +#define SEND_SERVER_RESPONSE(msg, host) \ size_t msgSize = msg.ByteSizeLong(); \ { \ uint8_t sMsg[msgSize]; \ if (!msg.SerializeToArray(sMsg, msgSize)) { \ throw std::runtime_error("Error serialising message"); \ } \ - recvEndpoint->sendResponse(sMsg, msgSize, host, port); \ + recvEndpoint->sendResponse(sMsg, msgSize, host); \ } #define PARSE_MSG(T, data, size) \ diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 0a7bf1a8a..67b6a3db3 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -103,7 +103,7 @@ void FunctionCallClient::sendFlush() SEND_MESSAGE(faabric::scheduler::FunctionCalls::Flush, call); // Await the response - awaitResponse(FUNCTION_CALL_PORT + REPLY_PORT_OFFSET); + awaitResponse(); } } @@ -128,7 +128,7 @@ faabric::HostResources FunctionCallClient::getResources() // Receive message faabric::transport::Message msg = - awaitResponse(FUNCTION_CALL_PORT + REPLY_PORT_OFFSET); + awaitResponse(); // Deserialise message string if (!response.ParseFromArray(msg.data(), msg.size())) { throw std::runtime_error("Error deserialising message"); diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 88b0545c6..05a036428 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -57,7 +57,7 @@ void FunctionCallServer::recvFlush(faabric::transport::Message& body) scheduler.flushLocally(); faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, msg.returnhost(), FUNCTION_CALL_PORT) + SEND_SERVER_RESPONSE(response, msg.returnhost()) } void FunctionCallServer::recvExecuteFunctions(faabric::transport::Message& body) @@ -86,6 +86,6 @@ void FunctionCallServer::recvGetResources(faabric::transport::Message& body) // Send the response body faabric::HostResources response = scheduler.getThisHostResources(); - SEND_SERVER_RESPONSE(response, msg.returnhost(), FUNCTION_CALL_PORT) + SEND_SERVER_RESPONSE(response, msg.returnhost()) } } diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 78628761f..37a072f1d 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -111,7 +111,7 @@ void SnapshotClient::pushSnapshot(const std::string& key, SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::PushSnapshot) // Await a response as this call must be synchronous - awaitResponse(SNAPSHOT_PORT + REPLY_PORT_OFFSET); + awaitResponse(); } } @@ -153,7 +153,7 @@ void SnapshotClient::pushSnapshotDiffs( SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::PushSnapshotDiffs) // Await a response as this call must be synchronous - awaitResponse(SNAPSHOT_PORT + REPLY_PORT_OFFSET); + awaitResponse(); } } diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 74034f48d..2d127ba13 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -76,7 +76,7 @@ void SnapshotServer::recvPushSnapshot(faabric::transport::Message& msg) // Send response faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, r->return_host()->str(), SNAPSHOT_PORT) + SEND_SERVER_RESPONSE(response, r->return_host()->str()) } void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) @@ -107,7 +107,7 @@ void SnapshotServer::recvPushSnapshotDiffs(faabric::transport::Message& msg) // Send response faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, r->return_host()->str(), SNAPSHOT_PORT) + SEND_SERVER_RESPONSE(response, r->return_host()->str()) } void SnapshotServer::applyDiffsToSnapshot( diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index 54b022d84..ddb0bce89 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -24,7 +24,7 @@ void StateClient::sendHeader(faabric::state::StateCalls call) faabric::transport::Message StateClient::awaitResponse() { // Call the superclass implementation - return SendMessageEndpoint::awaitResponse(STATE_PORT + REPLY_PORT_OFFSET); + return SendMessageEndpoint::awaitResponse(); } void StateClient::sendStateRequest(faabric::state::StateCalls header, diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index 91dea2ace..96d4a43bd 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -67,7 +67,7 @@ void StateServer::recvSize(faabric::transport::Message& body) response.set_user(kv->user); response.set_key(kv->key); response.set_statesize(kv->size()); - SEND_SERVER_RESPONSE(response, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(response, msg.returnhost()) } void StateServer::recvPull(faabric::transport::Message& body) @@ -91,7 +91,7 @@ void StateServer::recvPull(faabric::transport::Message& body) response.set_offset(chunkOffset); // TODO: avoid copying here response.set_data(chunk, chunkLen); - SEND_SERVER_RESPONSE(response, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(response, msg.returnhost()) } void StateServer::recvPush(faabric::transport::Message& body) @@ -109,7 +109,7 @@ void StateServer::recvPush(faabric::transport::Message& body) msg.offset(), BYTES_CONST(msg.data().c_str()), msg.data().size()); faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) } void StateServer::recvAppend(faabric::transport::Message& body) @@ -123,7 +123,7 @@ void StateServer::recvAppend(faabric::transport::Message& body) kv->append(reqData, dataLen); faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) } void StateServer::recvPullAppended(faabric::transport::Message& body) @@ -142,7 +142,7 @@ void StateServer::recvPullAppended(faabric::transport::Message& body) appendedValue->set_data(reinterpret_cast(value.data.get()), value.length); } - SEND_SERVER_RESPONSE(response, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(response, msg.returnhost()) } void StateServer::recvDelete(faabric::transport::Message& body) @@ -154,7 +154,7 @@ void StateServer::recvDelete(faabric::transport::Message& body) state.deleteKV(msg.user(), msg.key()); faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) } void StateServer::recvClearAppended(faabric::transport::Message& body) @@ -167,7 +167,7 @@ void StateServer::recvClearAppended(faabric::transport::Message& body) kv->clearAppended(); faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) } void StateServer::recvLock(faabric::transport::Message& body) @@ -180,7 +180,7 @@ void StateServer::recvLock(faabric::transport::Message& body) kv->lockWrite(); faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) } void StateServer::recvUnlock(faabric::transport::Message& body) @@ -193,6 +193,6 @@ void StateServer::recvUnlock(faabric::transport::Message& body) kv->unlockWrite(); faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) } } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index f940cfc0b..d8d6cb66e 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -117,10 +117,10 @@ void SendMessageEndpoint::send(uint8_t* serialisedMsg, } // Block until we receive a response from the server -Message SendMessageEndpoint::awaitResponse(int port) +Message SendMessageEndpoint::awaitResponse() { // Wait for the response, open a temporary endpoint for it - RecvMessageEndpoint endpoint(port, timeoutMs); + RecvMessageEndpoint endpoint(port + REPLY_PORT_OFFSET, timeoutMs); Message receivedMessage = endpoint.recv(); @@ -209,12 +209,10 @@ Message RecvMessageEndpoint::recvNoBuffer() // optimisation if needed. void RecvMessageEndpoint::sendResponse(uint8_t* data, int size, - const std::string& returnHost, - int returnPort) + const std::string& returnHost) { // Open the endpoint socket, server connects (not bind) to remote address - SendMessageEndpoint sendEndpoint(returnHost, - returnPort + REPLY_PORT_OFFSET); + SendMessageEndpoint sendEndpoint(returnHost, port + REPLY_PORT_OFFSET); sendEndpoint.send(data, size); } } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 9de60a155..95b65b911 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -9,8 +9,7 @@ using namespace faabric::transport; const std::string thisHost = "127.0.0.1"; -const int testPort = 9999; -const int testReplyPort = 9996; +const int testPort = 9800; namespace tests { @@ -53,7 +52,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") src.send(msg, expectedMsg.size()); // Block waiting for a response - faabric::transport::Message recvMsg = src.awaitResponse(testReplyPort); + faabric::transport::Message recvMsg = src.awaitResponse(); assert(recvMsg.size() == expectedResponse.size()); std::string actualResponse(recvMsg.data(), recvMsg.size()); assert(actualResponse == expectedResponse); @@ -67,7 +66,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") REQUIRE(actualMsg == expectedMsg); // Send response, open a new endpoint for it - SendMessageEndpoint dstResponse(thisHost, testReplyPort); + SendMessageEndpoint dstResponse(thisHost, testPort); uint8_t msg[expectedResponse.size()]; memcpy(msg, expectedResponse.c_str(), expectedResponse.size()); dstResponse.send(msg, expectedResponse.size()); diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 873402f8d..292b864d1 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -48,8 +48,7 @@ class SlowServer final : public MessageEndpointServer SPDLOG_DEBUG("Slow message server test recv"); usleep(delayMs * 1000); - recvEndpoint->sendResponse( - data.data(), data.size(), thisHost, testPort); + recvEndpoint->sendResponse(data.data(), data.size(), thisHost); } }; @@ -104,7 +103,7 @@ TEST_CASE("Test send one-off response to client", "[transport]") // Open the source endpoint client, don't bind SendMessageEndpoint cli(thisHost, testPort); - Message msg = cli.awaitResponse(testPort + REPLY_PORT_OFFSET); + Message msg = cli.awaitResponse(); assert(msg.size() == expectedMsg.size()); std::string actualMsg(msg.data(), msg.size()); assert(actualMsg == expectedMsg); @@ -112,7 +111,7 @@ TEST_CASE("Test send one-off response to client", "[transport]") uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - recvEndpoint.sendResponse(msg, expectedMsg.size(), thisHost, testPort); + recvEndpoint.sendResponse(msg, expectedMsg.size(), thisHost); if (clientThread.joinable()) { clientThread.join(); @@ -201,12 +200,10 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") if (expectFailure) { // Check for failure - REQUIRE_THROWS_AS(cli.awaitResponse(testPort + REPLY_PORT_OFFSET), - MessageTimeoutException); + REQUIRE_THROWS_AS(cli.awaitResponse(), MessageTimeoutException); } else { // Check response from server successful - Message responseMessage = - cli.awaitResponse(testPort + REPLY_PORT_OFFSET); + Message responseMessage = cli.awaitResponse(); std::vector expected = { 0, 1, 2, 3 }; REQUIRE(responseMessage.dataCopy() == expected); } From 96d2bf3593ee9f867a0df522aaf84a4d90908f6d Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 11:31:49 +0000 Subject: [PATCH 24/66] Add linger --- include/faabric/transport/MessageEndpoint.h | 5 +++++ src/transport/MessageEndpoint.cpp | 9 +++++++-- .../test/transport/test_message_endpoint_client.cpp | 13 +++++++------ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index cde7e1327..e9049e0d3 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -18,6 +18,11 @@ #define DEFAULT_RECV_TIMEOUT_MS 20000 #define DEFAULT_SEND_TIMEOUT_MS 20000 +// How long undelivered messages will hang around when the socket is closed, +// which also determines how long the context will hang for when closing if +// things haven't yet completed (usually only when there's an error). +#define LINGER_MS 1000 + namespace faabric::transport { /* Wrapper arround zmq::socket_t diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index d8d6cb66e..175e842b3 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,3 +1,4 @@ +#include "zmq.hpp" #include #include #include @@ -42,10 +43,13 @@ MessageEndpoint::MessageEndpoint(zmq::socket_type socketTypeIn, zmq::socket_t(*getGlobalMessageContext(), socketType), "socket_create") - // Set socket options + // Set socket timeouts socket.set(zmq::sockopt::rcvtimeo, timeoutMs); socket.set(zmq::sockopt::sndtimeo, timeoutMs); + // Note - setting linger here is essential to avoid infinite hangs + socket.set(zmq::sockopt::linger, LINGER_MS); + // Note - only one socket may bind, but several can connect. This // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between // bind and connect does not matter. @@ -212,7 +216,8 @@ void RecvMessageEndpoint::sendResponse(uint8_t* data, const std::string& returnHost) { // Open the endpoint socket, server connects (not bind) to remote address - SendMessageEndpoint sendEndpoint(returnHost, port + REPLY_PORT_OFFSET); + SendMessageEndpoint sendEndpoint( + returnHost, port + REPLY_PORT_OFFSET, timeoutMs); sendEndpoint.send(data, size); } } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 95b65b911..6d883a292 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -5,6 +5,7 @@ #include #include +#include using namespace faabric::transport; @@ -47,9 +48,10 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") SendMessageEndpoint src(thisHost, testPort); // Send message and wait for response - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - src.send(msg, expectedMsg.size()); + std::vector bytes(BYTES_CONST(expectedMsg.c_str()), + BYTES_CONST(expectedMsg.c_str()) + + expectedMsg.size()); + src.send(bytes.data(), bytes.size()); // Block waiting for a response faabric::transport::Message recvMsg = src.awaitResponse(); @@ -65,11 +67,10 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") std::string actualMsg(recvMsg.data(), recvMsg.size()); REQUIRE(actualMsg == expectedMsg); - // Send response, open a new endpoint for it - SendMessageEndpoint dstResponse(thisHost, testPort); + // Send response uint8_t msg[expectedResponse.size()]; memcpy(msg, expectedResponse.c_str(), expectedResponse.size()); - dstResponse.send(msg, expectedResponse.size()); + dst.sendResponse(msg, expectedResponse.size(), thisHost); // Wait for sender thread if (senderThread.joinable()) { From 763594ad70e0ed773c373a55b028a03dee23209b Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 12:02:01 +0000 Subject: [PATCH 25/66] Self review --- include/faabric/transport/MessageEndpoint.h | 2 -- src/scheduler/FunctionCallClient.cpp | 3 +- src/transport/MessageEndpoint.cpp | 6 +++- .../scheduler/test_snapshot_client_server.cpp | 5 +-- .../test_message_endpoint_client.cpp | 32 +++++++++++++++++++ tests/test/transport/test_message_server.cpp | 1 - 6 files changed, 39 insertions(+), 10 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index e9049e0d3..4470bdbb6 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -60,8 +60,6 @@ class MessageEndpoint const int id; zmq::socket_t socket; - - void validateTimeout(int value); }; /* Send and Recv Message Endpoints */ diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 67b6a3db3..06b755238 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -127,8 +127,7 @@ faabric::HostResources FunctionCallClient::getResources() SEND_MESSAGE(faabric::scheduler::FunctionCalls::GetResources, request); // Receive message - faabric::transport::Message msg = - awaitResponse(); + faabric::transport::Message msg = awaitResponse(); // Deserialise message string if (!response.ParseFromArray(msg.data(), msg.size())) { throw std::runtime_error("Error deserialising message"); diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 175e842b3..7b20a15d5 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,4 +1,3 @@ -#include "zmq.hpp" #include #include #include @@ -44,6 +43,11 @@ MessageEndpoint::MessageEndpoint(zmq::socket_type socketTypeIn, "socket_create") // Set socket timeouts + if (timeoutMs <= 0) { + SPDLOG_ERROR("Setting invalid timeout of {}", timeoutMs); + throw std::runtime_error("Setting invalid timeout"); + } + socket.set(zmq::sockopt::rcvtimeo, timeoutMs); socket.set(zmq::sockopt::sndtimeo, timeoutMs); diff --git a/tests/test/scheduler/test_snapshot_client_server.cpp b/tests/test/scheduler/test_snapshot_client_server.cpp index d995d2578..6db0d4054 100644 --- a/tests/test/scheduler/test_snapshot_client_server.cpp +++ b/tests/test/scheduler/test_snapshot_client_server.cpp @@ -32,10 +32,7 @@ class SnapshotClientServerFixture usleep(1000 * SHORT_TEST_TIMEOUT_MS); } - ~SnapshotClientServerFixture() - { - server.stop(); - } + ~SnapshotClientServerFixture() { server.stop(); } }; TEST_CASE_METHOD(SnapshotClientServerFixture, diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 6d883a292..82af9ff2c 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -156,4 +156,36 @@ TEST_CASE_METHOD(SchedulerTestFixture, } } } + +TEST_CASE_METHOD(SchedulerTestFixture, + "Test can't set invalid send/recv timeouts", + "[transport]") +{ + + SECTION("Sanity check valid timeout") + { + SendMessageEndpoint s(thisHost, testPort, 100); + RecvMessageEndpoint r(testPort, 100); + } + + SECTION("Recv zero timeout") + { + REQUIRE_THROWS(RecvMessageEndpoint(testPort, 0)); + } + + SECTION("Send zero timeout") + { + REQUIRE_THROWS(SendMessageEndpoint(thisHost, testPort, 0)); + } + + SECTION("Recv negative timeout") + { + REQUIRE_THROWS(RecvMessageEndpoint(testPort, -1)); + } + + SECTION("Send negative timeout") + { + REQUIRE_THROWS(SendMessageEndpoint(thisHost, testPort, -1)); + } +} } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 292b864d1..ff171c0c4 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -170,7 +170,6 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") expectFailure = false; } - // TODO - reinstate this SECTION("Short timeout failure") { clientTimeout = 10; From a5c57f38c822d1b361096c39f91c2d6538f801e5 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 12:09:18 +0000 Subject: [PATCH 26/66] Added global context init/ close where necessary --- src/mpi_native/MpiExecutor.cpp | 5 +++++ src/runner/FaabricMain.cpp | 4 ++++ src/transport/context.cpp | 6 ++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/mpi_native/MpiExecutor.cpp b/src/mpi_native/MpiExecutor.cpp index edf7b35ab..0427830df 100644 --- a/src/mpi_native/MpiExecutor.cpp +++ b/src/mpi_native/MpiExecutor.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace faabric::mpi_native { @@ -26,6 +27,8 @@ int32_t MpiExecutor::executeTask( int mpiNativeMain(int argc, char** argv) { + faabric::transport::initGlobalMessageContext(); + auto& scheduler = faabric::scheduler::getScheduler(); auto& conf = faabric::util::getSystemConfig(); @@ -54,6 +57,8 @@ int mpiNativeMain(int argc, char** argv) scheduler.callFunction(msg); } + faabric::transport::closeGlobalMessageContext(); + return 0; } } diff --git a/src/runner/FaabricMain.cpp b/src/runner/FaabricMain.cpp index 2e28a98e7..e80d9a9a8 100644 --- a/src/runner/FaabricMain.cpp +++ b/src/runner/FaabricMain.cpp @@ -21,6 +21,8 @@ FaabricMain::FaabricMain( void FaabricMain::startBackground() { + faabric::transport::initGlobalMessageContext(); + // Start basics startRunner(); @@ -92,6 +94,8 @@ void FaabricMain::shutdown() SPDLOG_INFO("Waiting for the snapshot server to finish"); snapshotServer.stop(); + faabric::transport::closeGlobalMessageContext(); + SPDLOG_INFO("Faabric pool successfully shut down"); } } diff --git a/src/transport/context.cpp b/src/transport/context.cpp index 21780571b..10dbeac7f 100644 --- a/src/transport/context.cpp +++ b/src/transport/context.cpp @@ -15,7 +15,8 @@ static std::shared_ptr instance = nullptr; void initGlobalMessageContext() { if (instance != nullptr) { - throw std::runtime_error("Trying to initialise ZeroMQ context twice"); + SPDLOG_WARN("ZeroMQ context already initialised. Skipping"); + return; } SPDLOG_TRACE("Initialising global ZeroMQ context"); @@ -35,7 +36,8 @@ std::shared_ptr getGlobalMessageContext() void closeGlobalMessageContext() { if (instance == nullptr) { - throw std::runtime_error("Cannot close uninitialised ZeroMQ context"); + SPDLOG_WARN("ZeroMQ context already closed (or not initialised). Skipping"); + return; } SPDLOG_TRACE("Destroying global ZeroMQ context"); From acfbe48eca0b53b8f7652e014976c62db16de508 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 12:09:35 +0000 Subject: [PATCH 27/66] Formatting --- src/transport/context.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transport/context.cpp b/src/transport/context.cpp index 10dbeac7f..e6adfc475 100644 --- a/src/transport/context.cpp +++ b/src/transport/context.cpp @@ -36,7 +36,8 @@ std::shared_ptr getGlobalMessageContext() void closeGlobalMessageContext() { if (instance == nullptr) { - SPDLOG_WARN("ZeroMQ context already closed (or not initialised). Skipping"); + SPDLOG_WARN( + "ZeroMQ context already closed (or not initialised). Skipping"); return; } From ed1b04bff2dd883ff6f9f247c6c1f3845732c6f3 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 15:39:06 +0000 Subject: [PATCH 28/66] Remove dummy state server --- include/faabric/state/DummyStateServer.h | 31 -- include/faabric/state/StateClient.h | 2 - src/state/CMakeLists.txt | 1 - src/state/DummyStateServer.cpp | 87 ---- src/state/InMemoryStateKeyValue.cpp | 11 +- src/state/StateClient.cpp | 1 - tests/test/state/test_state.cpp | 538 +++++++++++++---------- tests/test/state/test_state_server.cpp | 152 +++---- tests/utils/fixtures.h | 1 + 9 files changed, 363 insertions(+), 461 deletions(-) delete mode 100644 include/faabric/state/DummyStateServer.h delete mode 100644 src/state/DummyStateServer.cpp diff --git a/include/faabric/state/DummyStateServer.h b/include/faabric/state/DummyStateServer.h deleted file mode 100644 index df3010c57..000000000 --- a/include/faabric/state/DummyStateServer.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include - -namespace faabric::state { -class DummyStateServer -{ - public: - DummyStateServer(); - - std::shared_ptr getRemoteKv(); - - std::shared_ptr getLocalKv(); - - std::vector getRemoteKvValue(); - - std::vector getLocalKvValue(); - - std::vector dummyData; - std::string dummyUser; - std::string dummyKey; - - void start(); - - void stop(); - - state::State remoteState; - state::StateServer stateServer; -}; - -} diff --git a/include/faabric/state/StateClient.h b/include/faabric/state/StateClient.h index 13dd2bd96..05a503bae 100644 --- a/include/faabric/state/StateClient.h +++ b/include/faabric/state/StateClient.h @@ -17,8 +17,6 @@ class StateClient : public faabric::transport::SendMessageEndpoint const std::string key; const std::string host; - InMemoryStateRegistry& reg; - /* External state client API */ void pushChunks(const std::vector& chunks); diff --git a/src/state/CMakeLists.txt b/src/state/CMakeLists.txt index 18b2f73d4..ea7455753 100644 --- a/src/state/CMakeLists.txt +++ b/src/state/CMakeLists.txt @@ -1,7 +1,6 @@ file(GLOB HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/state/*.h") set(LIB_FILES - DummyStateServer.cpp InMemoryStateKeyValue.cpp InMemoryStateRegistry.cpp State.cpp diff --git a/src/state/DummyStateServer.cpp b/src/state/DummyStateServer.cpp deleted file mode 100644 index e45c4e6cb..000000000 --- a/src/state/DummyStateServer.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include -#include -#include -#include - -namespace faabric::state { -DummyStateServer::DummyStateServer() - : remoteState(LOCALHOST) - , stateServer(remoteState) -{} - -std::shared_ptr DummyStateServer::getRemoteKv() -{ - if (dummyData.empty()) { - return remoteState.getKV(dummyUser, dummyKey); - } else { - return remoteState.getKV(dummyUser, dummyKey, dummyData.size()); - } -} - -std::shared_ptr DummyStateServer::getLocalKv() -{ - if (dummyData.empty()) { - return state::getGlobalState().getKV(dummyUser, dummyKey); - } else { - return state::getGlobalState().getKV( - dummyUser, dummyKey, dummyData.size()); - } -} - -std::vector DummyStateServer::getRemoteKvValue() -{ - std::vector actual(dummyData.size(), 0); - getRemoteKv()->get(actual.data()); - return actual; -} - -std::vector DummyStateServer::getLocalKvValue() -{ - std::vector actual(dummyData.size(), 0); - getLocalKv()->get(actual.data()); - return actual; -} - -void DummyStateServer::start() -{ - // NOTE - We want to test the server being on a different host. - // To do this we run the server in a separate thread, forcing it to - // have a localhost IP, then the main thread is the "client" with a - // junk IP. - - // Override the host endpoint for the server thread. Must be localhost - faabric::util::getSystemConfig().endpointHost = LOCALHOST; - - // Master the dummy data in this thread - if (!dummyData.empty()) { - const std::shared_ptr& kv = - remoteState.getKV(dummyUser, dummyKey, dummyData.size()); - std::shared_ptr inMemKv = - std::static_pointer_cast(kv); - - // Check this kv "thinks" it's master - if (!inMemKv->isMaster()) { - SPDLOG_ERROR("Dummy state server not master for data"); - throw std::runtime_error("Remote state server failed"); - } - - // Set the data - kv->set(dummyData.data()); - SPDLOG_DEBUG( - "Finished setting master for test {}/{}", kv->user, kv->key); - } - - // Start the state server - // Note - by default the state server runs in a background thread - SPDLOG_DEBUG("Running state server"); - stateServer.start(); - - // Give it time to start - usleep(1000 * 1000); -} - -void DummyStateServer::stop() -{ - stateServer.stop(); -} -} diff --git a/src/state/InMemoryStateKeyValue.cpp b/src/state/InMemoryStateKeyValue.cpp index 0f975cb8d..a73703cdf 100644 --- a/src/state/InMemoryStateKeyValue.cpp +++ b/src/state/InMemoryStateKeyValue.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -64,7 +65,15 @@ InMemoryStateKeyValue::InMemoryStateKeyValue(const std::string& userIn, , status(masterIP == thisIP ? InMemoryStateKeyStatus::MASTER : InMemoryStateKeyStatus::NOT_MASTER) , stateRegistry(getInMemoryStateRegistry()) -{} +{ + SPDLOG_TRACE("Creating in-memory state key-value for {}/{} size {} (this " + "host {}, master {})", + userIn, + keyIn, + sizeIn, + thisIP, + masterIP); +} InMemoryStateKeyValue::InMemoryStateKeyValue(const std::string& userIn, const std::string& keyIn, diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index ddb0bce89..b95d7237a 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -12,7 +12,6 @@ StateClient::StateClient(const std::string& userIn, , user(userIn) , key(keyIn) , host(hostIn) - , reg(state::getInMemoryStateRegistry()) {} void StateClient::sendHeader(faabric::state::StateCalls call) diff --git a/tests/test/state/test_state.cpp b/tests/test/state/test_state.cpp index 9b2a21e5c..7308aef72 100644 --- a/tests/test/state/test_state.cpp +++ b/tests/test/state/test_state.cpp @@ -3,11 +3,12 @@ #include "faabric_utils.h" #include -#include #include #include +#include #include #include +#include #include #include @@ -15,21 +16,143 @@ using namespace faabric::state; namespace tests { + static int staticCount = 0; -static void setUpDummyServer(DummyStateServer& server, - const std::vector& values) +class StateServerTestFixture + : public StateTestFixture + , ConfTestFixture { - cleanFaabric(); + public: + // Set up a local server with a *different* state instance to the main + // thread. This way we can fake the master/ non-master setup + StateServerTestFixture() + : remoteState(LOCALHOST) + , stateServer(remoteState) + { + staticCount++; + const std::string stateKey = "state_key_" + std::to_string(staticCount); - staticCount++; - const std::string stateKey = "state_key_" + std::to_string(staticCount); + conf.stateMode = "inmemory"; - // Set state remotely - server.dummyUser = "demo"; - server.dummyKey = stateKey; - server.dummyData = values; -} + dummyUser = "demo"; + dummyKey = stateKey; + + // Start the state server + SPDLOG_DEBUG("Running state server"); + stateServer.start(); + + // Give it time to start + usleep(1000 * 1000); + } + + ~StateServerTestFixture() { stateServer.stop(); } + + void setDummyData(std::vector data) + { + dummyData = data; + + // Master the dummy data in the "remote" state + if (!dummyData.empty()) { + std::string originalHost = conf.endpointHost; + conf.endpointHost = LOCALHOST; + + const std::shared_ptr& kv = + remoteState.getKV(dummyUser, dummyKey, dummyData.size()); + + std::shared_ptr inMemKv = + std::static_pointer_cast(kv); + + // Check this kv "thinks" it's master + if (!inMemKv->isMaster()) { + SPDLOG_ERROR("Dummy state server not master for data"); + throw std::runtime_error("Dummy state server failed"); + } + + // Set the data + kv->set(dummyData.data()); + SPDLOG_DEBUG( + "Finished setting master for test {}/{}", kv->user, kv->key); + + conf.endpointHost = originalHost; + } + } + + std::shared_ptr getRemoteKv() + { + if (dummyData.empty()) { + return remoteState.getKV(dummyUser, dummyKey); + } + return remoteState.getKV(dummyUser, dummyKey, dummyData.size()); + } + + std::shared_ptr getLocalKv() + { + if (dummyData.empty()) { + return state::getGlobalState().getKV(dummyUser, dummyKey); + } + + return state::getGlobalState().getKV( + dummyUser, dummyKey, dummyData.size()); + } + + std::vector getRemoteKvValue() + { + std::vector actual(dummyData.size(), 0); + getRemoteKv()->get(actual.data()); + return actual; + } + + std::vector getLocalKvValue() + { + std::vector actual(dummyData.size(), 0); + getLocalKv()->get(actual.data()); + return actual; + } + + void checkPulling(bool doPull) + { + std::vector values = { 0, 1, 2, 3 }; + std::vector actual(values.size(), 0); + setDummyData(values); + + // Get, with optional pull + int nMessages = 1; + if (doPull) { + nMessages = 2; + } + + // Initial pull + const std::shared_ptr& localKv = getLocalKv(); + localKv->pull(); + + // Update directly on the remote KV + std::shared_ptr remoteKv = getRemoteKv(); + std::vector newValues = { 5, 5, 5, 5 }; + remoteKv->set(newValues.data()); + + if (doPull) { + // Check locak changed with another pull + localKv->pull(); + localKv->get(actual.data()); + REQUIRE(actual == newValues); + } else { + // Check local unchanged without another pull + localKv->get(actual.data()); + REQUIRE(actual == values); + } + } + + protected: + std::string dummyUser; + std::string dummyKey; + + state::State remoteState; + state::StateServer stateServer; + + private: + std::vector dummyData; +}; static std::shared_ptr setupKV(size_t size) { @@ -46,30 +169,29 @@ static std::shared_ptr setupKV(size_t size) return kv; } -TEST_CASE("Test in-memory state sizes", "[state]") +TEST_CASE_METHOD(StateTestFixture, "Test in-memory state sizes", "[state]") { - cleanFaabric(); - State& s = getGlobalState(); std::string user = "alpha"; std::string key = "beta"; // Empty should be none - size_t initialSize = s.getStateSize(user, key); + size_t initialSize = state.getStateSize(user, key); REQUIRE(initialSize == 0); // Set a value std::vector bytes = { 0, 1, 2, 3, 4 }; - auto kv = s.getKV(user, key, bytes.size()); + auto kv = state.getKV(user, key, bytes.size()); kv->set(bytes.data()); kv->pushFull(); // Get size - REQUIRE(s.getStateSize(user, key) == bytes.size()); + REQUIRE(state.getStateSize(user, key) == bytes.size()); } -TEST_CASE("Test simple in memory state get/set", "[state]") +TEST_CASE_METHOD(StateTestFixture, + "Test simple in memory state get/set", + "[state]") { - cleanFaabric(); auto kv = setupKV(5); std::vector actual(5); @@ -91,22 +213,20 @@ TEST_CASE("Test simple in memory state get/set", "[state]") REQUIRE(actual == values); } -TEST_CASE("Test in memory get/ set chunk", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test in memory get/ set chunk", + "[state]") { - DummyStateServer server; std::vector values = { 0, 0, 1, 1, 2, 2, 3, 3, 4, 4 }; - setUpDummyServer(server, values); - - // Start the server - server.start(); + setDummyData(values); // Get locally - std::vector actual = server.getLocalKvValue(); + std::vector actual = getLocalKvValue(); REQUIRE(actual == values); // Update a subsection std::vector update = { 8, 8, 8 }; - std::shared_ptr localKv = server.getLocalKv(); + std::shared_ptr localKv = getLocalKv(); localKv->setChunk(3, update.data(), 3); std::vector expected = { 0, 0, 1, 8, 8, 8, 3, 3, 4, 4 }; @@ -114,7 +234,7 @@ TEST_CASE("Test in memory get/ set chunk", "[state]") REQUIRE(actual == expected); // Check remote is unchanged - REQUIRE(server.getRemoteKvValue() == values); + REQUIRE(getRemoteKvValue() == values); // Try getting a chunk locally std::vector actualChunk(3); @@ -123,21 +243,18 @@ TEST_CASE("Test in memory get/ set chunk", "[state]") // Run push and check remote is updated localKv->pushPartial(); - REQUIRE(server.getRemoteKvValue() == expected); - - // Wait for server to finish - server.stop(); + REQUIRE(getRemoteKvValue() == expected); } -TEST_CASE("Test in memory marking chunks dirty", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test in memory marking chunks dirty", + "[state]") { std::vector values = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; - DummyStateServer server; - setUpDummyServer(server, values); - server.start(); + setDummyData(values); // Get pointer to local and update in memory only - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); uint8_t* ptr = localKv->get(); ptr[0] = 8; ptr[5] = 7; @@ -150,27 +267,23 @@ TEST_CASE("Test in memory marking chunks dirty", "[state]") values.at(0) = 8; // Check remote - REQUIRE(server.getRemoteKvValue() == values); + REQUIRE(getRemoteKvValue() == values); // Check local value has been set with the latest remote value std::vector actualMemory(ptr, ptr + values.size()); REQUIRE(actualMemory == values); - - server.stop(); } -TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test overlaps with multiple chunks dirty", + "[state]") { std::vector values = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, push, pull - server.start(); + setDummyData(values); // Get pointer to local data - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); uint8_t* statePtr = localKv->get(); // Update a couple of areas @@ -193,8 +306,7 @@ TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") // Update one non-overlapping value remotely std::vector directA = { 2, 2 }; - const std::shared_ptr& remoteKv = - server.getRemoteKv(); + const std::shared_ptr& remoteKv = getRemoteKv(); remoteKv->setChunk(6, directA.data(), 2); // Update one overlapping value remotely @@ -207,8 +319,8 @@ TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") std::vector expectedRemote = { 6, 6, 6, 6, 6, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - REQUIRE(server.getLocalKvValue() == expectedLocal); - REQUIRE(server.getRemoteKvValue() == expectedRemote); + REQUIRE(getLocalKvValue() == expectedLocal); + REQUIRE(getRemoteKvValue() == expectedRemote); // Push changes localKv->pushPartial(); @@ -217,23 +329,19 @@ TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") std::vector expected = { 6, 1, 2, 3, 6, 0, 2, 2, 0, 0, 4, 5, 0, 0, 7, 7, 7, 7, 0, 0 }; - REQUIRE(server.getLocalKvValue() == expected); - REQUIRE(server.getRemoteKvValue() == expected); - - server.stop(); + REQUIRE(getLocalKvValue() == expected); + REQUIRE(getRemoteKvValue() == expected); } -TEST_CASE("Test in memory partial update of doubles in state", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test in memory partial update of doubles in state", + "[state]") { long nDoubles = 20; long nBytes = nDoubles * sizeof(double); std::vector values(nBytes, 0); - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, push, pull - server.start(); + setDummyData(values); // Set up both with zeroes initially std::vector expected(nDoubles); @@ -242,7 +350,7 @@ TEST_CASE("Test in memory partial update of doubles in state", "[state]") memset(actualBytes.data(), 0, nBytes); // Update a value locally and flag dirty - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); auto actualPtr = reinterpret_cast(localKv->get()); auto expectedPtr = expected.data(); actualPtr[0] = 123.456; @@ -274,18 +382,17 @@ TEST_CASE("Test in memory partial update of doubles in state", "[state]") REQUIRE(expected == actualPostPush); // Check remote - std::vector remoteValue = server.getRemoteKvValue(); + std::vector remoteValue = getRemoteKvValue(); std::vector actualPostPushRemote(postPushDoublePtr, postPushDoublePtr + nDoubles); REQUIRE(expected == actualPostPushRemote); - - server.stop(); } -TEST_CASE("Test set chunk cannot be over the size of the allocated memory", - "[state]") +TEST_CASE_METHOD( + StateServerTestFixture, + "Test set chunk cannot be over the size of the allocated memory", + "[state]") { - cleanFaabric(); auto kv = setupKV(2); // Set a chunk offset @@ -295,29 +402,27 @@ TEST_CASE("Test set chunk cannot be over the size of the allocated memory", REQUIRE_THROWS(kv->setChunk(offset, update.data(), 3)); } -TEST_CASE("Test partially setting just first/ last element", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test partially setting just first/ last element", + "[state]") { std::vector values = { 0, 1, 2, 3, 4 }; - DummyStateServer server; - setUpDummyServer(server, values); - - // Only 3 push-partial messages as kv not fully allocated - server.start(); + setDummyData(values); // Update just the last element std::vector update = { 8 }; - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->setChunk(4, update.data(), 1); localKv->pushPartial(); std::vector expected = { 0, 1, 2, 3, 8 }; - REQUIRE(server.getRemoteKvValue() == expected); + REQUIRE(getRemoteKvValue() == expected); // Update the first localKv->setChunk(0, update.data(), 1); localKv->pushPartial(); expected = { 8, 1, 2, 3, 8 }; - REQUIRE(server.getRemoteKvValue() == expected); + REQUIRE(getRemoteKvValue() == expected); // Update two update = { 6 }; @@ -326,27 +431,22 @@ TEST_CASE("Test partially setting just first/ last element", "[state]") localKv->pushPartial(); expected = { 6, 1, 2, 3, 6 }; - REQUIRE(server.getRemoteKvValue() == expected); - - server.stop(); + REQUIRE(getRemoteKvValue() == expected); } -TEST_CASE("Test push partial with mask", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test push partial with mask", + "[state]") { size_t stateSize = 4 * sizeof(double); std::vector values(stateSize, 0); - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, full push, push partial - server.start(); + setDummyData(values); // Create another local KV of same size - State& state = getGlobalState(); auto maskKv = state.getKV("demo", "dummy_mask", stateSize); // Set up value locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); uint8_t* dataBytePtr = localKv->get(); auto dataDoublePtr = reinterpret_cast(dataBytePtr); std::vector initial = { 1.2345, 12.345, 987.6543, 10987654.3 }; @@ -358,7 +458,7 @@ TEST_CASE("Test push partial with mask", "[state]") localKv->pushFull(); // Check pushed remotely - std::vector actualBytes = server.getRemoteKvValue(); + std::vector actualBytes = getRemoteKvValue(); auto actualDoublePtr = reinterpret_cast(actualBytes.data()); std::vector actualDoubles(actualDoublePtr, actualDoublePtr + 4); @@ -387,100 +487,58 @@ TEST_CASE("Test push partial with mask", "[state]") }; // Check remotely - std::vector actualValue2 = server.getRemoteKvValue(); + std::vector actualValue2 = getRemoteKvValue(); auto actualDoublesPtr = reinterpret_cast(actualValue2.data()); std::vector actualDoubles2(actualDoublesPtr, actualDoublesPtr + 4); REQUIRE(actualDoubles2 == expected); - - server.stop(); } -void checkPulling(bool doPull) -{ - std::vector values = { 0, 1, 2, 3 }; - std::vector actual(values.size(), 0); - - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, with optional pull - int nMessages = 1; - if (doPull) { - nMessages = 2; - } - - server.start(); - - // Initial pull - const std::shared_ptr& localKv = server.getLocalKv(); - localKv->pull(); - - // Update directly on the remote KV - std::shared_ptr remoteKv = server.getRemoteKv(); - std::vector newValues = { 5, 5, 5, 5 }; - remoteKv->set(newValues.data()); - - if (doPull) { - // Check locak changed with another pull - localKv->pull(); - localKv->get(actual.data()); - REQUIRE(actual == newValues); - } else { - // Check local unchanged without another pull - localKv->get(actual.data()); - REQUIRE(actual == values); - } - - server.stop(); -} - -TEST_CASE("Test updates pulled from remote", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test updates pulled from remote", + "[state]") { checkPulling(true); } -TEST_CASE("Test updates not pulled from remote without call to pull", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test updates not pulled from remote without call to pull", + "[state]") { checkPulling(false); } -TEST_CASE("Test pushing only happens when dirty", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test pushing only happens when dirty", + "[state]") { std::vector values = { 0, 1, 2, 3 }; - DummyStateServer server; - setUpDummyServer(server, values); - - server.start(); + setDummyData(values); // Pull locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->pull(); // Change remote directly - std::shared_ptr remoteKv = server.getRemoteKv(); + std::shared_ptr remoteKv = getRemoteKv(); std::vector newValues = { 3, 4, 5, 6 }; remoteKv->set(newValues.data()); // Push and make sure remote has not changed without local being dirty localKv->pushFull(); - REQUIRE(server.getRemoteKvValue() == newValues); + REQUIRE(getRemoteKvValue() == newValues); // Now change locally and check push happens std::vector newValues2 = { 7, 7, 7, 7 }; localKv->set(newValues2.data()); localKv->pushFull(); - REQUIRE(server.getRemoteKvValue() == newValues2); - - server.stop(); + REQUIRE(getRemoteKvValue() == newValues2); } -TEST_CASE("Test mapping shared memory", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping shared memory", + "[state]") { - cleanFaabric(); - - // Set up the KV - State& s = getGlobalState(); - auto kv = s.getKV("demo", "mapping_test", 5); + auto kv = state.getKV("demo", "mapping_test", 5); std::vector value = { 0, 1, 2, 3, 4 }; kv->set(value.data()); @@ -529,19 +587,16 @@ TEST_CASE("Test mapping shared memory", "[state]") } } -TEST_CASE("Test mapping shared memory does not pull", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping shared memory does not pull", + "[state]") { std::vector values = { 0, 1, 2, 3, 4 }; std::vector zeroes(values.size(), 0); - DummyStateServer server; - setUpDummyServer(server, values); - - // One implicit pull - server.start(); + setDummyData(values); // Write value to remote - const std::shared_ptr& remoteKv = - server.getRemoteKv(); + const std::shared_ptr& remoteKv = getRemoteKv(); remoteKv->set(values.data()); // Map the KV locally @@ -551,7 +606,7 @@ TEST_CASE("Test mapping shared memory does not pull", "[state]") MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->mapSharedMemory(mappedRegion, 0, 1); // Check it's zeroed @@ -564,18 +619,17 @@ TEST_CASE("Test mapping shared memory does not pull", "[state]") std::vector actualValueAfterGet(byteRegion, byteRegion + values.size()); REQUIRE(actualValueAfterGet == values); - - server.stop(); } -TEST_CASE("Test mapping small shared memory offsets", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping small shared memory offsets", + "[state]") { - cleanFaabric(); - // Set up the KV std::vector values = { 0, 1, 2, 3, 4, 5, 6 }; - State& s = getGlobalState(); - auto kv = s.getKV("demo", "mapping_small_test", values.size()); + setDummyData(values); + + auto kv = state.getKV("demo", "mapping_small_test", values.size()); kv->set(values.data()); // Map a single page of host memory @@ -622,7 +676,9 @@ TEST_CASE("Test mapping small shared memory offsets", "[state]") REQUIRE(chunkB[1] == 1); } -TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping bigger uninitialized shared memory offsets", + "[state]") { // Define some mapping larger than a page size_t mappingSize = 3 * faabric::util::HOST_PAGE_SIZE; @@ -630,13 +686,7 @@ TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") // Set up a larger total value full of ones size_t totalSize = (10 * faabric::util::HOST_PAGE_SIZE) + 15; std::vector values(totalSize, 1); - - // Set up remote server - DummyStateServer server; - setUpDummyServer(server, values); - - // Expecting two implicit pulls - server.start(); + setDummyData(values); // Map a couple of chunks in host memory (as would be done by the wasm // module) @@ -646,7 +696,7 @@ TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") nullptr, mappingSize, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); // Do the mapping - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->mapSharedMemory(mappedRegionA, 6, 3); localKv->mapSharedMemory(mappedRegionB, 2, 3); @@ -668,36 +718,28 @@ TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") REQUIRE(chunkB[0] == 1); REQUIRE(chunkA[5] == 5); REQUIRE(chunkB[9] == 9); - - server.stop(); } -TEST_CASE("Test deletion", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, "Test deletion", "[state]") { std::vector values = { 0, 1, 2, 3, 4 }; - DummyStateServer server; - setUpDummyServer(server, values); - - // One pull, one deletion - server.start(); + setDummyData(values); // Check data remotely and locally - REQUIRE(server.getLocalKvValue() == values); - REQUIRE(server.getRemoteKvValue() == values); + REQUIRE(getLocalKvValue() == values); + REQUIRE(getRemoteKvValue() == values); // Delete from state - getGlobalState().deleteKV(server.dummyUser, server.dummyKey); + getGlobalState().deleteKV(dummyUser, dummyKey); // Check it's gone - REQUIRE(server.remoteState.getKVCount() == 0); - - server.stop(); + REQUIRE(remoteState.getKVCount() == 0); } -TEST_CASE("Test appended state with KV", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test appended state with KV", + "[state]") { - cleanFaabric(); - // Set up the KV State& s = getGlobalState(); std::shared_ptr kv; @@ -750,25 +792,19 @@ TEST_CASE("Test appended state with KV", "[state]") REQUIRE(actualAfterClear == expectedAfterClear); } -TEST_CASE("Test remote appended state", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test remote appended state", + "[state]") { - DummyStateServer server; - std::vector empty; - setUpDummyServer(server, empty); - - // One appends, two retrievals, one clear - server.start(); - std::vector valuesA = { 0, 1, 2, 3, 4 }; std::vector valuesB = { 3, 3, 5, 5 }; // Append some data remotely - const std::shared_ptr& remoteKv = - server.getRemoteKv(); + const std::shared_ptr& remoteKv = getRemoteKv(); remoteKv->append(valuesA.data(), valuesA.size()); // Append locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); auto localInMemKv = std::static_pointer_cast(localKv); REQUIRE(!localInMemKv->isMaster()); @@ -797,28 +833,24 @@ TEST_CASE("Test remote appended state", "[state]") localKv->getAppended( actualLocalAfterClear.data(), actualLocalAfterClear.size(), 1); REQUIRE(actualLocalAfterClear == valuesB); - - server.stop(); } -TEST_CASE("Test pushing pulling large state", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test pushing pulling large state", + "[state]") { size_t valueSize = (3 * STATE_STREAMING_CHUNK_SIZE) + 123; std::vector valuesA(valueSize, 1); std::vector valuesB(valueSize, 2); - DummyStateServer server; - setUpDummyServer(server, valuesA); - - // One pull, one push - server.start(); + setDummyData(valuesA); // Pull locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->pull(); // Check equality - const std::vector& actualLocal = server.getLocalKvValue(); + const std::vector& actualLocal = getLocalKvValue(); REQUIRE(actualLocal == valuesA); // Push @@ -826,28 +858,24 @@ TEST_CASE("Test pushing pulling large state", "[state]") localKv->pushFull(); // Check equality of remote - const std::vector& actualRemote = server.getRemoteKvValue(); + const std::vector& actualRemote = getRemoteKvValue(); REQUIRE(actualRemote == valuesB); - - server.stop(); } -TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test pushing pulling chunks over multiple requests", + "[state]") { // Set up a chunk of state size_t chunkSize = 1024; size_t valueSize = 10 * chunkSize + 123; std::vector values(valueSize, 1); - DummyStateServer server; - setUpDummyServer(server, values); - - // Two chunk pulls, one push partial - server.start(); + setDummyData(values); // Set a chunk in the remote value size_t offsetA = 3 * chunkSize + 5; std::vector segA = { 4, 4 }; - std::shared_ptr remoteKv = server.getRemoteKv(); + std::shared_ptr remoteKv = getRemoteKv(); remoteKv->setChunk(offsetA, segA.data(), segA.size()); // Set a chunk at the end of the remote value @@ -856,7 +884,7 @@ TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") remoteKv->setChunk(offsetB, segB.data(), segB.size()); // Get only these chunks locally - std::shared_ptr localKv = server.getLocalKv(); + std::shared_ptr localKv = getLocalKv(); std::vector actualSegA(segA.size(), 0); localKv->getChunk(offsetA, actualSegA.data(), segA.size()); REQUIRE(actualSegA == segA); @@ -879,7 +907,7 @@ TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") localKv->pushPartial(); // Check the chunks in the remote value - std::vector actualAfterPush = server.getRemoteKvValue(); + std::vector actualAfterPush = getRemoteKvValue(); std::vector actualSegC(actualAfterPush.begin() + offsetC, actualAfterPush.begin() + offsetC + segC.size()); @@ -889,21 +917,17 @@ TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") REQUIRE(actualSegC == segC); REQUIRE(actualSegD == segD); - - server.stop(); } -TEST_CASE("Test pulling disjoint chunks of the same value which share pages", - "[state]") +TEST_CASE_METHOD( + StateServerTestFixture, + "Test pulling disjoint chunks of the same value which share pages", + "[state]") { // Set up state size_t valueSize = 20 * faabric::util::HOST_PAGE_SIZE + 123; std::vector values(valueSize, 1); - DummyStateServer server; - setUpDummyServer(server, values); - - // Expect two chunk pulls - server.start(); + setDummyData(values); // Set up two chunks both from the same page of memory but not overlapping long offsetA = 2 * faabric::util::HOST_PAGE_SIZE + 10; @@ -916,7 +940,7 @@ TEST_CASE("Test pulling disjoint chunks of the same value which share pages", std::vector actualB(lenA, 0); std::vector expectedB(lenA, 1); - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->getChunk(offsetA, actualA.data(), lenA); localKv->getChunk(offsetB, actualB.data(), lenB); @@ -927,7 +951,49 @@ TEST_CASE("Test pulling disjoint chunks of the same value which share pages", // Check both chunks are as expected REQUIRE(actualA == expectedA); REQUIRE(actualB == expectedB); +} + +TEST_CASE_METHOD(StateServerTestFixture, + "Test state server as remote master", + "[state]") +{ + REQUIRE(state.getKVCount() == 0); + + const char* userA = "foo"; + const char* keyA = "bar"; + std::vector dataA = { 0, 1, 2, 3, 4, 5, 6, 7 }; + std::vector dataB = { 7, 6, 5, 4, 3, 2, 1, 0 }; + + dummyUser = userA; + dummyKey = keyA; + setDummyData(dataA); + + // Get the state size before accessing the value locally + size_t actualSize = state.getStateSize(userA, keyA); + REQUIRE(actualSize == dataA.size()); - server.stop(); + // Access locally and check not master + auto localKv = getLocalKv(); + auto localStateKv = + std::static_pointer_cast(localKv); + REQUIRE(!localStateKv->isMaster()); + + // Set the state locally and check + const std::shared_ptr& kv = + state.getKV(userA, keyA, dataA.size()); + kv->set(dataB.data()); + + std::vector actualLocal(dataA.size(), 0); + kv->get(actualLocal.data()); + REQUIRE(actualLocal == dataB); + + // Check it's not changed remotely + std::vector actualRemote = getRemoteKvValue(); + REQUIRE(actualRemote == dataA); + + // Push and check remote is updated + kv->pushFull(); + actualRemote = getRemoteKvValue(); + REQUIRE(actualRemote == dataB); } } diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index b82b77cf9..73f02ab57 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -2,7 +2,6 @@ #include "faabric_utils.h" -#include #include #include #include @@ -14,61 +13,66 @@ using namespace faabric::state; namespace tests { -static const char* userA = "foo"; -static const char* keyA = "bar"; -static const char* keyB = "baz"; -static std::vector dataA = { 0, 1, 2, 3, 4, 5, 6, 7 }; -static std::vector dataB = { 7, 6, 5, 4, 3, 2, 1, 0 }; -static std::string originalStateMode; - -static void setUpStateMode() +class SimpleStateServerTestFixture + : public StateTestFixture + , public ConfTestFixture { - faabric::util::SystemConfig& conf = faabric::util::getSystemConfig(); - originalStateMode = conf.stateMode; - cleanFaabric(); -} + public: + SimpleStateServerTestFixture() + : server(faabric::state::getGlobalState()) + , dataA({ 0, 1, 2, 3, 4, 5, 6, 7 }) + , dataB({ 7, 6, 5, 4, 3, 2, 1, 0 }) + { + conf.stateMode = "inmemory"; -static void resetStateMode() -{ - faabric::util::getSystemConfig().stateMode = originalStateMode; -} + server.start(); + usleep(1000 * 100); + } -std::shared_ptr getKv(const std::string& user, - const std::string& key, - size_t stateSize) -{ - State& state = getGlobalState(); + ~SimpleStateServerTestFixture() { server.stop(); } - std::shared_ptr localKv = state.getKV(user, key, stateSize); - std::shared_ptr inMemLocalKv = - std::static_pointer_cast(localKv); + std::shared_ptr getKv(const std::string& user, + const std::string& key, + size_t stateSize) + { + std::shared_ptr localKv = + state.getKV(user, key, stateSize); - return inMemLocalKv; -} + std::shared_ptr inMemLocalKv = + std::static_pointer_cast(localKv); -TEST_CASE("Test request/ response", "[state]") -{ - setUpStateMode(); + return inMemLocalKv; + } + + protected: + StateServer server; + + const char* userA = "foo"; + const char* keyA = "bar"; + const char* keyB = "baz"; - // Create server - StateServer s(getGlobalState()); - s.start(); - usleep(1000 * 100); + std::vector dataA; + std::vector dataB; +}; +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test state request/ response", + "[state]") +{ std::vector actual(dataA.size(), 0); // Prepare a key-value with data auto kvA = getKv(userA, keyA, dataA.size()); kvA->set(dataA.data()); - // Prepare a key-value with no data + // Prepare a key-value with no data (but a size) auto kvB = getKv(userA, keyB, dataA.size()); // Prepare a key-value with same key but different data (for pushing) - std::string thisIP = faabric::util::getSystemConfig().endpointHost; + std::string thisHost = faabric::util::getSystemConfig().endpointHost; auto kvADuplicate = - InMemoryStateKeyValue(userA, keyA, dataB.size(), thisIP); + InMemoryStateKeyValue(userA, keyA, dataB.size(), thisHost); kvADuplicate.set(dataB.data()); StateClient client(userA, keyA, DEFAULT_STATE_HOST); @@ -133,16 +137,12 @@ TEST_CASE("Test request/ response", "[state]") REQUIRE(actualAppended == expected); } - - s.stop(); - - resetStateMode(); } -TEST_CASE("Test local-only push/ pull", "[state]") +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test local-only push/ pull", + "[state]") { - setUpStateMode(); - // Create a key-value locally auto localKv = getKv(userA, keyA, dataA.size()); @@ -156,14 +156,12 @@ TEST_CASE("Test local-only push/ pull", "[state]") // Check that we get the expected size State& state = getGlobalState(); REQUIRE(state.getStateSize(userA, keyA) == dataA.size()); - - resetStateMode(); } -TEST_CASE("Test local-only append", "[state]") +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test local-only append", + "[state]") { - setUpStateMode(); - // Append a few chunks std::vector chunkA = { 1, 1 }; std::vector chunkB = { 2, 2, 2 }; @@ -184,59 +182,12 @@ TEST_CASE("Test local-only append", "[state]") kv->getAppended(actual.data(), actual.size(), 3); REQUIRE(actual == expected); - - resetStateMode(); } -TEST_CASE("Test state server as remote master", "[state]") +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test state server with local master", + "[state]") { - setUpStateMode(); - - State& globalState = getGlobalState(); - REQUIRE(globalState.getKVCount() == 0); - - DummyStateServer server; - server.dummyData = dataA; - server.dummyUser = userA; - server.dummyKey = keyA; - server.start(); - - // Get the state size before accessing the value locally - size_t actualSize = globalState.getStateSize(userA, keyA); - REQUIRE(actualSize == dataA.size()); - - // Access locally and check not master - auto localKv = getKv(userA, keyA, dataA.size()); - REQUIRE(!localKv->isMaster()); - - // Set the state locally and check - State& state = getGlobalState(); - const std::shared_ptr& kv = - state.getKV(userA, keyA, dataA.size()); - kv->set(dataB.data()); - - std::vector actualLocal(dataA.size(), 0); - kv->get(actualLocal.data()); - REQUIRE(actualLocal == dataB); - - // Check it's not changed remotely - std::vector actualRemote = server.getRemoteKvValue(); - REQUIRE(actualRemote == dataA); - - // Push and check remote is updated - kv->pushFull(); - actualRemote = server.getRemoteKvValue(); - REQUIRE(actualRemote == dataB); - - server.stop(); - - resetStateMode(); -} - -TEST_CASE("Test state server with local master", "[state]") -{ - setUpStateMode(); - // Set and push auto localKv = getKv(userA, keyA, dataA.size()); localKv->set(dataA.data()); @@ -247,7 +198,6 @@ TEST_CASE("Test state server with local master", "[state]") localKv->set(dataB.data()); // Pull - State& state = getGlobalState(); const std::shared_ptr& kv = state.getKV(userA, keyA, dataA.size()); kv->pull(); @@ -255,7 +205,5 @@ TEST_CASE("Test state server with local master", "[state]") // Check it's still the same locally set value std::vector actual(kv->get(), kv->get() + dataA.size()); REQUIRE(actual == dataB); - - resetStateMode(); } } diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 1568b3cf4..ef1fc3e6b 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include From 97a1e38b35f7f46d43905a56139423247172a18f Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 23 Jun 2021 16:58:01 +0000 Subject: [PATCH 29/66] Remove custom stop methods on snapshot and function call servers --- include/faabric/scheduler/FunctionCallServer.h | 2 -- include/faabric/scheduler/Scheduler.h | 4 ---- include/faabric/scheduler/SnapshotServer.h | 3 --- include/faabric/transport/MessageEndpoint.h | 2 +- src/scheduler/FunctionCallServer.cpp | 9 --------- src/scheduler/Scheduler.cpp | 14 ++------------ src/scheduler/SnapshotServer.cpp | 9 --------- src/transport/MessageEndpoint.cpp | 2 +- .../test/scheduler/test_snapshot_client_server.cpp | 2 +- 9 files changed, 5 insertions(+), 42 deletions(-) diff --git a/include/faabric/scheduler/FunctionCallServer.h b/include/faabric/scheduler/FunctionCallServer.h index bf18a112a..06fa48769 100644 --- a/include/faabric/scheduler/FunctionCallServer.h +++ b/include/faabric/scheduler/FunctionCallServer.h @@ -12,8 +12,6 @@ class FunctionCallServer final public: FunctionCallServer(); - void stop() override; - private: Scheduler& scheduler; diff --git a/include/faabric/scheduler/Scheduler.h b/include/faabric/scheduler/Scheduler.h index 725feb224..cc9dee83e 100644 --- a/include/faabric/scheduler/Scheduler.h +++ b/include/faabric/scheduler/Scheduler.h @@ -168,10 +168,6 @@ class Scheduler ExecGraph getFunctionExecGraph(unsigned int msgId); - void closeFunctionCallClients(); - - void closeSnapshotClients(); - private: std::string thisHost; diff --git a/include/faabric/scheduler/SnapshotServer.h b/include/faabric/scheduler/SnapshotServer.h index 326b0c7a2..214e57800 100644 --- a/include/faabric/scheduler/SnapshotServer.h +++ b/include/faabric/scheduler/SnapshotServer.h @@ -10,9 +10,6 @@ class SnapshotServer final : public faabric::transport::MessageEndpointServer { public: SnapshotServer(); - - void stop() override; - protected: void doRecv(faabric::transport::Message& header, faabric::transport::Message& body) override; diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 4470bdbb6..8a2361200 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -21,7 +21,7 @@ // How long undelivered messages will hang around when the socket is closed, // which also determines how long the context will hang for when closing if // things haven't yet completed (usually only when there's an error). -#define LINGER_MS 1000 +#define LINGER_MS 100 namespace faabric::transport { diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 05a036428..716756036 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -13,15 +13,6 @@ FunctionCallServer::FunctionCallServer() , scheduler(getScheduler()) {} -void FunctionCallServer::stop() -{ - // Close the dangling scheduler endpoints - faabric::scheduler::getScheduler().closeFunctionCallClients(); - - // Call the parent stop - MessageEndpointServer::stop(); -} - void FunctionCallServer::doRecv(faabric::transport::Message& header, faabric::transport::Message& body) { diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 328fca9dc..866ca51a0 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -61,16 +61,6 @@ void Scheduler::addHostToGlobalSet() redis.sadd(AVAILABLE_HOST_SET, thisHost); } -void Scheduler::closeFunctionCallClients() -{ - functionCallClients.clear(); -} - -void Scheduler::closeSnapshotClients() -{ - snapshotClients.clear(); -} - void Scheduler::reset() { // Shut down all Executors @@ -104,8 +94,8 @@ void Scheduler::reset() recordedMessagesLocal.clear(); recordedMessagesShared.clear(); - closeFunctionCallClients(); - closeSnapshotClients(); + functionCallClients.clear(); + snapshotClients.clear(); } void Scheduler::shutdown() diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 2d127ba13..d16349ead 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -15,15 +15,6 @@ SnapshotServer::SnapshotServer() : faabric::transport::MessageEndpointServer(SNAPSHOT_PORT) {} -void SnapshotServer::stop() -{ - // Close the dangling clients - faabric::scheduler::getScheduler().closeSnapshotClients(); - - // Call the parent stop - MessageEndpointServer::stop(); -} - void SnapshotServer::doRecv(faabric::transport::Message& header, faabric::transport::Message& body) { diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 7b20a15d5..1bd7ce493 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -42,7 +42,7 @@ MessageEndpoint::MessageEndpoint(zmq::socket_type socketTypeIn, zmq::socket_t(*getGlobalMessageContext(), socketType), "socket_create") - // Set socket timeouts + // Check and set socket timeout if (timeoutMs <= 0) { SPDLOG_ERROR("Setting invalid timeout of {}", timeoutMs); throw std::runtime_error("Setting invalid timeout"); diff --git a/tests/test/scheduler/test_snapshot_client_server.cpp b/tests/test/scheduler/test_snapshot_client_server.cpp index 6db0d4054..476e20140 100644 --- a/tests/test/scheduler/test_snapshot_client_server.cpp +++ b/tests/test/scheduler/test_snapshot_client_server.cpp @@ -64,7 +64,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, usleep(1000 * 500); - // Check snapshots created in regsitry + // Check snapshots created in registry REQUIRE(reg.getSnapshotCount() == 2); const faabric::util::SnapshotData& actualA = reg.getSnapshot(snapKeyA); const faabric::util::SnapshotData& actualB = reg.getSnapshot(snapKeyB); From e00978a664ed53449ec87ddfb1350e6b0410379a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Thu, 24 Jun 2021 06:29:12 +0000 Subject: [PATCH 30/66] Formatting --- include/faabric/scheduler/SnapshotServer.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/faabric/scheduler/SnapshotServer.h b/include/faabric/scheduler/SnapshotServer.h index 214e57800..b72a4a8df 100644 --- a/include/faabric/scheduler/SnapshotServer.h +++ b/include/faabric/scheduler/SnapshotServer.h @@ -10,6 +10,7 @@ class SnapshotServer final : public faabric::transport::MessageEndpointServer { public: SnapshotServer(); + protected: void doRecv(faabric::transport::Message& header, faabric::transport::Message& body) override; From 80754d47e9b2c8c18a599aa498ccb2c648eae998 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Thu, 24 Jun 2021 09:14:58 +0000 Subject: [PATCH 31/66] Starting req/rep refactor --- .../faabric/scheduler/FunctionCallClient.h | 7 +- .../faabric/scheduler/FunctionCallServer.h | 7 +- include/faabric/scheduler/SnapshotServer.h | 14 +- include/faabric/state/StateServer.h | 7 +- include/faabric/transport/MessageEndpoint.h | 88 ++++--- .../faabric/transport/MessageEndpointClient.h | 31 +++ .../faabric/transport/MessageEndpointServer.h | 25 +- src/scheduler/FunctionCallClient.cpp | 20 +- src/scheduler/SnapshotServer.cpp | 42 ++-- src/transport/MessageEndpoint.cpp | 216 ++++++++++++------ src/transport/MessageEndpointClient.cpp | 38 +++ src/transport/MessageEndpointServer.cpp | 124 ++++++---- 12 files changed, 423 insertions(+), 196 deletions(-) create mode 100644 include/faabric/transport/MessageEndpointClient.h create mode 100644 src/transport/MessageEndpointClient.cpp diff --git a/include/faabric/scheduler/FunctionCallClient.h b/include/faabric/scheduler/FunctionCallClient.h index 44d969803..2dc4e4de8 100644 --- a/include/faabric/scheduler/FunctionCallClient.h +++ b/include/faabric/scheduler/FunctionCallClient.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace faabric::scheduler { @@ -32,7 +33,7 @@ void clearMockRequests(); // ----------------------------------- // Message client // ----------------------------------- -class FunctionCallClient : public faabric::transport::SendMessageEndpoint +class FunctionCallClient : public faabric::transport::MessageEndpointClient { public: explicit FunctionCallClient(const std::string& hostIn); @@ -49,6 +50,10 @@ class FunctionCallClient : public faabric::transport::SendMessageEndpoint void unregister(const faabric::UnregisterRequest& req); private: + faabric::transport::AsyncSendMessageEndpoint asyncEndpoint; + + faabric::transport::SyncSendMessageEndpoint syncEndpoint; + void sendHeader(faabric::scheduler::FunctionCalls call); }; } diff --git a/include/faabric/scheduler/FunctionCallServer.h b/include/faabric/scheduler/FunctionCallServer.h index 06fa48769..c8d507818 100644 --- a/include/faabric/scheduler/FunctionCallServer.h +++ b/include/faabric/scheduler/FunctionCallServer.h @@ -15,8 +15,11 @@ class FunctionCallServer final private: Scheduler& scheduler; - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override; + + faabric::Message doSyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override; /* Function call server API */ diff --git a/include/faabric/scheduler/SnapshotServer.h b/include/faabric/scheduler/SnapshotServer.h index b72a4a8df..90253b40c 100644 --- a/include/faabric/scheduler/SnapshotServer.h +++ b/include/faabric/scheduler/SnapshotServer.h @@ -12,16 +12,22 @@ class SnapshotServer final : public faabric::transport::MessageEndpointServer SnapshotServer(); protected: - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override; + + std::unique_ptr doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) override; /* Snapshot server API */ - void recvPushSnapshot(faabric::transport::Message& msg); + std::unique_ptr recvPushSnapshot( + faabric::transport::Message& msg); void recvDeleteSnapshot(faabric::transport::Message& msg); - void recvPushSnapshotDiffs(faabric::transport::Message& msg); + std::unique_ptr recvPushSnapshotDiffs( + faabric::transport::Message& msg); void recvThreadResult(faabric::transport::Message& msg); diff --git a/include/faabric/state/StateServer.h b/include/faabric/state/StateServer.h index 29508b88d..e0a24e20e 100644 --- a/include/faabric/state/StateServer.h +++ b/include/faabric/state/StateServer.h @@ -13,8 +13,11 @@ class StateServer final : public faabric::transport::MessageEndpointServer private: State& state; - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override; + + faabric::Message doSyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override; /* State server API */ diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 8a2361200..bc381cd73 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -25,20 +25,15 @@ namespace faabric::transport { -/* Wrapper arround zmq::socket_t - * - * Thread-unsafe socket-like object. MUST be open-ed and close-ed from the - * _same_ thread. For a proto://host:pair triple, one socket may bind, and all - * the rest must connect. Order does not matter. Sockets either send (PUSH) - * or recv (PULL) data. +/* + * Note, that sockets must be open-ed and close-ed from the _same_ thread. For a + * proto://host:pair triple, one socket may bind, and all the rest must connect. + * Order does not matter. */ class MessageEndpoint { public: - MessageEndpoint(zmq::socket_type socketTypeIn, - const std::string& hostIn, - int portIn, - int timeoutMsIn); + MessageEndpoint(const std::string& hostIn, int portIn, int timeoutMsIn); // Delete assignment and copy-constructor as we need to be very careful with // socping and same-thread instantiation @@ -51,7 +46,6 @@ class MessageEndpoint int getPort(); protected: - const zmq::socket_type socketType; const std::string host; const int port; const std::string address; @@ -59,42 +53,74 @@ class MessageEndpoint const std::thread::id tid; const int id; - zmq::socket_t socket; -}; + zmq::socket_t setUpSocket(zmq::socket_type socketType, int socketPort); + + void doSend(zmq::socket_t& socket, + uint8_t* data, + size_t dataSize, + bool more); + + Message doRecv(zmq::socket_t& socket, int size = 0); -/* Send and Recv Message Endpoints */ + Message recvBuffer(zmq::socket_t& socket, int size); -class SendMessageEndpoint : public MessageEndpoint + Message recvNoBuffer(zmq::socket_t& socket); +}; + +class AsyncSendMessageEndpoint : public MessageEndpoint { public: - SendMessageEndpoint(const std::string& hostIn, - int portIn, - int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); + AsyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); - Message awaitResponse(); + private: + zmq::socket_t pushSocket; }; -class RecvMessageEndpoint : public MessageEndpoint +class SyncSendMessageEndpoint : public MessageEndpoint { public: - RecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + SyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); - Message recv(int size = 0); + void sendHeader(int header); - /* Send response to the client - * - * Send a one-off response to a client identified by host:port pair. - * Together with a blocking recv at the client side, this - * method can be used to achieve synchronous client-server communication. - */ - void sendResponse(uint8_t* data, int size, const std::string& returnHost); + Message sendAwaitResponse(uint8_t* serialisedMsg, + size_t msgSize, + bool more = false); private: - Message recvBuffer(int size); + zmq::socket_t reqSocket; +}; + +class AsyncRecvMessageEndpoint : public MessageEndpoint +{ + public: + AsyncRecvMessageEndpoint(int portIn, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - Message recvNoBuffer(); + Message recv(int size = 0); + + private: + zmq::socket_t pullSocket; +}; + +class SyncRecvMessageEndpoint : public MessageEndpoint +{ + public: + SyncRecvMessageEndpoint(int portIn, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + Message recv(int size = 0); + + void sendResponse(uint8_t* data, int size); + + private: + zmq::socket_t repSocket; }; class MessageTimeoutException : public faabric::util::FaabricException diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h new file mode 100644 index 000000000..dfff207c5 --- /dev/null +++ b/include/faabric/transport/MessageEndpointClient.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +namespace faabric::transport { +class MessageEndpointClient +{ + public: + MessageEndpointClient(std::string hostIn, int portIn); + + protected: + const std::string host; + + void asyncSend(int header, std::unique_ptr msg); + + void syncSend(int header, + std::unique_ptr msg, + std::unique_ptr response); + + private: + const int asyncPort; + + const int syncPort; + + faabric::transport::AsyncSendMessageEndpoint asyncEndpoint; + + faabric::transport::SyncSendMessageEndpoint syncEndpoint; +}; +} diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index da250f1ee..0f81a0939 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -1,12 +1,11 @@ #pragma once +#include #include #include #include -#define ENDPOINT_SERVER_SHUTDOWN -1 - namespace faabric::transport { /* Server handling a long-running 0MQ socket * @@ -25,22 +24,28 @@ class MessageEndpointServer virtual void stop(); protected: - std::unique_ptr recvEndpoint = nullptr; - - bool recv(); - /* Template function to handle message reception * * A message endpoint server in faabric expects each communication to be * a multi-part 0MQ message. One message containing the header, and another * one with the body. Note that 0MQ _guarantees_ in-order delivery. */ - virtual void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) = 0; + virtual void doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) = 0; + + virtual std::unique_ptr doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) = 0; + + void sendSyncResponse(); private: - const int port; + const int asyncPort; + + const int syncPort; + + std::thread asyncThread; - std::thread servingThread; + std::thread syncThread; }; } diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 06b755238..df73ff827 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -81,29 +81,19 @@ void clearMockRequests() // Message Client // ----------------------------------- FunctionCallClient::FunctionCallClient(const std::string& hostIn) - : faabric::transport::SendMessageEndpoint(hostIn, FUNCTION_CALL_PORT) + : faabric::transport::MessageEndpointClient(hostIn, FUNCTION_CALL_PORT) {} -void FunctionCallClient::sendHeader(faabric::scheduler::FunctionCalls call) -{ - uint8_t header = static_cast(call); - send(&header, sizeof(header), true); -} - void FunctionCallClient::sendFlush() { - faabric::ResponseRequest call; if (faabric::util::isMockMode()) { + faabric::ResponseRequest call; faabric::util::UniqueLock lock(mockMutex); flushCalls.emplace_back(host, call); } else { - // Prepare the message body - call.set_returnhost(faabric::util::getSystemConfig().endpointHost); - - SEND_MESSAGE(faabric::scheduler::FunctionCalls::Flush, call); - - // Await the response - awaitResponse(); + auto call = std::make_unique(); + auto resp = std::make_unique(); + syncSend(faabric::scheduler::FunctionCalls::Flush, call, resp); } } diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index d16349ead..c1d77b7f2 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -15,18 +15,12 @@ SnapshotServer::SnapshotServer() : faabric::transport::MessageEndpointServer(SNAPSHOT_PORT) {} -void SnapshotServer::doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +void SnapshotServer::doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) { assert(header.size() == sizeof(uint8_t)); uint8_t call = static_cast(*header.data()); switch (call) { - case faabric::scheduler::SnapshotCalls::PushSnapshot: - this->recvPushSnapshot(body); - break; - case faabric::scheduler::SnapshotCalls::PushSnapshotDiffs: - this->recvPushSnapshotDiffs(body); - break; case faabric::scheduler::SnapshotCalls::DeleteSnapshot: this->recvDeleteSnapshot(body); break; @@ -35,11 +29,29 @@ void SnapshotServer::doRecv(faabric::transport::Message& header, break; default: throw std::runtime_error( - fmt::format("Unrecognized call header: {}", call)); + fmt::format("Unrecognized async call header: {}", call)); } } -void SnapshotServer::recvPushSnapshot(faabric::transport::Message& msg) +std::unique_ptr SnapshotServer::doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) +{ + assert(header.size() == sizeof(uint8_t)); + uint8_t call = static_cast(*header.data()); + switch (call) { + case faabric::scheduler::SnapshotCalls::PushSnapshot: + return recvPushSnapshot(body); + case faabric::scheduler::SnapshotCalls::PushSnapshotDiffs: + return recvPushSnapshotDiffs(body); + default: + throw std::runtime_error( + fmt::format("Unrecognized sync call header: {}", call)); + } +} + +std::unique_ptr SnapshotServer::recvPushSnapshot( + faabric::transport::Message& msg) { SnapshotPushRequest* r = flatbuffers::GetMutableRoot(msg.udata()); @@ -66,8 +78,9 @@ void SnapshotServer::recvPushSnapshot(faabric::transport::Message& msg) reg.takeSnapshot(r->key()->str(), data, true); // Send response - faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, r->return_host()->str()) + auto response = std::make_unique(); + + return response; } void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) @@ -89,7 +102,8 @@ void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) sch.setThreadResultLocally(r->message_id(), r->return_value()); } -void SnapshotServer::recvPushSnapshotDiffs(faabric::transport::Message& msg) +faabric::Message SnapshotServer::recvPushSnapshotDiffs( + faabric::transport::Message& msg) { const SnapshotDiffPushRequest* r = flatbuffers::GetMutableRoot(msg.udata()); @@ -98,7 +112,7 @@ void SnapshotServer::recvPushSnapshotDiffs(faabric::transport::Message& msg) // Send response faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, r->return_host()->str()) + return response; } void SnapshotServer::applyDiffsToSnapshot( diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 1bd7ce493..d5a7e41c2 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -25,39 +25,46 @@ namespace faabric::transport { -MessageEndpoint::MessageEndpoint(zmq::socket_type socketTypeIn, - const std::string& hostIn, +MessageEndpoint::MessageEndpoint(const std::string& hostIn, int portIn, int timeoutMsIn) - : socketType(socketTypeIn) - , host(hostIn) + : host(hostIn) , port(portIn) , address("tcp://" + host + ":" + std::to_string(port)) , timeoutMs(timeoutMsIn) , tid(std::this_thread::get_id()) , id(faabric::util::generateGid()) { - // Create the socket - CATCH_ZMQ_ERR(socket = - zmq::socket_t(*getGlobalMessageContext(), socketType), - "socket_create") // Check and set socket timeout if (timeoutMs <= 0) { SPDLOG_ERROR("Setting invalid timeout of {}", timeoutMs); throw std::runtime_error("Setting invalid timeout"); } +} + +zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, + int socketPort) +{ + zmq::socket_t socket; + // Create the socket + CATCH_ZMQ_ERR(socket = + zmq::socket_t(*getGlobalMessageContext(), socketType), + "socket_create") socket.set(zmq::sockopt::rcvtimeo, timeoutMs); socket.set(zmq::sockopt::sndtimeo, timeoutMs); // Note - setting linger here is essential to avoid infinite hangs socket.set(zmq::sockopt::linger, LINGER_MS); - // Note - only one socket may bind, but several can connect. This - // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between - // bind and connect does not matter. switch (socketType) { + case zmq::socket_type::req: { + SPDLOG_TRACE( + "Opening req socket {}:{} (timeout {}ms)", host, port, timeoutMs); + CATCH_ZMQ_ERR(socket.connect(address), "connect") + break; + } case zmq::socket_type::push: { SPDLOG_TRACE("Opening push socket {}:{} (timeout {}ms)", host, @@ -74,88 +81,54 @@ MessageEndpoint::MessageEndpoint(zmq::socket_type socketTypeIn, CATCH_ZMQ_ERR(socket.bind(address), "bind") break; } + case zmq::socket_type::rep: { + SPDLOG_TRACE( + "Opening rep socket {}:{} (timeout {}ms)", host, port, timeoutMs); + CATCH_ZMQ_ERR(socket.bind(address), "bind") + break; + } default: { throw std::runtime_error("Opening unrecognized socket type"); } } } -std::string MessageEndpoint::getHost() -{ - return host; -} - -int MessageEndpoint::getPort() -{ - return port; -} - -// ---------------------------------------------- -// SEND ENDPOINT -// ---------------------------------------------- - -SendMessageEndpoint::SendMessageEndpoint(const std::string& hostIn, - int portIn, - int timeoutMs) - : MessageEndpoint(zmq::socket_type::push, hostIn, portIn, timeoutMs) -{} - -void SendMessageEndpoint::send(uint8_t* serialisedMsg, - size_t msgSize, - bool more) +void MessageEndpoint::doSend(zmq::socket_t& socket, + uint8_t* data, + size_t dataSize, + bool more) { assert(tid == std::this_thread::get_id()); - zmq::send_flags sendFlags = - more ? zmq::send_flags::sndmore : zmq::send_flags::none; + more ? zmq::send_flags::sndmore : zmq::send_flags::dontwait; CATCH_ZMQ_ERR( { - auto res = - socket.send(zmq::buffer(serialisedMsg, msgSize), sendFlags); - if (res != msgSize) { + auto res = socket.send(zmq::buffer(data, dataSize), sendFlags); + if (res != dataSize) { SPDLOG_ERROR("Sent different bytes than expected (sent " "{}, expected {})", res.value_or(0), - msgSize); + dataSize); throw std::runtime_error("Error sending message"); } }, "send") } -// Block until we receive a response from the server -Message SendMessageEndpoint::awaitResponse() -{ - // Wait for the response, open a temporary endpoint for it - RecvMessageEndpoint endpoint(port + REPLY_PORT_OFFSET, timeoutMs); - - Message receivedMessage = endpoint.recv(); - - return receivedMessage; -} - -// ---------------------------------------------- -// RECV ENDPOINT -// ---------------------------------------------- - -RecvMessageEndpoint::RecvMessageEndpoint(int portIn, int timeoutMs) - : MessageEndpoint(zmq::socket_type::pull, ANY_HOST, portIn, timeoutMs) -{} - -Message RecvMessageEndpoint::recv(int size) +Message MessageEndpoint::doRecv(zmq::socket_t& socket, int size) { assert(tid == std::this_thread::get_id()); assert(size >= 0); if (size == 0) { - return recvNoBuffer(); + return recvNoBuffer(socket); } - return recvBuffer(size); + return recvBuffer(socket, size); } -Message RecvMessageEndpoint::recvBuffer(int size) +Message MessageEndpoint::recvBuffer(zmq::socket_t& socket, int size) { // Pre-allocate buffer to avoid copying data Message msg(size); @@ -189,7 +162,7 @@ Message RecvMessageEndpoint::recvBuffer(int size) return msg; } -Message RecvMessageEndpoint::recvNoBuffer() +Message MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) { // Allocate a message to receive data zmq::message_t msg; @@ -213,15 +186,114 @@ Message RecvMessageEndpoint::recvNoBuffer() return Message(msg); } +std::string MessageEndpoint::getHost() +{ + return host; +} + +int MessageEndpoint::getPort() +{ + return port; +} + +// ---------------------------------------------- +// ASYNC SEND ENDPOINT +// ---------------------------------------------- + +AsyncSendMessageEndpoint::AsyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs) + : MessageEndpoint(hostIn, portIn, timeoutMs) +{ + pushSocket = setUpSocket(zmq::socket_type::push, portIn); +} + +void AsyncSendMessageEndpoint::send(uint8_t* serialisedMsg, + size_t msgSize, + bool more) +{ + doSend(pushSocket, serialisedMsg, msgSize, more); +} + +// ---------------------------------------------- +// SYNC SEND ENDPOINT +// ---------------------------------------------- + +SyncSendMessageEndpoint::SyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs) + : MessageEndpoint(hostIn, portIn, timeoutMs) +{ + reqSocket = setUpSocket(zmq::socket_type::req, portIn + 1); +} + +void SyncSendMessageEndpoint::sendHeader(int header) +{ + uint8_t headerBytes = static_cast(header); + doSend(reqSocket, &headerBytes, sizeof(headerBytes), true); +} + +Message SyncSendMessageEndpoint::sendAwaitResponse(uint8_t* serialisedMsg, + size_t msgSize, + bool more) +{ + doSend(reqSocket, serialisedMsg, msgSize, more); + + // Do the receive + zmq::message_t msg; + CATCH_ZMQ_ERR( + try { + auto res = reqSocket.recv(msg); + if (!res.has_value()) { + SPDLOG_ERROR("Timed out receiving message with no size"); + throw MessageTimeoutException("Timed out receiving message"); + } + } catch (zmq::error_t& e) { + if (e.num() == ZMQ_ETERM) { + SPDLOG_TRACE("Endpoint received ETERM"); + return Message(); + } + throw; + }, + "send_recv") + + return Message(msg); +} + +// ---------------------------------------------- +// ASYNC RECV ENDPOINT +// ---------------------------------------------- + +AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) + : MessageEndpoint(ANY_HOST, portIn, timeoutMs) +{ + pullSocket = setUpSocket(zmq::socket_type::pull, portIn); +} + +Message AsyncRecvMessageEndpoint::recv(int size) +{ + return doRecv(pullSocket, size); +} + +// ---------------------------------------------- +// SYNC RECV ENDPOINT +// ---------------------------------------------- +// +SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) + : MessageEndpoint(ANY_HOST, portIn, timeoutMs) +{ + repSocket = setUpSocket(zmq::socket_type::rep, portIn + 1); +} + +Message SyncRecvMessageEndpoint::recv(int size) +{ + return doRecv(repSocket, size); +} + // We create a new endpoint every time. Re-using them would be a possible // optimisation if needed. -void RecvMessageEndpoint::sendResponse(uint8_t* data, - int size, - const std::string& returnHost) -{ - // Open the endpoint socket, server connects (not bind) to remote address - SendMessageEndpoint sendEndpoint( - returnHost, port + REPLY_PORT_OFFSET, timeoutMs); - sendEndpoint.send(data, size); +void SyncRecvMessageEndpoint::sendResponse(uint8_t* data, int size) +{ + doSend(repSocket, data, size, false); } } diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp new file mode 100644 index 000000000..6fb069be3 --- /dev/null +++ b/src/transport/MessageEndpointClient.cpp @@ -0,0 +1,38 @@ +#pragma once + +#include + +namespace faabric::transport { + +MessageEndpointClient::MessageEndpointClient(std::string hostIn, int portIn) + : host(hostIn) + , asyncPort(portIn) + , syncPort(portIn + 1) + , asyncEndpoint(host, asyncPort) + , syncEndpoint(host, asyncPort) +{} + +void MessageEndpointClient::asyncSend( + int header, + std::unique_ptr msg) +{} + +void MessageEndpointClient::syncSend( + int header, + std::unique_ptr msg, + std::unique_ptr response) +{ + syncEndpoint.sendHeader(header); + + size_t msgSize = msg->ByteSizeLong(); + uint8_t sMsg[msgSize]; + if (!msg->SerializeToArray(sMsg, msgSize)) { + throw std::runtime_error("Error serialising message"); + } + Message responseMsg = syncEndpoint.sendAwaitResponse(sMsg, msgSize); + // Deserialise message string + if (!response->ParseFromArray(responseMsg.data(), responseMsg.size())) { + throw std::runtime_error("Error deserialising message"); + } +} +} diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 3e95b578a..dbbfb3ecd 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,32 +1,86 @@ #include #include #include +#include #include #include namespace faabric::transport { MessageEndpointServer::MessageEndpointServer(int portIn) - : port(portIn) + : asyncPort(portIn) + , syncPort(portIn + 1) {} void MessageEndpointServer::start() { - // Start serving thread in background - servingThread = std::thread([this] { - recvEndpoint = std::make_unique(this->port); + asyncThread = std::thread([this] { + AsyncRecvMessageEndpoint endpoint(asyncPort); // Loop until we receive a shutdown message while (true) { - try { - bool messageReceived = this->recv(); - if (!messageReceived) { - SPDLOG_TRACE("Server received shutdown message"); - break; + // Receive header and body + Message header = endpoint.recv(); + + // Detect shutdown condition + if (header.size() == 0) { + SPDLOG_TRACE("Server received shutdown message"); + break; + } + + // Check the header was sent with ZMQ_SNDMORE flag + if (!header.more()) { + throw std::runtime_error("Header sent without SNDMORE flag"); + } + + // Check that there are no more messages to receive + Message body = endpoint.recv(); + if (body.more()) { + throw std::runtime_error("Body sent with SNDMORE flag"); + } + assert(body.udata() != nullptr); + + // Server-specific message handling + doAsyncRecv(header, body); + } + }); + + syncThread = std::thread([this] { + SyncRecvMessageEndpoint endpoint(syncPort); + + // Loop until we receive a shutdown message + while (true) { + // Receive header and body + Message header = endpoint.recv(); + + // Detect shutdown condition + if (header.size() == 0) { + SPDLOG_TRACE("Server received shutdown message"); + break; + } + + // Check the header was sent with ZMQ_SNDMORE flag + if (!header.more()) { + throw std::runtime_error("Header sent without SNDMORE flag"); + } + + // Check that there are no more messages to receive + Message body = endpoint.recv(); + if (body.more()) { + throw std::runtime_error("Body sent with SNDMORE flag"); + } + assert(body.udata() != nullptr); + + // Server-specific message handling + std::unique_ptr resp = + doSyncRecv(header, body); + size_t msgSize = resp->ByteSizeLong(); + { + uint8_t sMsg[msgSize]; + if (!resp->SerializeToArray(sMsg, msgSize)) { + throw std::runtime_error("Error serialising message"); } - } catch (MessageTimeoutException& ex) { - SPDLOG_TRACE("Server timed out with no messages, continuing"); - continue; + endpoint.sendResponse(sMsg, msgSize); } } }); @@ -34,45 +88,25 @@ void MessageEndpointServer::start() void MessageEndpointServer::stop() { - // Send a shutdown message via a temporary endpoint - SendMessageEndpoint e(recvEndpoint->getHost(), recvEndpoint->getPort()); + SPDLOG_TRACE( + "Sending sync shutdown message locally to {}:{}", LOCALHOST, syncPort); - SPDLOG_TRACE("Sending shutdown message locally to {}:{}", - recvEndpoint->getHost(), - recvEndpoint->getPort()); - e.send(nullptr, 0); + SyncSendMessageEndpoint syncSender(LOCALHOST, syncPort); + syncSender.sendAwaitResponse(nullptr, 0); - // Join the serving thread - if (servingThread.joinable()) { - servingThread.join(); - } -} - -bool MessageEndpointServer::recv() -{ - // Receive header and body - Message header = recvEndpoint->recv(); + SPDLOG_TRACE( + "Sending async shutdown message locally to {}:{}", LOCALHOST, asyncPort); - // Detect shutdown condition - if (header.size() == 0) { - return false; - } + AsyncSendMessageEndpoint asyncSender(LOCALHOST, syncPort); + asyncSender.send(nullptr, 0); - // Check the header was sent with ZMQ_SNDMORE flag - if (!header.more()) { - throw std::runtime_error("Header sent without SNDMORE flag"); + // Join the threads + if (asyncThread.joinable()) { + asyncThread.join(); } - // Check that there are no more messages to receive - Message body = recvEndpoint->recv(); - if (body.more()) { - throw std::runtime_error("Body sent with SNDMORE flag"); + if (syncThread.joinable()) { + syncThread.join(); } - assert(body.udata() != nullptr); - - // Server-specific message handling - doRecv(header, body); - - return true; } } From bae864eb25739001bb79d0ceae72792fa87900b3 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Thu, 24 Jun 2021 15:23:25 +0000 Subject: [PATCH 32/66] Req/rep part 2 --- .../faabric/scheduler/FunctionCallClient.h | 6 +- .../faabric/scheduler/FunctionCallServer.h | 15 +- include/faabric/scheduler/SnapshotClient.h | 3 +- include/faabric/scheduler/SnapshotServer.h | 6 +- include/faabric/state/StateClient.h | 18 +- include/faabric/state/StateServer.h | 34 ++-- include/faabric/transport/MessageEndpoint.h | 6 +- .../faabric/transport/MessageEndpointClient.h | 14 +- .../faabric/transport/MessageEndpointServer.h | 2 +- .../faabric/transport/MpiMessageEndpoint.h | 4 +- include/faabric/transport/macros.h | 9 - src/scheduler/FunctionCallClient.cpp | 30 ++-- src/scheduler/FunctionCallServer.cpp | 61 ++++--- src/scheduler/SnapshotClient.cpp | 52 +++--- src/scheduler/SnapshotServer.cpp | 29 ++-- src/state/StateClient.cpp | 96 +++-------- src/state/StateServer.cpp | 158 ++++++++++-------- src/transport/MessageEndpoint.cpp | 4 +- src/transport/MessageEndpointClient.cpp | 42 +++-- src/transport/MessageEndpointServer.cpp | 18 +- src/transport/MpiMessageEndpoint.cpp | 4 +- tests/test/transport/test_message_server.cpp | 27 ++- 22 files changed, 329 insertions(+), 309 deletions(-) diff --git a/include/faabric/scheduler/FunctionCallClient.h b/include/faabric/scheduler/FunctionCallClient.h index 2dc4e4de8..b0d3e01fe 100644 --- a/include/faabric/scheduler/FunctionCallClient.h +++ b/include/faabric/scheduler/FunctionCallClient.h @@ -47,13 +47,9 @@ class FunctionCallClient : public faabric::transport::MessageEndpointClient void executeFunctions( const std::shared_ptr req); - void unregister(const faabric::UnregisterRequest& req); + void unregister(faabric::UnregisterRequest& req); private: - faabric::transport::AsyncSendMessageEndpoint asyncEndpoint; - - faabric::transport::SyncSendMessageEndpoint syncEndpoint; - void sendHeader(faabric::scheduler::FunctionCalls call); }; } diff --git a/include/faabric/scheduler/FunctionCallServer.h b/include/faabric/scheduler/FunctionCallServer.h index c8d507818..733a6dec6 100644 --- a/include/faabric/scheduler/FunctionCallServer.h +++ b/include/faabric/scheduler/FunctionCallServer.h @@ -18,19 +18,18 @@ class FunctionCallServer final void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) override; - faabric::Message doSyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + std::unique_ptr doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) override; - /* Function call server API */ + std::unique_ptr recvFlush( + faabric::transport::Message& body); - void recvFlush(faabric::transport::Message& body); + std::unique_ptr recvGetResources( + faabric::transport::Message& body); void recvExecuteFunctions(faabric::transport::Message& body); - void recvGetResources(faabric::transport::Message& body); - void recvUnregister(faabric::transport::Message& body); - - void recvSetThreadResult(faabric::transport::Message& body); }; } diff --git a/include/faabric/scheduler/SnapshotClient.h b/include/faabric/scheduler/SnapshotClient.h index 3189847b2..6d7fd8655 100644 --- a/include/faabric/scheduler/SnapshotClient.h +++ b/include/faabric/scheduler/SnapshotClient.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace faabric::scheduler { @@ -32,7 +33,7 @@ void clearMockSnapshotRequests(); // gRPC client // ----------------------------------- -class SnapshotClient final : public faabric::transport::SendMessageEndpoint +class SnapshotClient final : public faabric::transport::MessageEndpointClient { public: explicit SnapshotClient(const std::string& hostIn); diff --git a/include/faabric/scheduler/SnapshotServer.h b/include/faabric/scheduler/SnapshotServer.h index 90253b40c..8d0ea4a66 100644 --- a/include/faabric/scheduler/SnapshotServer.h +++ b/include/faabric/scheduler/SnapshotServer.h @@ -19,16 +19,14 @@ class SnapshotServer final : public faabric::transport::MessageEndpointServer faabric::transport::Message& header, faabric::transport::Message& body) override; - /* Snapshot server API */ - std::unique_ptr recvPushSnapshot( faabric::transport::Message& msg); - void recvDeleteSnapshot(faabric::transport::Message& msg); - std::unique_ptr recvPushSnapshotDiffs( faabric::transport::Message& msg); + void recvDeleteSnapshot(faabric::transport::Message& msg); + void recvThreadResult(faabric::transport::Message& msg); private: diff --git a/include/faabric/state/StateClient.h b/include/faabric/state/StateClient.h index 05a503bae..fff3e2b4a 100644 --- a/include/faabric/state/StateClient.h +++ b/include/faabric/state/StateClient.h @@ -4,9 +4,10 @@ #include #include #include +#include namespace faabric::state { -class StateClient : public faabric::transport::SendMessageEndpoint +class StateClient : public faabric::transport::MessageEndpointClient { public: explicit StateClient(const std::string& userIn, @@ -15,9 +16,6 @@ class StateClient : public faabric::transport::SendMessageEndpoint const std::string user; const std::string key; - const std::string host; - - /* External state client API */ void pushChunks(const std::vector& chunks); @@ -39,16 +37,8 @@ class StateClient : public faabric::transport::SendMessageEndpoint void unlock(); private: - void sendHeader(faabric::state::StateCalls call); - - // Block, but ignore return value - faabric::transport::Message awaitResponse(); - - void sendStateRequest(faabric::state::StateCalls header, bool expectReply); - void sendStateRequest(faabric::state::StateCalls header, - const uint8_t* data = nullptr, - int length = 0, - bool expectReply = false); + const uint8_t* data, + int length); }; } diff --git a/include/faabric/state/StateServer.h b/include/faabric/state/StateServer.h index e0a24e20e..883ab503e 100644 --- a/include/faabric/state/StateServer.h +++ b/include/faabric/state/StateServer.h @@ -16,27 +16,37 @@ class StateServer final : public faabric::transport::MessageEndpointServer void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) override; - faabric::Message doSyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + std::unique_ptr doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) override; - /* State server API */ + // Sync methods - void recvSize(faabric::transport::Message& body); + std::unique_ptr recvSize( + faabric::transport::Message& body); - void recvPull(faabric::transport::Message& body); + std::unique_ptr recvPull( + faabric::transport::Message& body); - void recvPush(faabric::transport::Message& body); + std::unique_ptr recvPush( + faabric::transport::Message& body); - void recvAppend(faabric::transport::Message& body); + std::unique_ptr recvAppend( + faabric::transport::Message& body); - void recvPullAppended(faabric::transport::Message& body); + std::unique_ptr recvPullAppended( + faabric::transport::Message& body); - void recvClearAppended(faabric::transport::Message& body); + std::unique_ptr recvClearAppended( + faabric::transport::Message& body); - void recvDelete(faabric::transport::Message& body); + std::unique_ptr recvDelete( + faabric::transport::Message& body); - void recvLock(faabric::transport::Message& body); + std::unique_ptr recvLock( + faabric::transport::Message& body); - void recvUnlock(faabric::transport::Message& body); + std::unique_ptr recvUnlock( + faabric::transport::Message& body); }; } diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index bc381cd73..6f2e93744 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -1,7 +1,5 @@ #pragma once -#include - #include #include @@ -56,7 +54,7 @@ class MessageEndpoint zmq::socket_t setUpSocket(zmq::socket_type socketType, int socketPort); void doSend(zmq::socket_t& socket, - uint8_t* data, + const uint8_t* data, size_t dataSize, bool more); @@ -89,7 +87,7 @@ class SyncSendMessageEndpoint : public MessageEndpoint void sendHeader(int header); - Message sendAwaitResponse(uint8_t* serialisedMsg, + Message sendAwaitResponse(const uint8_t* serialisedMsg, size_t msgSize, bool more = false); diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h index dfff207c5..c79a31120 100644 --- a/include/faabric/transport/MessageEndpointClient.h +++ b/include/faabric/transport/MessageEndpointClient.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -13,11 +14,18 @@ class MessageEndpointClient protected: const std::string host; - void asyncSend(int header, std::unique_ptr msg); + void asyncSend(int header, google::protobuf::Message* msg); + + void asyncSend(int header, uint8_t* buffer, size_t bufferSize); + + void syncSend(int header, + google::protobuf::Message* msg, + google::protobuf::Message* response); void syncSend(int header, - std::unique_ptr msg, - std::unique_ptr response); + const uint8_t* buffer, + size_t bufferSize, + google::protobuf::Message* response); private: const int asyncPort; diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 0f81a0939..85234b28b 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -37,7 +37,7 @@ class MessageEndpointServer faabric::transport::Message& header, faabric::transport::Message& body) = 0; - void sendSyncResponse(); + void sendSyncResponse(google::protobuf::Message* resp); private: const int asyncPort; diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 991331bc5..753c65532 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -32,7 +32,7 @@ class MpiMessageEndpoint void close(); private: - SendMessageEndpoint sendMessageEndpoint; - RecvMessageEndpoint recvMessageEndpoint; + AsyncSendMessageEndpoint sendMessageEndpoint; + AsyncRecvMessageEndpoint recvMessageEndpoint; }; } diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index 21e715d4f..1347f39a4 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -22,15 +22,6 @@ send(sMsg, msgSize); \ } -#define SEND_SERVER_RESPONSE(msg, host) \ - size_t msgSize = msg.ByteSizeLong(); \ - { \ - uint8_t sMsg[msgSize]; \ - if (!msg.SerializeToArray(sMsg, msgSize)) { \ - throw std::runtime_error("Error serialising message"); \ - } \ - recvEndpoint->sendResponse(sMsg, msgSize, host); \ - } #define PARSE_MSG(T, data, size) \ T msg; \ diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index df73ff827..7f8e736a5 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -86,14 +86,13 @@ FunctionCallClient::FunctionCallClient(const std::string& hostIn) void FunctionCallClient::sendFlush() { + faabric::ResponseRequest req; if (faabric::util::isMockMode()) { - faabric::ResponseRequest call; faabric::util::UniqueLock lock(mockMutex); - flushCalls.emplace_back(host, call); + flushCalls.emplace_back(host, req); } else { - auto call = std::make_unique(); - auto resp = std::make_unique(); - syncSend(faabric::scheduler::FunctionCalls::Flush, call, resp); + faabric::EmptyResponse resp; + syncSend(faabric::scheduler::FunctionCalls::Flush, &req, &resp); } } @@ -101,6 +100,7 @@ faabric::HostResources FunctionCallClient::getResources() { faabric::ResponseRequest request; faabric::HostResources response; + if (faabric::util::isMockMode()) { faabric::util::UniqueLock lock(mockMutex); @@ -112,16 +112,8 @@ faabric::HostResources FunctionCallClient::getResources() response = queuedResourceResponses[host].dequeue(); } } else { - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - - SEND_MESSAGE(faabric::scheduler::FunctionCalls::GetResources, request); - - // Receive message - faabric::transport::Message msg = awaitResponse(); - // Deserialise message string - if (!response.ParseFromArray(msg.data(), msg.size())) { - throw std::runtime_error("Error deserialising message"); - } + syncSend( + faabric::scheduler::FunctionCalls::GetResources, &request, &response); } return response; @@ -134,18 +126,18 @@ void FunctionCallClient::executeFunctions( faabric::util::UniqueLock lock(mockMutex); batchMessages.emplace_back(host, req); } else { - SEND_MESSAGE_PTR(faabric::scheduler::FunctionCalls::ExecuteFunctions, - req); + asyncSend(faabric::scheduler::FunctionCalls::ExecuteFunctions, + req.get()); } } -void FunctionCallClient::unregister(const faabric::UnregisterRequest& req) +void FunctionCallClient::unregister(faabric::UnregisterRequest& req) { if (faabric::util::isMockMode()) { faabric::util::UniqueLock lock(mockMutex); unregisterRequests.emplace_back(host, req); } else { - SEND_MESSAGE(faabric::scheduler::FunctionCalls::Unregister, req); + asyncSend(faabric::scheduler::FunctionCalls::Unregister, &req); } } } diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 716756036..961dfccad 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -13,31 +13,50 @@ FunctionCallServer::FunctionCallServer() , scheduler(getScheduler()) {} -void FunctionCallServer::doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +void FunctionCallServer::doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) { assert(header.size() == sizeof(uint8_t)); uint8_t call = static_cast(*header.data()); switch (call) { - case faabric::scheduler::FunctionCalls::Flush: - this->recvFlush(body); + case faabric::scheduler::FunctionCalls::ExecuteFunctions: { + recvExecuteFunctions(body); break; - case faabric::scheduler::FunctionCalls::ExecuteFunctions: - this->recvExecuteFunctions(body); + } + case faabric::scheduler::FunctionCalls::Unregister: { + recvUnregister(body); break; - case faabric::scheduler::FunctionCalls::Unregister: - this->recvUnregister(body); - break; - case faabric::scheduler::FunctionCalls::GetResources: - this->recvGetResources(body); - break; - default: + } + default: { + throw std::runtime_error( + fmt::format("Unrecognized async call header: {}", call)); + } + } +} + +std::unique_ptr FunctionCallServer::doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) +{ + assert(header.size() == sizeof(uint8_t)); + + uint8_t call = static_cast(*header.data()); + switch (call) { + case faabric::scheduler::FunctionCalls::Flush: { + return recvFlush(body); + } + case faabric::scheduler::FunctionCalls::GetResources: { + return recvGetResources(body); + } + default: { throw std::runtime_error( - fmt::format("Unrecognized call header: {}", call)); + fmt::format("Unrecognized sync call header: {}", call)); + } } } -void FunctionCallServer::recvFlush(faabric::transport::Message& body) +std::unique_ptr FunctionCallServer::recvFlush( + faabric::transport::Message& body) { PARSE_MSG(faabric::ResponseRequest, body.data(), body.size()); @@ -47,8 +66,7 @@ void FunctionCallServer::recvFlush(faabric::transport::Message& body) // Clear the scheduler scheduler.flushLocally(); - faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, msg.returnhost()) + return std::make_unique(); } void FunctionCallServer::recvExecuteFunctions(faabric::transport::Message& body) @@ -71,12 +89,13 @@ void FunctionCallServer::recvUnregister(faabric::transport::Message& body) scheduler.removeRegisteredHost(msg.host(), msg.function()); } -void FunctionCallServer::recvGetResources(faabric::transport::Message& body) +std::unique_ptr FunctionCallServer::recvGetResources( + faabric::transport::Message& body) { PARSE_MSG(faabric::ResponseRequest, body.data(), body.size()) - // Send the response body - faabric::HostResources response = scheduler.getThisHostResources(); - SEND_SERVER_RESPONSE(response, msg.returnhost()) + auto response = std::make_unique( + scheduler.getThisHostResources()); + return response; } } diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 37a072f1d..408bcdad9 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -69,23 +69,10 @@ void clearMockSnapshotRequests() // Snapshot client // ----------------------------------- -#define SEND_FB_REQUEST(T) \ - sendHeader(T); \ - mb.Finish(requestOffset); \ - uint8_t* buffer = mb.GetBufferPointer(); \ - int size = mb.GetSize(); \ - send(buffer, size); - SnapshotClient::SnapshotClient(const std::string& hostIn) - : faabric::transport::SendMessageEndpoint(hostIn, SNAPSHOT_PORT) + : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_PORT) {} -void SnapshotClient::sendHeader(faabric::scheduler::SnapshotCalls call) -{ - uint8_t header = static_cast(call); - send(&header, sizeof(header), true); -} - void SnapshotClient::pushSnapshot(const std::string& key, const faabric::util::SnapshotData& data) { @@ -108,10 +95,14 @@ void SnapshotClient::pushSnapshot(const std::string& key, mb, returnHostOffset, keyOffset, dataOffset); // Send it - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::PushSnapshot) - - // Await a response as this call must be synchronous - awaitResponse(); + mb.Finish(requestOffset); + uint8_t* buffer = mb.GetBufferPointer(); + int size = mb.GetSize(); + faabric::EmptyResponse response; + syncSend(faabric::scheduler::SnapshotCalls::PushSnapshot, + buffer, + size, + &response); } } @@ -149,11 +140,14 @@ void SnapshotClient::pushSnapshotDiffs( auto requestOffset = CreateSnapshotDiffPushRequest( mb, returnHostOffset, keyOffset, diffsOffset); - // Send it - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::PushSnapshotDiffs) - - // Await a response as this call must be synchronous - awaitResponse(); + mb.Finish(requestOffset); + uint8_t* buffer = mb.GetBufferPointer(); + int size = mb.GetSize(); + faabric::EmptyResponse response; + syncSend(faabric::scheduler::SnapshotCalls::PushSnapshotDiffs, + buffer, + size, + &response); } } @@ -171,7 +165,11 @@ void SnapshotClient::deleteSnapshot(const std::string& key) auto keyOffset = mb.CreateString(key); auto requestOffset = CreateSnapshotDeleteRequest(mb, keyOffset); - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::DeleteSnapshot) + mb.Finish(requestOffset); + uint8_t* buffer = mb.GetBufferPointer(); + int size = mb.GetSize(); + asyncSend( + faabric::scheduler::SnapshotCalls::PushSnapshotDiffs, buffer, size); } } @@ -228,7 +226,11 @@ void SnapshotClient::pushThreadResult( CreateThreadResultRequest(mb, messageId, returnValue); } - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::ThreadResult) + mb.Finish(requestOffset); + uint8_t* buffer = mb.GetBufferPointer(); + int size = mb.GetSize(); + asyncSend( + faabric::scheduler::SnapshotCalls::PushSnapshotDiffs, buffer, size); } } } diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index c1d77b7f2..45a51ef8c 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -21,15 +21,18 @@ void SnapshotServer::doAsyncRecv(faabric::transport::Message& header, assert(header.size() == sizeof(uint8_t)); uint8_t call = static_cast(*header.data()); switch (call) { - case faabric::scheduler::SnapshotCalls::DeleteSnapshot: + case faabric::scheduler::SnapshotCalls::DeleteSnapshot: { this->recvDeleteSnapshot(body); break; - case faabric::scheduler::SnapshotCalls::ThreadResult: + } + case faabric::scheduler::SnapshotCalls::ThreadResult: { this->recvThreadResult(body); break; - default: + } + default: { throw std::runtime_error( fmt::format("Unrecognized async call header: {}", call)); + } } } @@ -40,13 +43,16 @@ std::unique_ptr SnapshotServer::doSyncRecv( assert(header.size() == sizeof(uint8_t)); uint8_t call = static_cast(*header.data()); switch (call) { - case faabric::scheduler::SnapshotCalls::PushSnapshot: + case faabric::scheduler::SnapshotCalls::PushSnapshot: { return recvPushSnapshot(body); - case faabric::scheduler::SnapshotCalls::PushSnapshotDiffs: + } + case faabric::scheduler::SnapshotCalls::PushSnapshotDiffs: { return recvPushSnapshotDiffs(body); - default: + } + default: { throw std::runtime_error( fmt::format("Unrecognized sync call header: {}", call)); + } } } @@ -78,9 +84,7 @@ std::unique_ptr SnapshotServer::recvPushSnapshot( reg.takeSnapshot(r->key()->str(), data, true); // Send response - auto response = std::make_unique(); - - return response; + return std::make_unique(); } void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) @@ -102,8 +106,8 @@ void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) sch.setThreadResultLocally(r->message_id(), r->return_value()); } -faabric::Message SnapshotServer::recvPushSnapshotDiffs( - faabric::transport::Message& msg) +std::unique_ptr +SnapshotServer::recvPushSnapshotDiffs(faabric::transport::Message& msg) { const SnapshotDiffPushRequest* r = flatbuffers::GetMutableRoot(msg.udata()); @@ -111,8 +115,7 @@ faabric::Message SnapshotServer::recvPushSnapshotDiffs( applyDiffsToSnapshot(r->key()->str(), r->chunks()); // Send response - faabric::EmptyResponse response; - return response; + return std::make_unique(); } void SnapshotServer::applyDiffsToSnapshot( diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index b95d7237a..dfd2a0480 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -8,43 +8,25 @@ namespace faabric::state { StateClient::StateClient(const std::string& userIn, const std::string& keyIn, const std::string& hostIn) - : faabric::transport::SendMessageEndpoint(hostIn, STATE_PORT) + : faabric::transport::MessageEndpointClient(hostIn, STATE_PORT) , user(userIn) , key(keyIn) - , host(hostIn) {} -void StateClient::sendHeader(faabric::state::StateCalls call) -{ - uint8_t header = static_cast(call); - send(&header, sizeof(header), true); -} - -faabric::transport::Message StateClient::awaitResponse() -{ - // Call the superclass implementation - return SendMessageEndpoint::awaitResponse(); -} - -void StateClient::sendStateRequest(faabric::state::StateCalls header, - bool expectReply) -{ - sendStateRequest(header, nullptr, 0, expectReply); -} - void StateClient::sendStateRequest(faabric::state::StateCalls header, const uint8_t* data, - int length, - bool expectReply) + int length) { faabric::StateRequest request; request.set_user(user); request.set_key(key); + if (length > 0) { request.set_data(data, length); } - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(header, request) + + faabric::EmptyResponse resp; + syncSend(header, data, length, &resp); } void StateClient::pushChunks(const std::vector& chunks) @@ -55,17 +37,9 @@ void StateClient::pushChunks(const std::vector& chunks) stateChunk.set_key(key); stateChunk.set_offset(chunk.offset); stateChunk.set_data(chunk.data, chunk.length); - stateChunk.set_returnhost( - faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(faabric::state::StateCalls::Push, stateChunk) - - // Await for a response, but discard it as it is empty - try { - (void)awaitResponse(); - } catch (...) { - SPDLOG_ERROR("Error in awaitReponse"); - throw; - } + + faabric::EmptyResponse resp; + syncSend(faabric::state::StateCalls::Push, &stateChunk, &resp); } } @@ -80,11 +54,10 @@ void StateClient::pullChunks(const std::vector& chunks, request.set_offset(chunk.offset); request.set_chunksize(chunk.length); request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(faabric::state::StateCalls::Pull, request) - // Receive message - faabric::transport::Message recvMsg = awaitResponse(); - PARSE_RESPONSE(faabric::StatePart, recvMsg.data(), recvMsg.size()) + // Send and copy response into place + faabric::StatePart response; + syncSend(faabric::state::StateCalls::Pull, &request, &response); std::copy(response.data().begin(), response.data().end(), bufferStart + response.offset()); @@ -93,11 +66,7 @@ void StateClient::pullChunks(const std::vector& chunks, void StateClient::append(const uint8_t* data, size_t length) { - // Send request sendStateRequest(faabric::state::StateCalls::Append, data, length); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); } void StateClient::pullAppended(uint8_t* buffer, size_t length, long nValues) @@ -108,12 +77,9 @@ void StateClient::pullAppended(uint8_t* buffer, size_t length, long nValues) request.set_key(key); request.set_nvalues(nValues); request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(faabric::state::StateCalls::PullAppended, request) - // Receive response - faabric::transport::Message recvMsg = awaitResponse(); - PARSE_RESPONSE( - faabric::StateAppendedResponse, recvMsg.data(), recvMsg.size()) + faabric::StateAppendedResponse response; + syncSend(faabric::state::StateCalls::PullAppended, &request, &response); // Process response size_t offset = 0; @@ -133,47 +99,33 @@ void StateClient::pullAppended(uint8_t* buffer, size_t length, long nValues) void StateClient::clearAppended() { - // Send request - sendStateRequest(faabric::state::StateCalls::ClearAppended); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::ClearAppended, nullptr, 0); } size_t StateClient::stateSize() { - // Include the return address in the message body - sendStateRequest(faabric::state::StateCalls::Size, true); + faabric::StateRequest request; + request.set_user(user); + request.set_key(key); + + faabric::StateSizeResponse response; + syncSend(faabric::state::StateCalls::Size, nullptr, 0, &response); - // Receive message - faabric::transport::Message recvMsg = awaitResponse(); - PARSE_RESPONSE(faabric::StateSizeResponse, recvMsg.data(), recvMsg.size()) return response.statesize(); } void StateClient::deleteState() { - // Send request - sendStateRequest(faabric::state::StateCalls::Delete, true); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::Delete, nullptr, 0); } void StateClient::lock() { - // Send request - sendStateRequest(faabric::state::StateCalls::Lock); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::Lock, nullptr, 0); } void StateClient::unlock() { - sendStateRequest(faabric::state::StateCalls::Unlock); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::Unlock, nullptr, 0); } } diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index 96d4a43bd..703e6e72c 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -17,60 +17,72 @@ StateServer::StateServer(State& stateIn) , state(stateIn) {} -void StateServer::doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +void StateServer::doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) +{ + throw std::runtime_error("State server does not support async recv"); +} + +std::unique_ptr StateServer::doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) { assert(header.size() == sizeof(uint8_t)); + uint8_t call = static_cast(*header.data()); switch (call) { - case faabric::state::StateCalls::Pull: - this->recvPull(body); - break; - case faabric::state::StateCalls::Push: - this->recvPush(body); - break; - case faabric::state::StateCalls::Size: - this->recvSize(body); - break; - case faabric::state::StateCalls::Append: - this->recvAppend(body); - break; - case faabric::state::StateCalls::ClearAppended: - this->recvClearAppended(body); - break; - case faabric::state::StateCalls::PullAppended: - this->recvPullAppended(body); - break; - case faabric::state::StateCalls::Lock: - this->recvLock(body); - break; - case faabric::state::StateCalls::Unlock: - this->recvUnlock(body); - break; - case faabric::state::StateCalls::Delete: - this->recvDelete(body); - break; - default: + case faabric::state::StateCalls::Pull: { + return recvPull(body); + } + case faabric::state::StateCalls::Push: { + return recvPush(body); + } + case faabric::state::StateCalls::Size: { + return recvSize(body); + } + case faabric::state::StateCalls::Append: { + return recvAppend(body); + } + case faabric::state::StateCalls::ClearAppended: { + return recvClearAppended(body); + } + case faabric::state::StateCalls::PullAppended: { + return recvPullAppended(body); + } + case faabric::state::StateCalls::Lock: { + return recvLock(body); + } + case faabric::state::StateCalls::Unlock: { + return recvUnlock(body); + } + case faabric::state::StateCalls::Delete: { + return recvDelete(body); + } + default: { throw std::runtime_error( fmt::format("Unrecognized state call header: {}", call)); + } } } -void StateServer::recvSize(faabric::transport::Message& body) +std::unique_ptr StateServer::recvSize( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateRequest, body.data(), body.size()) // Prepare the response SPDLOG_TRACE("Size {}/{}", msg.user(), msg.key()); KV_FROM_REQUEST(msg) - faabric::StateSizeResponse response; - response.set_user(kv->user); - response.set_key(kv->key); - response.set_statesize(kv->size()); - SEND_SERVER_RESPONSE(response, msg.returnhost()) + auto response = std::make_unique(); + response->set_user(kv->user); + response->set_key(kv->key); + response->set_statesize(kv->size()); + + return response; } -void StateServer::recvPull(faabric::transport::Message& body) +std::unique_ptr StateServer::recvPull( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateChunkRequest, body.data(), body.size()) @@ -81,20 +93,23 @@ void StateServer::recvPull(faabric::transport::Message& body) msg.offset() + msg.chunksize()); // Write the response - faabric::StatePart response; KV_FROM_REQUEST(msg) uint64_t chunkOffset = msg.offset(); uint64_t chunkLen = msg.chunksize(); uint8_t* chunk = kv->getChunk(chunkOffset, chunkLen); - response.set_user(msg.user()); - response.set_key(msg.key()); - response.set_offset(chunkOffset); + + auto response = std::make_unique(); + response->set_user(msg.user()); + response->set_key(msg.key()); + response->set_offset(chunkOffset); // TODO: avoid copying here - response.set_data(chunk, chunkLen); - SEND_SERVER_RESPONSE(response, msg.returnhost()) + response->set_data(chunk, chunkLen); + + return response; } -void StateServer::recvPush(faabric::transport::Message& body) +std::unique_ptr StateServer::recvPush( + faabric::transport::Message& body) { PARSE_MSG(faabric::StatePart, body.data(), body.size()) @@ -104,15 +119,17 @@ void StateServer::recvPush(faabric::transport::Message& body) msg.key(), msg.offset(), msg.offset() + msg.data().size()); + KV_FROM_REQUEST(msg) kv->setChunk( msg.offset(), BYTES_CONST(msg.data().c_str()), msg.data().size()); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) + auto response = std::make_unique(); + return response; } -void StateServer::recvAppend(faabric::transport::Message& body) +std::unique_ptr StateServer::recvAppend( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateRequest, body.data(), body.size()) @@ -122,30 +139,34 @@ void StateServer::recvAppend(faabric::transport::Message& body) uint64_t dataLen = msg.data().size(); kv->append(reqData, dataLen); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) + auto response = std::make_unique(); + return response; } -void StateServer::recvPullAppended(faabric::transport::Message& body) +std::unique_ptr StateServer::recvPullAppended( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateAppendedRequest, body.data(), body.size()) // Prepare response - faabric::StateAppendedResponse response; SPDLOG_TRACE("Pull appended {}/{}", msg.user(), msg.key()); KV_FROM_REQUEST(msg) - response.set_user(msg.user()); - response.set_key(msg.key()); + + auto response = std::make_unique(); + response->set_user(msg.user()); + response->set_key(msg.key()); for (uint32_t i = 0; i < msg.nvalues(); i++) { AppendedInMemoryState& value = kv->getAppendedValue(i); - auto appendedValue = response.add_values(); + auto appendedValue = response->add_values(); appendedValue->set_data(reinterpret_cast(value.data.get()), value.length); } - SEND_SERVER_RESPONSE(response, msg.returnhost()) + + return response; } -void StateServer::recvDelete(faabric::transport::Message& body) +std::unique_ptr StateServer::recvDelete( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateRequest, body.data(), body.size()) @@ -153,11 +174,12 @@ void StateServer::recvDelete(faabric::transport::Message& body) SPDLOG_TRACE("Delete {}/{}", msg.user(), msg.key()); state.deleteKV(msg.user(), msg.key()); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) + auto response = std::make_unique(); + return response; } -void StateServer::recvClearAppended(faabric::transport::Message& body) +std::unique_ptr StateServer::recvClearAppended( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateRequest, body.data(), body.size()) @@ -166,11 +188,12 @@ void StateServer::recvClearAppended(faabric::transport::Message& body) KV_FROM_REQUEST(msg) kv->clearAppended(); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) + auto response = std::make_unique(); + return response; } -void StateServer::recvLock(faabric::transport::Message& body) +std::unique_ptr StateServer::recvLock( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateRequest, body.data(), body.size()) @@ -179,11 +202,12 @@ void StateServer::recvLock(faabric::transport::Message& body) KV_FROM_REQUEST(msg) kv->lockWrite(); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) + auto response = std::make_unique(); + return response; } -void StateServer::recvUnlock(faabric::transport::Message& body) +std::unique_ptr StateServer::recvUnlock( + faabric::transport::Message& body) { PARSE_MSG(faabric::StateRequest, body.data(), body.size()) @@ -192,7 +216,7 @@ void StateServer::recvUnlock(faabric::transport::Message& body) KV_FROM_REQUEST(msg) kv->unlockWrite(); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost()) + auto response = std::make_unique(); + return response; } } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index d5a7e41c2..055df1545 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -94,7 +94,7 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, } void MessageEndpoint::doSend(zmq::socket_t& socket, - uint8_t* data, + const uint8_t* data, size_t dataSize, bool more) { @@ -233,7 +233,7 @@ void SyncSendMessageEndpoint::sendHeader(int header) doSend(reqSocket, &headerBytes, sizeof(headerBytes), true); } -Message SyncSendMessageEndpoint::sendAwaitResponse(uint8_t* serialisedMsg, +Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* serialisedMsg, size_t msgSize, bool more) { diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 6fb069be3..02d431f69 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -12,24 +12,46 @@ MessageEndpointClient::MessageEndpointClient(std::string hostIn, int portIn) , syncEndpoint(host, asyncPort) {} -void MessageEndpointClient::asyncSend( - int header, - std::unique_ptr msg) -{} - -void MessageEndpointClient::syncSend( - int header, - std::unique_ptr msg, - std::unique_ptr response) +void MessageEndpointClient::asyncSend(int header, + google::protobuf::Message* msg) { + size_t msgSize = msg->ByteSizeLong(); + uint8_t sMsg[msgSize]; + + if (!msg->SerializeToArray(sMsg, msgSize)) { + throw std::runtime_error("Error serialising message"); + } + + asyncSend(header, sMsg, msgSize); +} + +void MessageEndpointClient::asyncSend(int header, uint8_t* buffer, size_t bufferSize) { syncEndpoint.sendHeader(header); + asyncEndpoint.send(buffer, bufferSize); +} + +void MessageEndpointClient::syncSend(int header, + google::protobuf::Message* msg, + google::protobuf::Message* response) +{ size_t msgSize = msg->ByteSizeLong(); uint8_t sMsg[msgSize]; if (!msg->SerializeToArray(sMsg, msgSize)) { throw std::runtime_error("Error serialising message"); } - Message responseMsg = syncEndpoint.sendAwaitResponse(sMsg, msgSize); + + syncSend(header, sMsg, msgSize, response); +} + +void MessageEndpointClient::syncSend(int header, + const uint8_t* buffer, + const size_t bufferSize, + google::protobuf::Message* response) +{ + syncEndpoint.sendHeader(header); + + Message responseMsg = syncEndpoint.sendAwaitResponse(buffer, bufferSize); // Deserialise message string if (!response->ParseFromArray(responseMsg.data(), responseMsg.size())) { throw std::runtime_error("Error deserialising message"); diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index dbbfb3ecd..e18892910 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -72,16 +72,15 @@ void MessageEndpointServer::start() assert(body.udata() != nullptr); // Server-specific message handling - std::unique_ptr resp = - doSyncRecv(header, body); - size_t msgSize = resp->ByteSizeLong(); - { - uint8_t sMsg[msgSize]; - if (!resp->SerializeToArray(sMsg, msgSize)) { - throw std::runtime_error("Error serialising message"); - } - endpoint.sendResponse(sMsg, msgSize); + std::unique_ptr resp = doSyncRecv(header, body); + size_t respSize = resp->ByteSizeLong(); + + uint8_t buffer[respSize]; + if (!resp->SerializeToArray(buffer, respSize)) { + throw std::runtime_error("Error serialising message"); } + + endpoint.sendResponse(buffer, respSize); } }); } @@ -109,4 +108,5 @@ void MessageEndpointServer::stop() syncThread.join(); } } + } diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index fd9396929..d4b864550 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -3,7 +3,7 @@ namespace faabric::transport { faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() { - faabric::transport::RecvMessageEndpoint endpoint(MPI_PORT); + faabric::transport::AsyncRecvMessageEndpoint endpoint(MPI_PORT); faabric::transport::Message m = endpoint.recv(); PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); @@ -19,7 +19,7 @@ void sendMpiHostRankMsg(const std::string& hostIn, if (!msg.SerializeToArray(sMsg, msgSize)) { throw std::runtime_error("Error serialising message"); } - faabric::transport::SendMessageEndpoint endpoint(hostIn, MPI_PORT); + faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, MPI_PORT); endpoint.send(sMsg, msgSize, false); } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index ff171c0c4..767f5ab86 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -23,12 +23,20 @@ class DummyServer final : public MessageEndpointServer int messageCount; private: - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override + void doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override { - // Dummy server, do nothing but increment the message count messageCount++; } + + std::unique_ptr doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) override + { + messageCount++; + + return std::make_unique(); + } }; class SlowServer final : public MessageEndpointServer @@ -42,13 +50,20 @@ class SlowServer final : public MessageEndpointServer {} private: - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override + void doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override + { + throw std::runtime_error("SlowServer not expecting async recv"); + } + + void doSyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) { SPDLOG_DEBUG("Slow message server test recv"); usleep(delayMs * 1000); - recvEndpoint->sendResponse(data.data(), data.size(), thisHost); + sendSyncResponse recvEndpoint->sendResponse( + data.data(), data.size(), thisHost); } }; From 0fe4ac3e6876d3c5313a3d9066a63a64f52a2eb2 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Thu, 24 Jun 2021 16:26:02 +0000 Subject: [PATCH 33/66] Tests now running (and hanging) --- include/faabric/transport/MessageEndpoint.h | 6 + include/faabric/transport/common.h | 7 +- src/transport/CMakeLists.txt | 2 + src/transport/MessageEndpoint.cpp | 24 +++- src/transport/MessageEndpointClient.cpp | 4 +- src/transport/MessageEndpointServer.cpp | 17 +-- src/transport/MpiMessageEndpoint.cpp | 1 + .../test_message_endpoint_client.cpp | 52 +++++---- tests/test/transport/test_message_server.cpp | 107 ++++++++++++------ 9 files changed, 147 insertions(+), 73 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 6f2e93744..8b69e865d 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -72,6 +72,10 @@ class AsyncSendMessageEndpoint : public MessageEndpoint int portIn, int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); + void sendHeader(int header); + + void sendShutdown(); + void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); private: @@ -87,6 +91,8 @@ class SyncSendMessageEndpoint : public MessageEndpoint void sendHeader(int header); + void sendShutdown(); + Message sendAwaitResponse(const uint8_t* serialisedMsg, size_t msgSize, bool more = false); diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index b7957801e..615eac5d1 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -3,9 +3,10 @@ #define DEFAULT_STATE_HOST "0.0.0.0" #define DEFAULT_FUNCTION_CALL_HOST "0.0.0.0" #define DEFAULT_SNAPSHOT_HOST "0.0.0.0" + +// Note - these ports must be spaced at least two apart #define STATE_PORT 8003 -#define FUNCTION_CALL_PORT 8004 -#define SNAPSHOT_PORT 8005 -#define REPLY_PORT_OFFSET 100 +#define FUNCTION_CALL_PORT 8013 +#define SNAPSHOT_PORT 8023 #define MPI_PORT 8800 diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index d7577ca34..4d5b3a651 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -8,6 +8,7 @@ set(HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/transport/macros.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/Message.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpoint.h" + "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointClient.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointServer.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MpiMessageEndpoint.h" ) @@ -16,6 +17,7 @@ set(LIB_FILES context.cpp Message.cpp MessageEndpoint.cpp + MessageEndpointClient.cpp MessageEndpointServer.cpp MpiMessageEndpoint.cpp ${HEADERS} diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 055df1545..cdaf69a09 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -91,6 +91,8 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, throw std::runtime_error("Opening unrecognized socket type"); } } + + return socket; } void MessageEndpoint::doSend(zmq::socket_t& socket, @@ -100,7 +102,7 @@ void MessageEndpoint::doSend(zmq::socket_t& socket, { assert(tid == std::this_thread::get_id()); zmq::send_flags sendFlags = - more ? zmq::send_flags::sndmore : zmq::send_flags::dontwait; + more ? zmq::send_flags::sndmore : zmq::send_flags::none; CATCH_ZMQ_ERR( { @@ -208,6 +210,19 @@ AsyncSendMessageEndpoint::AsyncSendMessageEndpoint(const std::string& hostIn, pushSocket = setUpSocket(zmq::socket_type::push, portIn); } +void AsyncSendMessageEndpoint::sendHeader(int header) +{ + uint8_t headerBytes = static_cast(header); + doSend(pushSocket, &headerBytes, sizeof(headerBytes), true); +} + +void AsyncSendMessageEndpoint::sendShutdown() +{ + int header = -1; + uint8_t headerBytes = static_cast(header); + doSend(pushSocket, &headerBytes, sizeof(headerBytes), false); +} + void AsyncSendMessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) @@ -233,6 +248,13 @@ void SyncSendMessageEndpoint::sendHeader(int header) doSend(reqSocket, &headerBytes, sizeof(headerBytes), true); } +void SyncSendMessageEndpoint::sendShutdown() +{ + int header = -1; + uint8_t headerBytes = static_cast(header); + doSend(reqSocket, &headerBytes, sizeof(headerBytes), false); +} + Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* serialisedMsg, size_t msgSize, bool more) diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 02d431f69..b1110a41b 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -1,5 +1,3 @@ -#pragma once - #include namespace faabric::transport { @@ -9,7 +7,7 @@ MessageEndpointClient::MessageEndpointClient(std::string hostIn, int portIn) , asyncPort(portIn) , syncPort(portIn + 1) , asyncEndpoint(host, asyncPort) - , syncEndpoint(host, asyncPort) + , syncEndpoint(host, syncPort) {} void MessageEndpointClient::asyncSend(int header, diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index e18892910..9fe337667 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -23,8 +23,8 @@ void MessageEndpointServer::start() Message header = endpoint.recv(); // Detect shutdown condition - if (header.size() == 0) { - SPDLOG_TRACE("Server received shutdown message"); + if (header.size() == sizeof(uint8_t) && !header.more()) { + SPDLOG_TRACE("Async server socket received shutdown message"); break; } @@ -54,8 +54,8 @@ void MessageEndpointServer::start() Message header = endpoint.recv(); // Detect shutdown condition - if (header.size() == 0) { - SPDLOG_TRACE("Server received shutdown message"); + if (header.size() == sizeof(uint8_t) && !header.more()) { + SPDLOG_TRACE("Sync server socket received shutdown message"); break; } @@ -72,7 +72,8 @@ void MessageEndpointServer::start() assert(body.udata() != nullptr); // Server-specific message handling - std::unique_ptr resp = doSyncRecv(header, body); + std::unique_ptr resp = + doSyncRecv(header, body); size_t respSize = resp->ByteSizeLong(); uint8_t buffer[respSize]; @@ -91,13 +92,13 @@ void MessageEndpointServer::stop() "Sending sync shutdown message locally to {}:{}", LOCALHOST, syncPort); SyncSendMessageEndpoint syncSender(LOCALHOST, syncPort); - syncSender.sendAwaitResponse(nullptr, 0); + syncSender.sendShutdown(); SPDLOG_TRACE( "Sending async shutdown message locally to {}:{}", LOCALHOST, asyncPort); - AsyncSendMessageEndpoint asyncSender(LOCALHOST, syncPort); - asyncSender.send(nullptr, 0); + AsyncSendMessageEndpoint asyncSender(LOCALHOST, asyncPort); + asyncSender.sendShutdown(); // Join the threads if (asyncThread.joinable()) { diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index d4b864550..ca22797d2 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -19,6 +19,7 @@ void sendMpiHostRankMsg(const std::string& hostIn, if (!msg.SerializeToArray(sMsg, msgSize)) { throw std::runtime_error("Error serialising message"); } + faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, MPI_PORT); endpoint.send(sMsg, msgSize, false); } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 82af9ff2c..5a1dbc080 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -19,16 +19,16 @@ TEST_CASE_METHOD(SchedulerTestFixture, "[transport]") { // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(thisHost, testPort); // Open the destination endpoint client, bind - RecvMessageEndpoint dst(testPort); + AsyncRecvMessageEndpoint dst(testPort); // Send message std::string expectedMsg = "Hello world!"; uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - REQUIRE_NOTHROW(src.send(msg, expectedMsg.size())); + src.send(msg, expectedMsg.size()); // Receive message faabric::transport::Message recvMsg = dst.recv(); @@ -45,23 +45,24 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") std::thread senderThread([expectedMsg, expectedResponse] { // Open the source endpoint client - SendMessageEndpoint src(thisHost, testPort); + SyncSendMessageEndpoint src(thisHost, testPort); // Send message and wait for response std::vector bytes(BYTES_CONST(expectedMsg.c_str()), BYTES_CONST(expectedMsg.c_str()) + expectedMsg.size()); - src.send(bytes.data(), bytes.size()); + + faabric::transport::Message recvMsg = + src.sendAwaitResponse(bytes.data(), bytes.size()); // Block waiting for a response - faabric::transport::Message recvMsg = src.awaitResponse(); assert(recvMsg.size() == expectedResponse.size()); std::string actualResponse(recvMsg.data(), recvMsg.size()); assert(actualResponse == expectedResponse); }); // Receive message - RecvMessageEndpoint dst(testPort); + SyncRecvMessageEndpoint dst(testPort); faabric::transport::Message recvMsg = dst.recv(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); @@ -70,7 +71,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") // Send response uint8_t msg[expectedResponse.size()]; memcpy(msg, expectedResponse.c_str(), expectedResponse.size()); - dst.sendResponse(msg, expectedResponse.size(), thisHost); + dst.sendResponse(msg, expectedResponse.size()); // Wait for sender thread if (senderThread.joinable()) { @@ -87,17 +88,17 @@ TEST_CASE_METHOD(SchedulerTestFixture, std::thread senderThread([numMessages, baseMsg] { // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(thisHost, testPort); for (int i = 0; i < numMessages; i++) { - std::string expectedMsg = baseMsg + std::to_string(i); - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - src.send(msg, expectedMsg.size()); + std::string msgData = baseMsg + std::to_string(i); + uint8_t msg[msgData.size()]; + memcpy(msg, msgData.c_str(), msgData.size()); + src.send(msg, msgData.size()); } }); // Receive messages - RecvMessageEndpoint dst(testPort); + AsyncRecvMessageEndpoint dst(testPort); for (int i = 0; i < numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -128,7 +129,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, for (int j = 0; j < numSenders; j++) { senderThreads.emplace_back(std::thread([numMessages, expectedMsg] { // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(thisHost, testPort); for (int i = 0; i < numMessages; i++) { uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); @@ -138,7 +139,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, } // Receive messages - RecvMessageEndpoint dst(testPort); + AsyncRecvMessageEndpoint dst(testPort); for (int i = 0; i < numSenders * numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -164,28 +165,35 @@ TEST_CASE_METHOD(SchedulerTestFixture, SECTION("Sanity check valid timeout") { - SendMessageEndpoint s(thisHost, testPort, 100); - RecvMessageEndpoint r(testPort, 100); + AsyncSendMessageEndpoint s(thisHost, testPort, 100); + AsyncRecvMessageEndpoint r(testPort, 100); + + SyncSendMessageEndpoint sB(thisHost, testPort + 10, 100); + SyncRecvMessageEndpoint rB(testPort + 10, 100); } SECTION("Recv zero timeout") { - REQUIRE_THROWS(RecvMessageEndpoint(testPort, 0)); + REQUIRE_THROWS(AsyncRecvMessageEndpoint(testPort, 0)); + REQUIRE_THROWS(SyncRecvMessageEndpoint(testPort + 10, 0)); } SECTION("Send zero timeout") { - REQUIRE_THROWS(SendMessageEndpoint(thisHost, testPort, 0)); + REQUIRE_THROWS(AsyncSendMessageEndpoint(thisHost, testPort, 0)); + REQUIRE_THROWS(SyncSendMessageEndpoint(thisHost, testPort + 10, 0)); } SECTION("Recv negative timeout") { - REQUIRE_THROWS(RecvMessageEndpoint(testPort, -1)); + REQUIRE_THROWS(AsyncRecvMessageEndpoint(testPort, -1)); + REQUIRE_THROWS(SyncRecvMessageEndpoint(testPort + 10, -1)); } SECTION("Send negative timeout") { - REQUIRE_THROWS(SendMessageEndpoint(thisHost, testPort, -1)); + REQUIRE_THROWS(AsyncSendMessageEndpoint(thisHost, testPort, -1)); + REQUIRE_THROWS(SyncSendMessageEndpoint(thisHost, testPort + 10, -1)); } } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 767f5ab86..92a31e58e 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -1,3 +1,4 @@ +#include "faabric/proto/faabric.pb.h" #include #include @@ -5,6 +6,7 @@ #include #include #include +#include using namespace faabric::transport; @@ -39,31 +41,59 @@ class DummyServer final : public MessageEndpointServer } }; +class EchoServer final : public MessageEndpointServer +{ + public: + EchoServer() + : MessageEndpointServer(testPort) + {} + + protected: + void doAsyncRecv(faabric::transport::Message& header, + faabric::transport::Message& body) override + { + throw std::runtime_error("EchoServer not expecting async recv"); + } + + std::unique_ptr doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) override + { + SPDLOG_TRACE("Echo server received {} bytes", body.size()); + + auto response = std::make_unique(); + response->set_data(body.data(), body.size()); + + return response; + } +}; + class SlowServer final : public MessageEndpointServer { public: int delayMs = 1000; - std::vector data = { 0, 1, 2, 3 }; SlowServer() : MessageEndpointServer(testPort) {} - private: + protected: void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) override { throw std::runtime_error("SlowServer not expecting async recv"); } - void doSyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) + std::unique_ptr doSyncRecv( + faabric::transport::Message& header, + faabric::transport::Message& body) override { SPDLOG_DEBUG("Slow message server test recv"); usleep(delayMs * 1000); - sendSyncResponse recvEndpoint->sendResponse( - data.data(), data.size(), thisHost); + auto response = std::make_unique(); + response->set_data("From the slow server"); + return response; } }; @@ -71,11 +101,11 @@ namespace tests { TEST_CASE("Test start/stop server", "[transport]") { DummyServer server; - REQUIRE_NOTHROW(server.start()); + server.start(); - usleep(1000 * 100); + usleep(100 * 1000); - REQUIRE_NOTHROW(server.stop()); + server.stop(); } TEST_CASE("Test send one message to server", "[transport]") @@ -85,7 +115,7 @@ TEST_CASE("Test send one message to server", "[transport]") server.start(); // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(thisHost, testPort); // Send message: server expects header + body std::string header = "header"; @@ -108,28 +138,32 @@ TEST_CASE("Test send one message to server", "[transport]") server.stop(); } -TEST_CASE("Test send one-off response to client", "[transport]") +TEST_CASE("Test send response to client", "[transport]") { - RecvMessageEndpoint recvEndpoint(testPort); + std::thread serverThread([] { + EchoServer server; + server.start(); + usleep(1000 * 1000); + server.stop(); + }); std::string expectedMsg = "Response from server"; - std::thread clientThread([expectedMsg] { - // Open the source endpoint client, don't bind - SendMessageEndpoint cli(thisHost, testPort); + // Open the source endpoint client, don't bind + SyncSendMessageEndpoint cli(thisHost, testPort + 1); + + // Send and await the response + cli.sendHeader(1); + Message responseMsg = + cli.sendAwaitResponse(BYTES(expectedMsg.data()), expectedMsg.size()); - Message msg = cli.awaitResponse(); - assert(msg.size() == expectedMsg.size()); - std::string actualMsg(msg.data(), msg.size()); - assert(actualMsg == expectedMsg); - }); + faabric::StatePart response; + response.ParseFromArray(responseMsg.data(), responseMsg.size()); - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - recvEndpoint.sendResponse(msg, expectedMsg.size(), thisHost); + assert(response.data() == expectedMsg); - if (clientThread.joinable()) { - clientThread.join(); + if (serverThread.joinable()) { + serverThread.join(); } } @@ -145,14 +179,13 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") for (int i = 0; i < numClients; i++) { clientThreads.emplace_back(std::thread([numMessages] { // Prepare client - SendMessageEndpoint cli(thisHost, testPort); + AsyncSendMessageEndpoint cli(thisHost, testPort); std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { // Send header - uint8_t header[clientMsg.size()]; - memcpy(header, clientMsg.c_str(), clientMsg.size()); - cli.send(header, clientMsg.size(), true); + cli.sendHeader(1); + // Send body uint8_t body[clientMsg.size()]; memcpy(body, clientMsg.c_str(), clientMsg.size()); @@ -206,20 +239,22 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") usleep(500 * 1000); // Set up the client - SendMessageEndpoint cli(thisHost, testPort, clientTimeout); + SyncSendMessageEndpoint cli(thisHost, testPort + 1, clientTimeout); std::vector data = { 1, 1, 1 }; - cli.send(data.data(), data.size(), true); - cli.send(data.data(), data.size()); + cli.sendHeader(1); if (expectFailure) { // Check for failure - REQUIRE_THROWS_AS(cli.awaitResponse(), MessageTimeoutException); + REQUIRE_THROWS_AS(cli.sendAwaitResponse(data.data(), data.size()), + MessageTimeoutException); } else { - // Check response from server successful - Message responseMessage = cli.awaitResponse(); + Message responseMsg = cli.sendAwaitResponse(data.data(), data.size()); + faabric::StatePart response; + response.ParseFromArray(responseMsg.data(), responseMsg.size()); + std::vector expected = { 0, 1, 2, 3 }; - REQUIRE(responseMessage.dataCopy() == expected); + REQUIRE(response.data() == "From the slow server"); } if (t.joinable()) { From 3e793185e69086d604f137e8510c974783d5cf77 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 07:06:30 +0000 Subject: [PATCH 34/66] Switch to using client class instead of sockets directly --- .../faabric/transport/MessageEndpointClient.h | 10 ++++--- src/transport/MessageEndpointClient.cpp | 6 ++-- tests/test/transport/test_message_server.cpp | 29 +++++++------------ 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h index c79a31120..9e080a508 100644 --- a/include/faabric/transport/MessageEndpointClient.h +++ b/include/faabric/transport/MessageEndpointClient.h @@ -9,10 +9,9 @@ namespace faabric::transport { class MessageEndpointClient { public: - MessageEndpointClient(std::string hostIn, int portIn); - - protected: - const std::string host; + MessageEndpointClient(std::string hostIn, + int portIn, + int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); void asyncSend(int header, google::protobuf::Message* msg); @@ -27,6 +26,9 @@ class MessageEndpointClient size_t bufferSize, google::protobuf::Message* response); + protected: + const std::string host; + private: const int asyncPort; diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index b1110a41b..b2b357ea2 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -2,12 +2,12 @@ namespace faabric::transport { -MessageEndpointClient::MessageEndpointClient(std::string hostIn, int portIn) +MessageEndpointClient::MessageEndpointClient(std::string hostIn, int portIn, int timeoutMs) : host(hostIn) , asyncPort(portIn) , syncPort(portIn + 1) - , asyncEndpoint(host, asyncPort) - , syncEndpoint(host, syncPort) + , asyncEndpoint(host, asyncPort, timeoutMs) + , syncEndpoint(host, syncPort, timeoutMs) {} void MessageEndpointClient::asyncSend(int header, diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 92a31e58e..4cfcd6f41 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -1,8 +1,9 @@ -#include "faabric/proto/faabric.pb.h" #include #include +#include +#include #include #include #include @@ -150,15 +151,11 @@ TEST_CASE("Test send response to client", "[transport]") std::string expectedMsg = "Response from server"; // Open the source endpoint client, don't bind - SyncSendMessageEndpoint cli(thisHost, testPort + 1); + MessageEndpointClient cli(thisHost, testPort); // Send and await the response - cli.sendHeader(1); - Message responseMsg = - cli.sendAwaitResponse(BYTES(expectedMsg.data()), expectedMsg.size()); - faabric::StatePart response; - response.ParseFromArray(responseMsg.data(), responseMsg.size()); + cli.syncSend(0, BYTES(expectedMsg.data()), expectedMsg.size(), &response); assert(response.data() == expectedMsg); @@ -179,17 +176,14 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") for (int i = 0; i < numClients; i++) { clientThreads.emplace_back(std::thread([numMessages] { // Prepare client - AsyncSendMessageEndpoint cli(thisHost, testPort); + MessageEndpointClient cli(thisHost, testPort); std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { - // Send header - cli.sendHeader(1); - // Send body uint8_t body[clientMsg.size()]; memcpy(body, clientMsg.c_str(), clientMsg.size()); - cli.send(body, clientMsg.size()); + cli.asyncSend(0, body, clientMsg.size()); } usleep(1000 * 300); @@ -239,19 +233,16 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") usleep(500 * 1000); // Set up the client - SyncSendMessageEndpoint cli(thisHost, testPort + 1, clientTimeout); - + MessageEndpointClient cli(thisHost, testPort, clientTimeout); std::vector data = { 1, 1, 1 }; - cli.sendHeader(1); + faabric::StatePart response; if (expectFailure) { // Check for failure - REQUIRE_THROWS_AS(cli.sendAwaitResponse(data.data(), data.size()), + REQUIRE_THROWS_AS(cli.syncSend(0, data.data(), data.size(), &response), MessageTimeoutException); } else { - Message responseMsg = cli.sendAwaitResponse(data.data(), data.size()); - faabric::StatePart response; - response.ParseFromArray(responseMsg.data(), responseMsg.size()); + cli.syncSend(0, data.data(), data.size(), &response); std::vector expected = { 0, 1, 2, 3 }; REQUIRE(response.data() == "From the slow server"); From 49b14107ef5afe2b6919782735aeab9c9ed59de5 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 07:06:45 +0000 Subject: [PATCH 35/66] Remove unused macros --- include/faabric/transport/macros.h | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index 1347f39a4..1adaff3be 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -1,36 +1,7 @@ #pragma once -#define SEND_MESSAGE(header, msg) \ - sendHeader(header); \ - size_t msgSize = msg.ByteSizeLong(); \ - { \ - uint8_t sMsg[msgSize]; \ - if (!msg.SerializeToArray(sMsg, msgSize)) { \ - throw std::runtime_error("Error serialising message"); \ - } \ - send(sMsg, msgSize); \ - } - -#define SEND_MESSAGE_PTR(header, msg) \ - sendHeader(header); \ - size_t msgSize = msg->ByteSizeLong(); \ - { \ - uint8_t sMsg[msgSize]; \ - if (!msg->SerializeToArray(sMsg, msgSize)) { \ - throw std::runtime_error("Error serialising message"); \ - } \ - send(sMsg, msgSize); \ - } - - #define PARSE_MSG(T, data, size) \ T msg; \ if (!msg.ParseFromArray(data, size)) { \ throw std::runtime_error("Error deserialising message"); \ } - -#define PARSE_RESPONSE(T, data, size) \ - T response; \ - if (!response.ParseFromArray(data, size)) { \ - throw std::runtime_error("Error deserialising message"); \ - } From 371091aa3e88b7a11d621f7cef01a8844b1dabaf Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 08:13:01 +0000 Subject: [PATCH 36/66] Explicitly set sync/async ports --- .../faabric/scheduler/FunctionCallClient.h | 4 +- .../faabric/transport/MessageEndpointClient.h | 3 +- .../faabric/transport/MessageEndpointServer.h | 2 +- .../faabric/transport/MpiMessageEndpoint.h | 23 ++++++---- include/faabric/transport/common.h | 12 ++--- include/faabric/transport/macros.h | 14 ++++++ src/flat/faabric.fbs | 2 - src/proto/faabric.proto | 12 ++--- src/scheduler/FunctionCallClient.cpp | 16 ++++--- src/scheduler/FunctionCallServer.cpp | 7 +-- src/scheduler/MpiWorld.cpp | 19 ++------ src/scheduler/SnapshotClient.cpp | 14 ++---- src/scheduler/SnapshotServer.cpp | 3 +- src/state/StateClient.cpp | 6 +-- src/state/StateServer.cpp | 2 +- src/transport/MessageEndpointClient.cpp | 31 ++++++++----- src/transport/MessageEndpointServer.cpp | 7 ++- src/transport/MpiMessageEndpoint.cpp | 44 +++++++------------ .../test_message_endpoint_client.cpp | 33 ++++++++++++-- tests/test/transport/test_message_server.cpp | 18 ++++---- 20 files changed, 145 insertions(+), 127 deletions(-) diff --git a/include/faabric/scheduler/FunctionCallClient.h b/include/faabric/scheduler/FunctionCallClient.h index b0d3e01fe..be86d98a1 100644 --- a/include/faabric/scheduler/FunctionCallClient.h +++ b/include/faabric/scheduler/FunctionCallClient.h @@ -13,13 +13,13 @@ namespace faabric::scheduler { // ----------------------------------- std::vector> getFunctionCalls(); -std::vector> getFlushCalls(); +std::vector> getFlushCalls(); std::vector< std::pair>> getBatchRequests(); -std::vector> +std::vector> getResourceRequests(); std::vector> diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h index 9e080a508..945e30509 100644 --- a/include/faabric/transport/MessageEndpointClient.h +++ b/include/faabric/transport/MessageEndpointClient.h @@ -10,7 +10,8 @@ class MessageEndpointClient { public: MessageEndpointClient(std::string hostIn, - int portIn, + int asyncPort, + int syncPort, int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); void asyncSend(int header, google::protobuf::Message* msg); diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 85234b28b..fa8de64b6 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -17,7 +17,7 @@ namespace faabric::transport { class MessageEndpointServer { public: - MessageEndpointServer(int portIn); + MessageEndpointServer(int asyncPortIn, int syncPortIn); void start(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 753c65532..afcf82bc1 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -2,11 +2,13 @@ #include #include +#include #include #include namespace faabric::transport { -/* These two abstract methods are used to broadcast the host-rank mapping at +/* + * These two abstract methods are used to broadcast the host-rank mapping at * initialisation time. */ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); @@ -14,25 +16,28 @@ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg); -/* This class abstracts the notion of a communication channel between two remote +/* + * This class abstracts the notion of a communication channel between two remote * MPI ranks. There will always be one rank local to this host, and one remote. - * Note that the port is unique per (user, function, sendRank, recvRank) tuple. + * Note that the ports are unique per (user, function, sendRank, recvRank) + * tuple. + * + * This is different to our normal messaging clients as it wraps two _async_ + * sockets. */ class MpiMessageEndpoint { public: - MpiMessageEndpoint(const std::string& hostIn, int portIn); - MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort); void sendMpiMessage(const std::shared_ptr& msg); std::shared_ptr recvMpiMessage(); - void close(); - private: - AsyncSendMessageEndpoint sendMessageEndpoint; - AsyncRecvMessageEndpoint recvMessageEndpoint; + std::string host; + + AsyncSendMessageEndpoint sendSocket; + AsyncRecvMessageEndpoint recvSocket; }; } diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index 615eac5d1..f9d1f95e2 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -4,9 +4,11 @@ #define DEFAULT_FUNCTION_CALL_HOST "0.0.0.0" #define DEFAULT_SNAPSHOT_HOST "0.0.0.0" -// Note - these ports must be spaced at least two apart -#define STATE_PORT 8003 -#define FUNCTION_CALL_PORT 8013 -#define SNAPSHOT_PORT 8023 +#define STATE_ASYNC_PORT 8003 +#define STATE_SYNC_PORT 8004 +#define FUNCTION_CALL_ASYNC_PORT 8005 +#define FUNCTION_CALL_SYNC_PORT 8006 +#define SNAPSHOT_SYNC_PORT 8007 +#define SNAPSHOT_ASYNC_PORT 8008 -#define MPI_PORT 8800 +#define MPI_BASE_PORT 8800 diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index 1adaff3be..abe0a7494 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -5,3 +5,17 @@ if (!msg.ParseFromArray(data, size)) { \ throw std::runtime_error("Error deserialising message"); \ } + +#define SERIALISE_MSG(msg) \ + size_t msgSize = msg.ByteSizeLong(); \ + uint8_t buffer[msgSize]; \ + if (!msg.SerializeToArray(buffer, msgSize)) { \ + throw std::runtime_error("Error serialising message"); \ + } + +#define SERIALISE_MSG_PTR(msg) \ + size_t msgSize = msg->ByteSizeLong(); \ + uint8_t buffer[msgSize]; \ + if (!msg->SerializeToArray(buffer, msgSize)) { \ + throw std::runtime_error("Error serialising message"); \ + } diff --git a/src/flat/faabric.fbs b/src/flat/faabric.fbs index d173a2d7d..80bf93568 100644 --- a/src/flat/faabric.fbs +++ b/src/flat/faabric.fbs @@ -1,5 +1,4 @@ table SnapshotPushRequest { - return_host:string; key:string; contents:[ubyte]; } @@ -14,7 +13,6 @@ table SnapshotDiffChunk { } table SnapshotDiffPushRequest { - return_host:string; key:string; chunks:[SnapshotDiffChunk]; } diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index a44a4668c..296f4e232 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -10,6 +10,10 @@ message EmptyResponse { int32 empty = 1; } +message EmptyRequest { + int32 empty = 1; +} + // --------------------------------------------- // FUNCTION SCHEDULING // --------------------------------------------- @@ -35,10 +39,6 @@ message BatchExecuteRequest { bytes contextData = 7; } -message ResponseRequest { - string returnHost = 1; -} - message HostResources { int32 slots = 1; int32 usedSlots = 2; @@ -157,7 +157,6 @@ message StateRequest { string user = 1; string key = 2; bytes data = 3; - string returnHost = 4; } message StateChunkRequest { @@ -165,7 +164,6 @@ message StateChunkRequest { string key = 2; uint64 offset = 3; uint64 chunkSize = 4; - string returnHost = 5; } message StateResponse { @@ -179,7 +177,6 @@ message StatePart { string key = 2; uint64 offset = 3; bytes data = 4; - string returnHost = 5; } message StateSizeResponse { @@ -192,7 +189,6 @@ message StateAppendedRequest { string user = 1; string key = 2; uint32 nValues = 3; - string returnHost = 4; } message StateAppendedResponse { diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 7f8e736a5..701b29381 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -14,13 +14,13 @@ std::mutex mockMutex; static std::vector> functionCalls; -static std::vector> flushCalls; +static std::vector> flushCalls; static std::vector< std::pair>> batchMessages; -static std::vector> +static std::vector> resourceRequests; static std::unordered_map> getFunctionCalls() return functionCalls; } -std::vector> getFlushCalls() +std::vector> getFlushCalls() { return flushCalls; } @@ -47,7 +47,7 @@ getBatchRequests() return batchMessages; } -std::vector> +std::vector> getResourceRequests() { return resourceRequests; @@ -81,12 +81,14 @@ void clearMockRequests() // Message Client // ----------------------------------- FunctionCallClient::FunctionCallClient(const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, FUNCTION_CALL_PORT) + : faabric::transport::MessageEndpointClient(hostIn, + FUNCTION_CALL_ASYNC_PORT, + FUNCTION_CALL_SYNC_PORT) {} void FunctionCallClient::sendFlush() { - faabric::ResponseRequest req; + faabric::EmptyRequest req; if (faabric::util::isMockMode()) { faabric::util::UniqueLock lock(mockMutex); flushCalls.emplace_back(host, req); @@ -98,7 +100,7 @@ void FunctionCallClient::sendFlush() faabric::HostResources FunctionCallClient::getResources() { - faabric::ResponseRequest request; + faabric::EmptyRequest request; faabric::HostResources response; if (faabric::util::isMockMode()) { diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 961dfccad..53fd0b7f5 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -9,7 +9,8 @@ namespace faabric::scheduler { FunctionCallServer::FunctionCallServer() - : faabric::transport::MessageEndpointServer(FUNCTION_CALL_PORT) + : faabric::transport::MessageEndpointServer(FUNCTION_CALL_ASYNC_PORT, + FUNCTION_CALL_SYNC_PORT) , scheduler(getScheduler()) {} @@ -58,8 +59,6 @@ std::unique_ptr FunctionCallServer::doSyncRecv( std::unique_ptr FunctionCallServer::recvFlush( faabric::transport::Message& body) { - PARSE_MSG(faabric::ResponseRequest, body.data(), body.size()); - // Clear out any cached state faabric::state::getGlobalState().forceClearAll(false); @@ -92,8 +91,6 @@ void FunctionCallServer::recvUnregister(faabric::transport::Message& body) std::unique_ptr FunctionCallServer::recvGetResources( faabric::transport::Message& body) { - PARSE_MSG(faabric::ResponseRequest, body.data(), body.size()) - auto response = std::make_unique( scheduler.getThisHostResources()); return response; diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index e0a214fdd..9c2c862b3 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -173,17 +173,6 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) void MpiWorld::destroy() { - // Destroy once per thread the rank-specific data structures - // Remote message endpoints - if (!mpiMessageEndpoints.empty()) { - for (auto& e : mpiMessageEndpoints) { - if (e != nullptr) { - e->close(); - } - } - mpiMessageEndpoints.clear(); - } - // Unacked message buffers if (!unackedMessageBuffers.empty()) { for (auto& umb : unackedMessageBuffers) { @@ -289,8 +278,8 @@ std::pair MpiWorld::getPortForRanks(int localRank, int remoteRank) void MpiWorld::setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg) { // Assert we are only setting the values once - assert(rankHosts.size() == 0); - assert(basePorts.size() == 0); + assert(rankHosts.empty()); + assert(basePorts.empty()); assert(msg.hosts().size() == size); assert(msg.baseports().size() == size); @@ -1224,10 +1213,10 @@ std::vector MpiWorld::initLocalBasePorts( basePortForRank.reserve(size); std::string lastHost = thisHost; - int lastPort = MPI_PORT; + int lastPort = MPI_BASE_PORT; for (const auto& host : executedAt) { if (host == thisHost) { - basePortForRank.push_back(MPI_PORT); + basePortForRank.push_back(MPI_BASE_PORT); } else if (host == lastHost) { basePortForRank.push_back(lastPort); } else { diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 408bcdad9..525b0359c 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -70,7 +70,7 @@ void clearMockSnapshotRequests() // ----------------------------------- SnapshotClient::SnapshotClient(const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_PORT) + : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_ASYNC_PORT, SNAPSHOT_SYNC_PORT) {} void SnapshotClient::pushSnapshot(const std::string& key, @@ -82,17 +82,13 @@ void SnapshotClient::pushSnapshot(const std::string& key, faabric::util::UniqueLock lock(mockMutex); snapshotPushes.emplace_back(host, data); } else { - const faabric::util::SystemConfig& conf = - faabric::util::getSystemConfig(); - // Set up the main request // TODO - avoid copying data here flatbuffers::FlatBufferBuilder mb; - auto returnHostOffset = mb.CreateString(conf.endpointHost); auto keyOffset = mb.CreateString(key); auto dataOffset = mb.CreateVector(data.data, data.size); auto requestOffset = CreateSnapshotPushRequest( - mb, returnHostOffset, keyOffset, dataOffset); + mb, keyOffset, dataOffset); // Send it mb.Finish(requestOffset); @@ -119,9 +115,6 @@ void SnapshotClient::pushSnapshotDiffs( snapshotKey, host); - const faabric::util::SystemConfig& conf = - faabric::util::getSystemConfig(); - flatbuffers::FlatBufferBuilder mb; // Create objects for all the chunks @@ -134,11 +127,10 @@ void SnapshotClient::pushSnapshotDiffs( // Set up the main request // TODO - avoid copying data here - auto returnHostOffset = mb.CreateString(conf.endpointHost); auto keyOffset = mb.CreateString(snapshotKey); auto diffsOffset = mb.CreateVector(diffsFbVector); auto requestOffset = CreateSnapshotDiffPushRequest( - mb, returnHostOffset, keyOffset, diffsOffset); + mb, keyOffset, diffsOffset); mb.Finish(requestOffset); uint8_t* buffer = mb.GetBufferPointer(); diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 45a51ef8c..934c54e78 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -12,7 +12,8 @@ namespace faabric::scheduler { SnapshotServer::SnapshotServer() - : faabric::transport::MessageEndpointServer(SNAPSHOT_PORT) + : faabric::transport::MessageEndpointServer(SNAPSHOT_ASYNC_PORT, + SNAPSHOT_SYNC_PORT) {} void SnapshotServer::doAsyncRecv(faabric::transport::Message& header, diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index dfd2a0480..91f40f835 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -8,7 +8,9 @@ namespace faabric::state { StateClient::StateClient(const std::string& userIn, const std::string& keyIn, const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, STATE_PORT) + : faabric::transport::MessageEndpointClient(hostIn, + STATE_ASYNC_PORT, + STATE_SYNC_PORT) , user(userIn) , key(keyIn) {} @@ -53,7 +55,6 @@ void StateClient::pullChunks(const std::vector& chunks, request.set_key(key); request.set_offset(chunk.offset); request.set_chunksize(chunk.length); - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); // Send and copy response into place faabric::StatePart response; @@ -76,7 +77,6 @@ void StateClient::pullAppended(uint8_t* buffer, size_t length, long nValues) request.set_user(user); request.set_key(key); request.set_nvalues(nValues); - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); faabric::StateAppendedResponse response; syncSend(faabric::state::StateCalls::PullAppended, &request, &response); diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index 703e6e72c..e83f72470 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -13,7 +13,7 @@ namespace faabric::state { StateServer::StateServer(State& stateIn) - : faabric::transport::MessageEndpointServer(STATE_PORT) + : faabric::transport::MessageEndpointServer(STATE_ASYNC_PORT, STATE_SYNC_PORT) , state(stateIn) {} diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index b2b357ea2..6f028fac3 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -2,10 +2,13 @@ namespace faabric::transport { -MessageEndpointClient::MessageEndpointClient(std::string hostIn, int portIn, int timeoutMs) +MessageEndpointClient::MessageEndpointClient(std::string hostIn, + int asyncPortIn, + int syncPortIn, + int timeoutMs) : host(hostIn) - , asyncPort(portIn) - , syncPort(portIn + 1) + , asyncPort(asyncPortIn) + , syncPort(syncPortIn) , asyncEndpoint(host, asyncPort, timeoutMs) , syncEndpoint(host, syncPort, timeoutMs) {} @@ -14,17 +17,20 @@ void MessageEndpointClient::asyncSend(int header, google::protobuf::Message* msg) { size_t msgSize = msg->ByteSizeLong(); - uint8_t sMsg[msgSize]; + uint8_t buffer[msgSize]; - if (!msg->SerializeToArray(sMsg, msgSize)) { + if (!msg->SerializeToArray(buffer, msgSize)) { throw std::runtime_error("Error serialising message"); } - asyncSend(header, sMsg, msgSize); + asyncSend(header, buffer, msgSize); } -void MessageEndpointClient::asyncSend(int header, uint8_t* buffer, size_t bufferSize) { - syncEndpoint.sendHeader(header); +void MessageEndpointClient::asyncSend(int header, + uint8_t* buffer, + size_t bufferSize) +{ + asyncEndpoint.sendHeader(header); asyncEndpoint.send(buffer, bufferSize); } @@ -34,12 +40,12 @@ void MessageEndpointClient::syncSend(int header, google::protobuf::Message* response) { size_t msgSize = msg->ByteSizeLong(); - uint8_t sMsg[msgSize]; - if (!msg->SerializeToArray(sMsg, msgSize)) { + uint8_t buffer[msgSize]; + if (!msg->SerializeToArray(buffer, msgSize)) { throw std::runtime_error("Error serialising message"); } - syncSend(header, sMsg, msgSize, response); + syncSend(header, buffer, msgSize, response); } void MessageEndpointClient::syncSend(int header, @@ -50,7 +56,8 @@ void MessageEndpointClient::syncSend(int header, syncEndpoint.sendHeader(header); Message responseMsg = syncEndpoint.sendAwaitResponse(buffer, bufferSize); - // Deserialise message string + + // Deserialise response if (!response->ParseFromArray(responseMsg.data(), responseMsg.size())) { throw std::runtime_error("Error deserialising message"); } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 9fe337667..2d43999ec 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -7,9 +7,9 @@ #include namespace faabric::transport { -MessageEndpointServer::MessageEndpointServer(int portIn) - : asyncPort(portIn) - , syncPort(portIn + 1) +MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) + : asyncPort(asyncPortIn) + , syncPort(syncPortIn) {} void MessageEndpointServer::start() @@ -69,7 +69,6 @@ void MessageEndpointServer::start() if (body.more()) { throw std::runtime_error("Body sent with SNDMORE flag"); } - assert(body.udata() != nullptr); // Server-specific message handling std::unique_ptr resp = diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index ca22797d2..abb284c5b 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -1,9 +1,12 @@ #include +#include namespace faabric::transport { + faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() { - faabric::transport::AsyncRecvMessageEndpoint endpoint(MPI_PORT); + SPDLOG_TRACE("Receiving MPI host ranks on {}", MPI_BASE_PORT); + faabric::transport::AsyncRecvMessageEndpoint endpoint(MPI_BASE_PORT); faabric::transport::Message m = endpoint.recv(); PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); @@ -13,50 +16,33 @@ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg) { - size_t msgSize = msg.ByteSizeLong(); - { - uint8_t sMsg[msgSize]; - if (!msg.SerializeToArray(sMsg, msgSize)) { - throw std::runtime_error("Error serialising message"); - } - - faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, MPI_PORT); - endpoint.send(sMsg, msgSize, false); - } + SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, MPI_BASE_PORT); + faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, + MPI_BASE_PORT); + SERIALISE_MSG(msg) + endpoint.send(buffer, msgSize, false); } -MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) - : sendMessageEndpoint(hostIn, portIn) - , recvMessageEndpoint(portIn) -{} - MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort) - : sendMessageEndpoint(hostIn, sendPort) - , recvMessageEndpoint(recvPort) + : host(hostIn) + , sendSocket(hostIn, sendPort) + , recvSocket(recvPort) {} void MpiMessageEndpoint::sendMpiMessage( const std::shared_ptr& msg) { - size_t msgSize = msg->ByteSizeLong(); - { - uint8_t sMsg[msgSize]; - if (!msg->SerializeToArray(sMsg, msgSize)) { - throw std::runtime_error("Error serialising message"); - } - sendMessageEndpoint.send(sMsg, msgSize, false); - } + SERIALISE_MSG_PTR(msg) + sendSocket.send(buffer, msgSize, false); } std::shared_ptr MpiMessageEndpoint::recvMpiMessage() { - Message m = recvMessageEndpoint.recv(); + Message m = recvSocket.recv(); PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); return std::make_shared(msg); } - -void MpiMessageEndpoint::close() {} } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 5a1dbc080..37898f107 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -18,10 +18,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv one message", "[transport]") { - // Open the source endpoint client, don't bind AsyncSendMessageEndpoint src(thisHost, testPort); - - // Open the destination endpoint client, bind AsyncRecvMessageEndpoint dst(testPort); // Send message @@ -37,6 +34,36 @@ TEST_CASE_METHOD(SchedulerTestFixture, REQUIRE(actualMsg == expectedMsg); } +TEST_CASE_METHOD(SchedulerTestFixture, + "Test send before recv is ready", + "[transport]") +{ + std::string expectedMsg = "Hello world!"; + + AsyncSendMessageEndpoint src(thisHost, testPort); + + // Run the recv in the background + std::thread recvThread([expectedMsg] { + usleep(1000 * 1000); + AsyncRecvMessageEndpoint dst(testPort); + + // Receive message + faabric::transport::Message recvMsg = dst.recv(); + assert(recvMsg.size() == expectedMsg.size()); + std::string actualMsg(recvMsg.data(), recvMsg.size()); + assert(actualMsg == expectedMsg); + }); + + // Send message (should wait for receiver to become ready) + uint8_t msg[expectedMsg.size()]; + memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); + src.send(msg, expectedMsg.size()); + + if(recvThread.joinable()) { + recvThread.join(); + } +} + TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") { // Prepare common message/response diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 4cfcd6f41..377e6e0b9 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -12,13 +12,14 @@ using namespace faabric::transport; const std::string thisHost = "127.0.0.1"; -const int testPort = 9999; +const int testPortAsync = 9998; +const int testPortSync = 9999; class DummyServer final : public MessageEndpointServer { public: DummyServer() - : MessageEndpointServer(testPort) + : MessageEndpointServer(testPortAsync, testPortSync) , messageCount(0) {} @@ -46,7 +47,7 @@ class EchoServer final : public MessageEndpointServer { public: EchoServer() - : MessageEndpointServer(testPort) + : MessageEndpointServer(testPortAsync, testPortSync) {} protected: @@ -75,7 +76,7 @@ class SlowServer final : public MessageEndpointServer int delayMs = 1000; SlowServer() - : MessageEndpointServer(testPort) + : MessageEndpointServer(testPortAsync, testPortSync) {} protected: @@ -116,7 +117,7 @@ TEST_CASE("Test send one message to server", "[transport]") server.start(); // Open the source endpoint client, don't bind - AsyncSendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(thisHost, testPortAsync, testPortSync); // Send message: server expects header + body std::string header = "header"; @@ -151,7 +152,7 @@ TEST_CASE("Test send response to client", "[transport]") std::string expectedMsg = "Response from server"; // Open the source endpoint client, don't bind - MessageEndpointClient cli(thisHost, testPort); + MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); // Send and await the response faabric::StatePart response; @@ -176,7 +177,7 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") for (int i = 0; i < numClients; i++) { clientThreads.emplace_back(std::thread([numMessages] { // Prepare client - MessageEndpointClient cli(thisHost, testPort); + MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { @@ -233,7 +234,8 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") usleep(500 * 1000); // Set up the client - MessageEndpointClient cli(thisHost, testPort, clientTimeout); + MessageEndpointClient cli( + thisHost, testPortAsync, testPortSync, clientTimeout); std::vector data = { 1, 1, 1 }; faabric::StatePart response; From 46f9d8d5aa9b0e6c4ac9cf164bbb9b32b6a74ab5 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 09:49:29 +0000 Subject: [PATCH 37/66] Split MPI world broadcasting of ranks into separate function --- include/faabric/scheduler/MpiWorld.h | 7 +- .../faabric/transport/MpiMessageEndpoint.h | 1 + src/scheduler/MpiContext.cpp | 5 +- src/scheduler/MpiWorld.cpp | 87 ++++++++++--------- src/transport/MpiMessageEndpoint.cpp | 15 ++++ tests/test/scheduler/test_mpi_world.cpp | 16 ---- .../test/scheduler/test_remote_mpi_worlds.cpp | 21 ++++- .../test_message_endpoint_client.cpp | 28 +++++- .../transport/test_mpi_message_endpoint.cpp | 14 ++- tests/utils/fixtures.h | 4 +- 10 files changed, 122 insertions(+), 76 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index fc91d4bbb..b9df07182 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -22,12 +22,11 @@ class MpiWorld void create(const faabric::Message& call, int newId, int newSize); - void initialiseFromMsg(const faabric::Message& msg, - bool forceLocal = false); + void broadcastHostsToRanks(); - std::string getHostForRank(int rank); + void initialiseFromMsg(const faabric::Message& msg); - void setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg); + std::string getHostForRank(int rank); std::string getUser(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index afcf82bc1..2b9d6253b 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -7,6 +7,7 @@ #include namespace faabric::transport { + /* * These two abstract methods are used to broadcast the host-rank mapping at * initialisation time. diff --git a/src/scheduler/MpiContext.cpp b/src/scheduler/MpiContext.cpp index c279f5be4..3e2f2fa4d 100644 --- a/src/scheduler/MpiContext.cpp +++ b/src/scheduler/MpiContext.cpp @@ -25,7 +25,10 @@ int MpiContext::createWorld(const faabric::Message& msg) // Create the MPI world scheduler::MpiWorldRegistry& reg = scheduler::getMpiWorldRegistry(); - reg.createWorld(msg, worldId); + scheduler::MpiWorld &world = reg.createWorld(msg, worldId); + + // Broadcast setup to other hosts + world.broadcastHostsToRanks(); // Set up this context isMpi = true; diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 9c2c862b3..8a419acf2 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -146,29 +146,36 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) // Prepend this host for rank 0 executedAt.insert(executedAt.begin(), thisHost); + // Record rank-to-host mapping and base ports + rankHosts = executedAt; + basePorts = initLocalBasePorts(executedAt); + + // Initialise the memory queues for message reception + initLocalQueues(); +} + +void MpiWorld::broadcastHostsToRanks() +{ + // Set up a list of hosts to broadcast to (excluding this host) + std::set targetHosts(rankHosts.begin(), rankHosts.end()); + targetHosts.erase(thisHost); + + if (targetHosts.empty()) { + SPDLOG_DEBUG("Not broadcasting rank-to-host mapping, no other hosts"); + return; + } + // Register hosts to rank mappings on this host faabric::MpiHostsToRanksMessage hostRankMsg; - *hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() }; + *hostRankMsg.mutable_hosts() = { rankHosts.begin(), rankHosts.end() }; // Prepare the base port for each rank - std::vector basePortForRank = initLocalBasePorts(executedAt); - *hostRankMsg.mutable_baseports() = { basePortForRank.begin(), - basePortForRank.end() }; - - // Register hosts to rank mappins on this host - setAllRankHostsPorts(hostRankMsg); - - // Set up a list of hosts to broadcast to (excluding this host) - std::set hosts(executedAt.begin(), executedAt.end()); - hosts.erase(thisHost); + *hostRankMsg.mutable_baseports() = { basePorts.begin(), basePorts.end() }; // Do the broadcast - for (const auto& h : hosts) { + for (const auto& h : targetHosts) { faabric::transport::sendMpiHostRankMsg(h, hostRankMsg); } - - // Initialise the memory queues for message reception - initLocalQueues(); } void MpiWorld::destroy() @@ -206,25 +213,34 @@ void MpiWorld::destroy() } } -void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) +void MpiWorld::initialiseFromMsg(const faabric::Message& msg) { id = msg.mpiworldid(); user = msg.user(); function = msg.function(); size = msg.mpiworldsize(); - // Sometimes for testing purposes we may want to initialise a world in the - // _same_ host we have created one (note that this would never happen in - // reality). If so, we skip initialising resources already initialised - if (!forceLocal) { - // Block until we receive - faabric::MpiHostsToRanksMessage hostRankMsg = - faabric::transport::recvMpiHostRankMsg(); - setAllRankHostsPorts(hostRankMsg); - - // Initialise the memory queues for message reception - initLocalQueues(); - } + // Block until we receive + faabric::MpiHostsToRanksMessage hostRankMsg = + faabric::transport::recvMpiHostRankMsg(); + + // Prepare the host-rank map with a vector containing _all_ ranks + // Note - this method should be called by only one rank. This is + // enforced in the world registry. + + // Assert we are only setting the values once + assert(rankHosts.empty()); + assert(basePorts.empty()); + + assert(hostRankMsg.hosts().size() == size); + assert(hostRankMsg.baseports().size() == size); + + rankHosts = { hostRankMsg.hosts().begin(), hostRankMsg.hosts().end() }; + basePorts = { hostRankMsg.baseports().begin(), + hostRankMsg.baseports().end() }; + + // Initialise the memory queues for message reception + initLocalQueues(); } std::string MpiWorld::getHostForRank(int rank) @@ -272,21 +288,6 @@ std::pair MpiWorld::getPortForRanks(int localRank, int remoteRank) return sendRecvPortPair; } -// Prepare the host-rank map with a vector containing _all_ ranks -// Note - this method should be called by only one rank. This is enforced in -// the world registry -void MpiWorld::setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg) -{ - // Assert we are only setting the values once - assert(rankHosts.empty()); - assert(basePorts.empty()); - - assert(msg.hosts().size() == size); - assert(msg.baseports().size() == size); - rankHosts = { msg.hosts().begin(), msg.hosts().end() }; - basePorts = { msg.baseports().begin(), msg.baseports().end() }; -} - void MpiWorld::getCartesianRank(int rank, int maxDims, const int* dims, diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index abb284c5b..9c5435dca 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -1,10 +1,20 @@ #include #include +#include namespace faabric::transport { +static std::vector rankMessages; + faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() { + if (faabric::util::isMockMode()) { + assert(!rankMessages.empty()); + faabric::MpiHostsToRanksMessage msg = rankMessages.back(); + rankMessages.pop_back(); + return msg; + } + SPDLOG_TRACE("Receiving MPI host ranks on {}", MPI_BASE_PORT); faabric::transport::AsyncRecvMessageEndpoint endpoint(MPI_BASE_PORT); faabric::transport::Message m = endpoint.recv(); @@ -16,6 +26,11 @@ faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() void sendMpiHostRankMsg(const std::string& hostIn, const faabric::MpiHostsToRanksMessage msg) { + if (faabric::util::isMockMode()) { + rankMessages.push_back(msg); + return; + } + SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, MPI_BASE_PORT); faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, MPI_BASE_PORT); diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 5144de4ba..0dec3519e 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -63,22 +63,6 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test creating world of size 1", "[mpi]") world.destroy(); } -TEST_CASE_METHOD(MpiTestFixture, "Test world loading from msg", "[mpi]") -{ - // Create another copy from state - scheduler::MpiWorld worldB; - // Force creating the second world in the _same_ host - bool forceLocal = true; - worldB.initialiseFromMsg(msg, forceLocal); - - REQUIRE(worldB.getSize() == worldSize); - REQUIRE(worldB.getId() == worldId); - REQUIRE(worldB.getUser() == user); - REQUIRE(worldB.getFunction() == func); - - worldB.destroy(); -} - TEST_CASE_METHOD(MpiBaseTestFixture, "Test cartesian communicator", "[mpi]") { MpiWorld world; diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 0669fb245..d4f962bc8 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -43,9 +43,12 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - remoteWorld.initialiseFromMsg(msg); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); + + remoteWorld.initialiseFromMsg(msg); + // Now check both world instances report the same mappings REQUIRE(localWorld.getHostForRank(0) == thisHost); REQUIRE(localWorld.getHostForRank(1) == otherHost); @@ -67,6 +70,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); + std::thread senderThread([this, rankA, rankB, &messageData] { remoteWorld.initialiseFromMsg(msg); @@ -112,6 +117,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); + std::thread senderThread([this, rankA, rankB, &messageData, &messageData2] { remoteWorld.initialiseFromMsg(msg); @@ -169,6 +176,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test barrier across hosts", "[mpi]") MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); + std::thread senderThread([this, rankA, rankB, &sendData, &recvData] { remoteWorld.initialiseFromMsg(msg); @@ -213,6 +222,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); + std::thread senderThread([this, rankA, rankB, numMessages] { remoteWorld.initialiseFromMsg(msg); @@ -254,6 +265,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); + std::thread senderThread([this, &messageData] { remoteWorld.initialiseFromMsg(msg); @@ -302,6 +315,8 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); + // Build the data int nPerRank = 4; int dataSize = nPerRank * thisWorldSize; @@ -399,6 +414,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); // Build the data for each rank int nPerRank = 4; @@ -485,6 +501,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, // Init world MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); std::thread senderThread([this, sendRank, recvRank, &messageData] { remoteWorld.initialiseFromMsg(msg); @@ -547,6 +564,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, // Init world MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); std::thread senderThread([this, sendRank, recvRank] { remoteWorld.initialiseFromMsg(msg); @@ -607,6 +625,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, // Init world MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + localWorld.broadcastHostsToRanks(); std::thread senderThread([this, worldSize] { std::vector remoteRanks = { 1, 2 }; diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 37898f107..2fc6223b3 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -59,11 +59,37 @@ TEST_CASE_METHOD(SchedulerTestFixture, memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); src.send(msg, expectedMsg.size()); - if(recvThread.joinable()) { + if (recvThread.joinable()) { recvThread.join(); } } +TEST_CASE_METHOD(SchedulerTestFixture, + "Test send out of scope before recv", + "[transport]") +{ + std::string expectedMsg = "Hello world!"; + + // Send message and let socket go out of scope + { + AsyncSendMessageEndpoint src(thisHost, testPort); + uint8_t msg[expectedMsg.size()]; + memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); + src.send(msg, expectedMsg.size()); + } + + // Recieve message in its own scope too + { + usleep(100 * 1000); + AsyncRecvMessageEndpoint dst(testPort); + + faabric::transport::Message recvMsg = dst.recv(); + REQUIRE(recvMsg.size() == expectedMsg.size()); + std::string actualMsg(recvMsg.data(), recvMsg.size()); + REQUIRE(actualMsg == expectedMsg); + } +} + TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") { // Prepare common message/response diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index 09d535ad3..32a67d994 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -10,19 +10,18 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test send and recv the hosts to rank message", "[transport]") { - // Prepare message std::vector expected = { "foo", "bar" }; + + // Send the message faabric::MpiHostsToRanksMessage sendMsg; *sendMsg.mutable_hosts() = { expected.begin(), expected.end() }; sendMpiHostRankMsg(LOCALHOST, sendMsg); - // Send message + // Receive and check faabric::MpiHostsToRanksMessage actual = recvMpiHostRankMsg(); - - // Checks - REQUIRE(actual.hosts().size() == expected.size()); + assert(actual.hosts().size() == expected.size()); for (int i = 0; i < actual.hosts().size(); i++) { - REQUIRE(actual.hosts().Get(i) == expected[i]); + assert(actual.hosts().Get(i) == expected[i]); } } @@ -43,8 +42,5 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Checks REQUIRE(expected->id() == actual->id()); - - REQUIRE_NOTHROW(sendEndpoint.close()); - REQUIRE_NOTHROW(recvEndpoint.close()); } } diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index ef1fc3e6b..532b1f7e6 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -157,7 +157,9 @@ class MpiBaseTestFixture : public SchedulerTestFixture class MpiTestFixture : public MpiBaseTestFixture { public: - MpiTestFixture() { world.create(msg, worldId, worldSize); } + MpiTestFixture() { + world.create(msg, worldId, worldSize); + } ~MpiTestFixture() { world.destroy(); } From f2e9113359ed0abbf8331313329a721d86bbc869 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 10:15:06 +0000 Subject: [PATCH 38/66] Move send/ recv host ranks to MpiWorld --- include/faabric/scheduler/MpiWorld.h | 16 ++++++- .../faabric/transport/MpiMessageEndpoint.h | 9 ---- src/scheduler/MpiWorld.cpp | 48 ++++++++++++++++--- src/transport/MpiMessageEndpoint.cpp | 34 ------------- 4 files changed, 56 insertions(+), 51 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index b9df07182..a23e39b31 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -180,11 +180,13 @@ class MpiWorld double getWTime(); private: - int id; - int size; + int id = -1; + int size = -1; std::string thisHost; faabric::util::TimePoint creationTime; + faabric::transport::AsyncRecvMessageEndpoint ranksRecvEndpoint; + std::shared_mutex worldMutex; std::atomic_flag isDestroyed = false; @@ -207,13 +209,23 @@ class MpiWorld std::vector basePorts; std::vector initLocalBasePorts( const std::vector& executedAt); + void initRemoteMpiEndpoint(int localRank, int remoteRank); + std::pair getPortForRanks(int localRank, int remoteRank); + void sendRemoteMpiMessage(int sendRank, int recvRank, const std::shared_ptr& msg); + std::shared_ptr recvRemoteMpiMessage(int sendRank, int recvRank); + + faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); + + void sendMpiHostRankMsg(const std::string& hostIn, + const faabric::MpiHostsToRanksMessage msg); + void closeMpiMessageEndpoints(); // Support for asyncrhonous communications diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 2b9d6253b..0b917bf8d 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -8,15 +8,6 @@ namespace faabric::transport { -/* - * These two abstract methods are used to broadcast the host-rank mapping at - * initialisation time. - */ -faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); - -void sendMpiHostRankMsg(const std::string& hostIn, - const faabric::MpiHostsToRanksMessage msg); - /* * This class abstracts the notion of a communication channel between two remote * MPI ranks. There will always be one rank local to this host, and one remote. diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 8a419acf2..b9ed87646 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -4,6 +4,7 @@ #include #include #include +#include /* Each MPI rank runs in a separate thread, thus we use TLS to maintain the * per-rank data structures. @@ -11,21 +12,57 @@ static thread_local std::vector< std::unique_ptr> mpiMessageEndpoints; + static thread_local std::vector< std::shared_ptr> unackedMessageBuffers; + static thread_local std::set iSendRequests; + static thread_local std::map> reqIdToRanks; +static std::vector rankMessages; + namespace faabric::scheduler { + MpiWorld::MpiWorld() - : id(-1) - , size(-1) - , thisHost(faabric::util::getSystemConfig().endpointHost) + : thisHost(faabric::util::getSystemConfig().endpointHost) , creationTime(faabric::util::startTimer()) + , ranksRecvEndpoint(MPI_BASE_PORT) , cartProcsPerDim(2) {} +faabric::MpiHostsToRanksMessage MpiWorld::recvMpiHostRankMsg() +{ + if (faabric::util::isMockMode()) { + assert(!rankMessages.empty()); + faabric::MpiHostsToRanksMessage msg = rankMessages.back(); + rankMessages.pop_back(); + return msg; + } + + SPDLOG_TRACE("Receiving MPI host ranks on {}", MPI_BASE_PORT); + faabric::transport::Message m = ranksRecvEndpoint.recv(); + PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); + + return msg; +} + +void MpiWorld::sendMpiHostRankMsg(const std::string& hostIn, + const faabric::MpiHostsToRanksMessage msg) +{ + if (faabric::util::isMockMode()) { + rankMessages.push_back(msg); + return; + } + + SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, MPI_BASE_PORT); + faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, + MPI_BASE_PORT); + SERIALISE_MSG(msg) + endpoint.send(buffer, msgSize, false); +} + void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank) { SPDLOG_TRACE("Open MPI endpoint between ranks (local-remote) {} - {}", @@ -174,7 +211,7 @@ void MpiWorld::broadcastHostsToRanks() // Do the broadcast for (const auto& h : targetHosts) { - faabric::transport::sendMpiHostRankMsg(h, hostRankMsg); + sendMpiHostRankMsg(h, hostRankMsg); } } @@ -221,8 +258,7 @@ void MpiWorld::initialiseFromMsg(const faabric::Message& msg) size = msg.mpiworldsize(); // Block until we receive - faabric::MpiHostsToRanksMessage hostRankMsg = - faabric::transport::recvMpiHostRankMsg(); + faabric::MpiHostsToRanksMessage hostRankMsg = recvMpiHostRankMsg(); // Prepare the host-rank map with a vector containing _all_ ranks // Note - this method should be called by only one rank. This is diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 9c5435dca..136456238 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -4,40 +4,6 @@ namespace faabric::transport { -static std::vector rankMessages; - -faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() -{ - if (faabric::util::isMockMode()) { - assert(!rankMessages.empty()); - faabric::MpiHostsToRanksMessage msg = rankMessages.back(); - rankMessages.pop_back(); - return msg; - } - - SPDLOG_TRACE("Receiving MPI host ranks on {}", MPI_BASE_PORT); - faabric::transport::AsyncRecvMessageEndpoint endpoint(MPI_BASE_PORT); - faabric::transport::Message m = endpoint.recv(); - PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); - - return msg; -} - -void sendMpiHostRankMsg(const std::string& hostIn, - const faabric::MpiHostsToRanksMessage msg) -{ - if (faabric::util::isMockMode()) { - rankMessages.push_back(msg); - return; - } - - SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, MPI_BASE_PORT); - faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, - MPI_BASE_PORT); - SERIALISE_MSG(msg) - endpoint.send(buffer, msgSize, false); -} - MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort) From 01ec57d74411a221edcc5ea25b6d9ea08487fb2c Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 12:28:55 +0000 Subject: [PATCH 39/66] Clear up MPI worlds in tests, lazy init recv ranks hosts socket --- include/faabric/scheduler/MpiWorld.h | 15 +++++-- include/faabric/transport/common.h | 2 +- src/scheduler/MpiWorld.cpp | 43 ++++++++++++++----- src/scheduler/SnapshotClient.cpp | 14 +++--- src/state/InMemoryStateRegistry.cpp | 1 - src/state/StateClient.cpp | 8 ++-- .../test/scheduler/test_remote_mpi_worlds.cpp | 25 ++++++----- .../transport/test_mpi_message_endpoint.cpp | 18 -------- tests/utils/fixtures.h | 9 ++-- 9 files changed, 79 insertions(+), 56 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index a23e39b31..d5ee0d1ff 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -10,6 +10,7 @@ #include #include +#include namespace faabric::scheduler { typedef faabric::util::Queue> @@ -18,7 +19,7 @@ typedef faabric::util::Queue> class MpiWorld { public: - MpiWorld(); + MpiWorld(int basePort = DEFAULT_MPI_BASE_PORT); void create(const faabric::Message& call, int newId, int newSize); @@ -183,9 +184,16 @@ class MpiWorld int id = -1; int size = -1; std::string thisHost; + int basePort = -1; faabric::util::TimePoint creationTime; - faabric::transport::AsyncRecvMessageEndpoint ranksRecvEndpoint; + std::unique_ptr + ranksRecvEndpoint; + + std::unordered_map< + std::string, + std::unique_ptr> + ranksSendEndpoints; std::shared_mutex worldMutex; std::atomic_flag isDestroyed = false; @@ -224,13 +232,14 @@ class MpiWorld faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); void sendMpiHostRankMsg(const std::string& hostIn, - const faabric::MpiHostsToRanksMessage msg); + const faabric::MpiHostsToRanksMessage msg); void closeMpiMessageEndpoints(); // Support for asyncrhonous communications std::shared_ptr getUnackedMessageBuffer(int sendRank, int recvRank); + std::shared_ptr recvBatchReturnLast(int sendRank, int recvRank, int batchSize = 0); diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index f9d1f95e2..dcdcf182b 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -11,4 +11,4 @@ #define SNAPSHOT_SYNC_PORT 8007 #define SNAPSHOT_ASYNC_PORT 8008 -#define MPI_BASE_PORT 8800 +#define DEFAULT_MPI_BASE_PORT 8800 diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index b9ed87646..8e93942bb 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -25,10 +25,10 @@ static std::vector rankMessages; namespace faabric::scheduler { -MpiWorld::MpiWorld() +MpiWorld::MpiWorld(int basePortIn) : thisHost(faabric::util::getSystemConfig().endpointHost) + , basePort(basePortIn) , creationTime(faabric::util::startTimer()) - , ranksRecvEndpoint(MPI_BASE_PORT) , cartProcsPerDim(2) {} @@ -41,8 +41,20 @@ faabric::MpiHostsToRanksMessage MpiWorld::recvMpiHostRankMsg() return msg; } - SPDLOG_TRACE("Receiving MPI host ranks on {}", MPI_BASE_PORT); - faabric::transport::Message m = ranksRecvEndpoint.recv(); + if (ranksRecvEndpoint == nullptr) { + faabric::util::FullLock lock(worldMutex); + if (ranksRecvEndpoint == nullptr) { + ranksRecvEndpoint = + std::make_unique( + basePort); + } + } + + // Shared lock to ensure it's initialised before use + faabric::util::SharedLock lock(worldMutex); + + SPDLOG_TRACE("Receiving MPI host ranks on {}", basePort); + faabric::transport::Message m = ranksRecvEndpoint->recv(); PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); return msg; @@ -56,11 +68,22 @@ void MpiWorld::sendMpiHostRankMsg(const std::string& hostIn, return; } - SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, MPI_BASE_PORT); - faabric::transport::AsyncSendMessageEndpoint endpoint(hostIn, - MPI_BASE_PORT); + if (ranksSendEndpoints.find(hostIn) == ranksSendEndpoints.end()) { + faabric::util::FullLock lock(worldMutex); + if (ranksSendEndpoints.find(hostIn) == ranksSendEndpoints.end()) { + ranksSendEndpoints.emplace( + hostIn, + std::make_unique( + hostIn, basePort)); + } + } + + // Shared lock to ensure endpoint is initialised before use + faabric::util::SharedLock lock(worldMutex); + + SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, basePort); SERIALISE_MSG(msg) - endpoint.send(buffer, msgSize, false); + ranksSendEndpoints[hostIn]->send(buffer, msgSize, false); } void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank) @@ -1250,10 +1273,10 @@ std::vector MpiWorld::initLocalBasePorts( basePortForRank.reserve(size); std::string lastHost = thisHost; - int lastPort = MPI_BASE_PORT; + int lastPort = basePort; for (const auto& host : executedAt) { if (host == thisHost) { - basePortForRank.push_back(MPI_BASE_PORT); + basePortForRank.push_back(basePort); } else if (host == lastHost) { basePortForRank.push_back(lastPort); } else { diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 525b0359c..58ffba9eb 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -70,7 +70,9 @@ void clearMockSnapshotRequests() // ----------------------------------- SnapshotClient::SnapshotClient(const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_ASYNC_PORT, SNAPSHOT_SYNC_PORT) + : faabric::transport::MessageEndpointClient(hostIn, + SNAPSHOT_ASYNC_PORT, + SNAPSHOT_SYNC_PORT) {} void SnapshotClient::pushSnapshot(const std::string& key, @@ -87,8 +89,8 @@ void SnapshotClient::pushSnapshot(const std::string& key, flatbuffers::FlatBufferBuilder mb; auto keyOffset = mb.CreateString(key); auto dataOffset = mb.CreateVector(data.data, data.size); - auto requestOffset = CreateSnapshotPushRequest( - mb, keyOffset, dataOffset); + auto requestOffset = + CreateSnapshotPushRequest(mb, keyOffset, dataOffset); // Send it mb.Finish(requestOffset); @@ -129,8 +131,8 @@ void SnapshotClient::pushSnapshotDiffs( // TODO - avoid copying data here auto keyOffset = mb.CreateString(snapshotKey); auto diffsOffset = mb.CreateVector(diffsFbVector); - auto requestOffset = CreateSnapshotDiffPushRequest( - mb, keyOffset, diffsOffset); + auto requestOffset = + CreateSnapshotDiffPushRequest(mb, keyOffset, diffsOffset); mb.Finish(requestOffset); uint8_t* buffer = mb.GetBufferPointer(); @@ -222,7 +224,7 @@ void SnapshotClient::pushThreadResult( uint8_t* buffer = mb.GetBufferPointer(); int size = mb.GetSize(); asyncSend( - faabric::scheduler::SnapshotCalls::PushSnapshotDiffs, buffer, size); + faabric::scheduler::SnapshotCalls::ThreadResult, buffer, size); } } } diff --git a/src/state/InMemoryStateRegistry.cpp b/src/state/InMemoryStateRegistry.cpp index f29b9426c..e115aa28b 100644 --- a/src/state/InMemoryStateRegistry.cpp +++ b/src/state/InMemoryStateRegistry.cpp @@ -30,7 +30,6 @@ std::string InMemoryStateRegistry::getMasterIP(const std::string& user, const std::string& thisIP, bool claim) { - std::string lookupKey = faabric::util::keyForUser(user, key); // See if we already have the master diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index 91f40f835..66e81524d 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -28,7 +28,7 @@ void StateClient::sendStateRequest(faabric::state::StateCalls header, } faabric::EmptyResponse resp; - syncSend(header, data, length, &resp); + syncSend(header, &request, &resp); } void StateClient::pushChunks(const std::vector& chunks) @@ -56,9 +56,11 @@ void StateClient::pullChunks(const std::vector& chunks, request.set_offset(chunk.offset); request.set_chunksize(chunk.length); - // Send and copy response into place + // Send request faabric::StatePart response; syncSend(faabric::state::StateCalls::Pull, &request, &response); + + // Copy response data std::copy(response.data().begin(), response.data().end(), bufferStart + response.offset()); @@ -109,7 +111,7 @@ size_t StateClient::stateSize() request.set_key(key); faabric::StateSizeResponse response; - syncSend(faabric::state::StateCalls::Size, nullptr, 0, &response); + syncSend(faabric::state::StateCalls::Size, &request, &response); return response.statesize(); } diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index d4f962bc8..e34cab9b3 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -43,8 +43,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - faabric::util::setMockMode(false); + faabric::util::setMockMode(false); localWorld.broadcastHostsToRanks(); remoteWorld.initialiseFromMsg(msg); @@ -79,7 +79,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") remoteWorld.send( rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -137,7 +137,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(buffer, buffer + messageData2.size()); REQUIRE(actual == messageData2); - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -231,7 +231,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -286,7 +286,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, assert(actual == messageData); } - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -362,7 +362,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -452,7 +452,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -518,7 +518,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -574,7 +574,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, remoteWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -649,7 +649,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_STATUS_IGNORE); } - usleep(1000 * 500); + usleep(500 * 1000); remoteWorld.destroy(); }); @@ -674,7 +674,10 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Destroy world - senderThread.join(); + if (senderThread.joinable()) { + senderThread.join(); + } + localWorld.destroy(); } } diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index 32a67d994..643c9bda0 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -6,24 +6,6 @@ using namespace faabric::transport; namespace tests { -TEST_CASE_METHOD(SchedulerTestFixture, - "Test send and recv the hosts to rank message", - "[transport]") -{ - std::vector expected = { "foo", "bar" }; - - // Send the message - faabric::MpiHostsToRanksMessage sendMsg; - *sendMsg.mutable_hosts() = { expected.begin(), expected.end() }; - sendMpiHostRankMsg(LOCALHOST, sendMsg); - - // Receive and check - faabric::MpiHostsToRanksMessage actual = recvMpiHostRankMsg(); - assert(actual.hosts().size() == expected.size()); - for (int i = 0; i < actual.hosts().size(); i++) { - assert(actual.hosts().Get(i) == expected[i]); - } -} TEST_CASE_METHOD(SchedulerTestFixture, "Test send and recv an MPI message", diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 532b1f7e6..30525e239 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -157,9 +157,7 @@ class MpiBaseTestFixture : public SchedulerTestFixture class MpiTestFixture : public MpiBaseTestFixture { public: - MpiTestFixture() { - world.create(msg, worldId, worldSize); - } + MpiTestFixture() { world.create(msg, worldId, worldSize); } ~MpiTestFixture() { world.destroy(); } @@ -177,6 +175,11 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture remoteWorld.overrideHost(otherHost); } + ~RemoteMpiTestFixture() + { + faabric::scheduler::getMpiWorldRegistry().clear(); + } + void setWorldsSizes(int worldSize, int ranksWorldOne, int ranksWorldTwo) { // Update message From 7f617bfbd539b4ceb37523021f6f7cf9e9f878c3 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 14:38:39 +0000 Subject: [PATCH 40/66] Thread-local sockets for rank-to-host mappings --- include/faabric/scheduler/MpiWorld.h | 9 ------- src/scheduler/MpiWorld.cpp | 35 +++++++++++++--------------- 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index d5ee0d1ff..efd53efdd 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -187,15 +187,6 @@ class MpiWorld int basePort = -1; faabric::util::TimePoint creationTime; - std::unique_ptr - ranksRecvEndpoint; - - std::unordered_map< - std::string, - std::unique_ptr> - ranksSendEndpoints; - - std::shared_mutex worldMutex; std::atomic_flag isDestroyed = false; std::string user; diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 8e93942bb..ca75988d9 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -17,6 +17,15 @@ static thread_local std::vector< std::shared_ptr> unackedMessageBuffers; +static thread_local std::unique_ptr< + faabric::transport::AsyncRecvMessageEndpoint> + ranksRecvEndpoint; + +static thread_local std::unordered_map< + std::string, + std::unique_ptr> + ranksSendEndpoints; + static thread_local std::set iSendRequests; static thread_local std::map> reqIdToRanks; @@ -42,17 +51,11 @@ faabric::MpiHostsToRanksMessage MpiWorld::recvMpiHostRankMsg() } if (ranksRecvEndpoint == nullptr) { - faabric::util::FullLock lock(worldMutex); - if (ranksRecvEndpoint == nullptr) { - ranksRecvEndpoint = - std::make_unique( - basePort); - } + ranksRecvEndpoint = + std::make_unique( + basePort); } - // Shared lock to ensure it's initialised before use - faabric::util::SharedLock lock(worldMutex); - SPDLOG_TRACE("Receiving MPI host ranks on {}", basePort); faabric::transport::Message m = ranksRecvEndpoint->recv(); PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); @@ -69,18 +72,12 @@ void MpiWorld::sendMpiHostRankMsg(const std::string& hostIn, } if (ranksSendEndpoints.find(hostIn) == ranksSendEndpoints.end()) { - faabric::util::FullLock lock(worldMutex); - if (ranksSendEndpoints.find(hostIn) == ranksSendEndpoints.end()) { - ranksSendEndpoints.emplace( - hostIn, - std::make_unique( - hostIn, basePort)); - } + ranksSendEndpoints.emplace( + hostIn, + std::make_unique( + hostIn, basePort)); } - // Shared lock to ensure endpoint is initialised before use - faabric::util::SharedLock lock(worldMutex); - SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, basePort); SERIALISE_MSG(msg) ranksSendEndpoints[hostIn]->send(buffer, msgSize, false); From 6c15fcb488b615fe0e17c90e8b43d5523d937ad1 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 14:39:10 +0000 Subject: [PATCH 41/66] Detailed logging of send/recv --- .../faabric/transport/MpiMessageEndpoint.h | 1 - src/transport/MessageEndpoint.cpp | 22 ++--- .../test/scheduler/test_remote_mpi_worlds.cpp | 90 +++++++++++-------- tests/utils/fixtures.h | 7 +- 4 files changed, 70 insertions(+), 50 deletions(-) diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 0b917bf8d..0f5f5a397 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -2,7 +2,6 @@ #include #include -#include #include #include diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index cdaf69a09..1a79c8a32 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -61,29 +61,25 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, switch (socketType) { case zmq::socket_type::req: { SPDLOG_TRACE( - "Opening req socket {}:{} (timeout {}ms)", host, port, timeoutMs); + "REQ socket open {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.connect(address), "connect") break; } case zmq::socket_type::push: { - SPDLOG_TRACE("Opening push socket {}:{} (timeout {}ms)", - host, - port, - timeoutMs); + SPDLOG_TRACE( + "PUSH socket open {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.connect(address), "connect") break; } case zmq::socket_type::pull: { - SPDLOG_TRACE("Opening pull socket {}:{} (timeout {}ms)", - host, - port, - timeoutMs); + SPDLOG_TRACE( + "PULL socket open {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.bind(address), "bind") break; } case zmq::socket_type::rep: { SPDLOG_TRACE( - "Opening rep socket {}:{} (timeout {}ms)", host, port, timeoutMs); + "REP socket open {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.bind(address), "bind") break; } @@ -227,6 +223,7 @@ void AsyncSendMessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) { + SPDLOG_TRACE("PUSH {}:{} ({} bytes, more {})", host, port, msgSize, more); doSend(pushSocket, serialisedMsg, msgSize, more); } @@ -259,6 +256,8 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* serialisedMsg, size_t msgSize, bool more) { + SPDLOG_TRACE("REQ {}:{} ({} bytes, more {})", host, port, msgSize, more); + doSend(reqSocket, serialisedMsg, msgSize, more); // Do the receive @@ -294,6 +293,7 @@ AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) Message AsyncRecvMessageEndpoint::recv(int size) { + SPDLOG_TRACE("PULL {} ({} bytes)", port, size); return doRecv(pullSocket, size); } @@ -309,6 +309,7 @@ SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) Message SyncRecvMessageEndpoint::recv(int size) { + SPDLOG_TRACE("RECV {} (REP) ({} bytes)", port, size); return doRecv(repSocket, size); } @@ -316,6 +317,7 @@ Message SyncRecvMessageEndpoint::recv(int size) // optimisation if needed. void SyncRecvMessageEndpoint::sendResponse(uint8_t* data, int size) { + SPDLOG_TRACE("REP {} ({} bytes)", port, size); doSend(repSocket, data, size, false); } } diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index e34cab9b3..4015a37b2 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -30,8 +30,13 @@ class RemoteCollectiveTestFixture : public RemoteMpiTestFixture protected: int thisWorldSize; - int remoteRankA, remoteRankB, remoteRankC; - int localRankA, localRankB; + int remoteRankA; + int remoteRankB; + int remoteRankC; + + int localRankA; + int localRankB; + std::vector remoteWorldRanks; std::vector localWorldRanks; }; @@ -69,7 +74,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); std::thread senderThread([this, rankA, rankB, &messageData] { @@ -79,9 +83,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") remoteWorld.send( rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); - usleep(500 * 1000); - - remoteWorld.destroy(); + usleep(1000 * 1000); }); // Receive the message for the given rank @@ -98,8 +100,12 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); // Destroy worlds - senderThread.join(); + if (senderThread.joinable()) { + senderThread.join(); + } + localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -138,8 +144,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(actual == messageData2); usleep(500 * 1000); - - remoteWorld.destroy(); }); // Receive the message for the given rank @@ -159,7 +163,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); // Destroy worlds - senderThread.join(); + if (senderThread.joinable()) { + senderThread.join(); + } + + remoteWorld.destroy(); localWorld.destroy(); } @@ -187,8 +195,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test barrier across hosts", "[mpi]") // Barrier on this rank remoteWorld.barrier(rankB); assert(sendData == recvData); - - remoteWorld.destroy(); }); // Receive the message for the given rank @@ -204,8 +210,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test barrier across hosts", "[mpi]") localWorld.barrier(rankA); // Destroy worlds - senderThread.join(); + if (senderThread.joinable()) { + senderThread.join(); + } localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -232,8 +241,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } usleep(500 * 1000); - - remoteWorld.destroy(); }); int recv; @@ -248,8 +255,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Destroy worlds - senderThread.join(); + if (senderThread.joinable()) { + senderThread.join(); + } localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, @@ -264,7 +274,6 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, // Init worlds MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); std::thread senderThread([this, &messageData] { @@ -285,10 +294,6 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); assert(actual == messageData); } - - usleep(500 * 1000); - - remoteWorld.destroy(); }); // Check the local host @@ -299,8 +304,12 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, REQUIRE(actual == messageData); } - senderThread.join(); + // Destroy worlds + if (senderThread.joinable()) { + senderThread.join(); + } localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, @@ -363,8 +372,6 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, assert(actual == std::vector({ 12, 13, 14, 15 })); usleep(500 * 1000); - - remoteWorld.destroy(); }); // Check for local ranks @@ -399,8 +406,12 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); - senderThread.join(); + // Destroy worlds + if (senderThread.joinable()) { + senderThread.join(); + } localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, @@ -453,8 +464,6 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } usleep(500 * 1000); - - remoteWorld.destroy(); }); for (int rank : localWorldRanks) { @@ -484,8 +493,12 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, // Check data REQUIRE(actual == expected); - senderThread.join(); + // Destroy worlds + if (senderThread.joinable()) { + senderThread.join(); + } localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -519,8 +532,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, messageData.size()); usleep(500 * 1000); - - remoteWorld.destroy(); }); // Receive one message asynchronously @@ -547,9 +558,12 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(syncMessage == messageData); REQUIRE(asyncMessage == messageData); - // Destroy world - senderThread.join(); + // Destroy worlds + if (senderThread.joinable()) { + senderThread.join(); + } localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -575,8 +589,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } usleep(500 * 1000); - - remoteWorld.destroy(); }); // Receive two messages asynchronously @@ -608,9 +620,12 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(recv2 == 1); REQUIRE(recv3 == 2); - // Destroy world - senderThread.join(); + // Destroy worlds + if (senderThread.joinable()) { + senderThread.join(); + } localWorld.destroy(); + remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -650,8 +665,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } usleep(500 * 1000); - - remoteWorld.destroy(); }); for (auto& rank : localRanks) { @@ -679,5 +692,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } localWorld.destroy(); + remoteWorld.destroy(); } } diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 30525e239..1b01b000d 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -118,7 +118,9 @@ class ConfTestFixture faabric::util::SystemConfig& conf; }; -class MpiBaseTestFixture : public SchedulerTestFixture +class MpiBaseTestFixture + : public SchedulerTestFixture + , public ConfTestFixture { public: MpiBaseTestFixture() @@ -172,11 +174,14 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture : thisHost(faabric::util::getSystemConfig().endpointHost) , otherHost(LOCALHOST) { + // Mock everything by default + faabric::util::setMockMode(true); remoteWorld.overrideHost(otherHost); } ~RemoteMpiTestFixture() { + faabric::util::setMockMode(false); faabric::scheduler::getMpiWorldRegistry().clear(); } From 1d702f989ec4b7c44c7f681fc218d22f39aa0580 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 25 Jun 2021 15:38:23 +0000 Subject: [PATCH 42/66] Switch local/remote in remote world tests --- src/transport/MessageEndpoint.cpp | 8 ++-- .../test/scheduler/test_remote_mpi_worlds.cpp | 39 +++++++++---------- tests/test/transport/test_message_server.cpp | 3 +- tests/utils/fixtures.h | 2 - 4 files changed, 25 insertions(+), 27 deletions(-) diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 1a79c8a32..cbd750a06 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -61,25 +61,25 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, switch (socketType) { case zmq::socket_type::req: { SPDLOG_TRACE( - "REQ socket open {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: req {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.connect(address), "connect") break; } case zmq::socket_type::push: { SPDLOG_TRACE( - "PUSH socket open {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: push {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.connect(address), "connect") break; } case zmq::socket_type::pull: { SPDLOG_TRACE( - "PULL socket open {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: pull {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.bind(address), "bind") break; } case zmq::socket_type::rep: { SPDLOG_TRACE( - "REP socket open {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: rep {}:{} (timeout {}ms)", host, port, timeoutMs); CATCH_ZMQ_ERR(socket.bind(address), "bind") break; } diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 4015a37b2..76ea19272 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -76,36 +76,35 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") faabric::util::setMockMode(false); localWorld.broadcastHostsToRanks(); - std::thread senderThread([this, rankA, rankB, &messageData] { + // Start the "remote" world in the background + std::thread remoteWorldThread([this, rankA, rankB, &messageData] { remoteWorld.initialiseFromMsg(msg); - // Send a message that should get sent to this host - remoteWorld.send( - rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + // Receive the message for the given rank + MPI_Status status{}; + auto buffer = new int[messageData.size()]; + remoteWorld.recv( + rankA, rankB, BYTES(buffer), MPI_INT, messageData.size(), &status); - usleep(1000 * 1000); - }); + std::vector actual(buffer, buffer + messageData.size()); + assert(actual == messageData); - // Receive the message for the given rank - MPI_Status status{}; - auto buffer = new int[messageData.size()]; - localWorld.recv( - rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); + assert(status.MPI_SOURCE == rankA); + assert(status.MPI_ERROR == MPI_SUCCESS); + assert(status.bytesSize == messageData.size() * sizeof(int)); - std::vector actual(buffer, buffer + messageData.size()); - REQUIRE(actual == messageData); + remoteWorld.destroy(); + }); - REQUIRE(status.MPI_SOURCE == rankB); - REQUIRE(status.MPI_ERROR == MPI_SUCCESS); - REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); + // Send a message that should get sent to the "remote" world + localWorld.send( + rankA, rankB, BYTES(messageData.data()), MPI_INT, messageData.size()); - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + if (remoteWorldThread.joinable()) { + remoteWorldThread.join(); } localWorld.destroy(); - remoteWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 377e6e0b9..5c0496385 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -187,7 +187,6 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") cli.asyncSend(0, body, clientMsg.size()); } - usleep(1000 * 300); })); } @@ -197,6 +196,8 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") } } + usleep(2000 * 1000); + REQUIRE(server.messageCount == numMessages * numClients); server.stop(); diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 1b01b000d..92a4ac777 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -174,8 +174,6 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture : thisHost(faabric::util::getSystemConfig().endpointHost) , otherHost(LOCALHOST) { - // Mock everything by default - faabric::util::setMockMode(true); remoteWorld.overrideHost(otherHost); } From 2d88beb4e18496bd56056e23da11a6486df677df Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 09:40:02 +0000 Subject: [PATCH 43/66] Clear up all thread-local sockets when destroying MPI world --- src/scheduler/MpiWorld.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index ca75988d9..b75665e96 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -237,6 +237,17 @@ void MpiWorld::broadcastHostsToRanks() void MpiWorld::destroy() { + SPDLOG_TRACE("Destroying MPI world {}", id); + + // Note that all ranks will call this function. + + // We must force the destructors for all message endpoints to run here + // rather than at the end of their global thread-local lifespan. If we + // don't, the ZMQ shutdown can hang. + mpiMessageEndpoints.clear(); + ranksRecvEndpoint = nullptr; + ranksSendEndpoints.clear(); + // Unacked message buffers if (!unackedMessageBuffers.empty()) { for (auto& umb : unackedMessageBuffers) { From 705a29cae8336116b0d4416a716b0af88c3f37cf Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 10:45:17 +0000 Subject: [PATCH 44/66] Fix remote world test hanging --- include/faabric/transport/macros.h | 6 +- src/scheduler/FunctionCallClient.cpp | 3 +- src/scheduler/MpiContext.cpp | 2 +- .../test/scheduler/test_remote_mpi_worlds.cpp | 661 +++++++++--------- .../test_message_endpoint_client.cpp | 26 - tests/test/transport/test_message_server.cpp | 1 - tests/utils/fixtures.h | 41 +- 7 files changed, 372 insertions(+), 368 deletions(-) diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index abe0a7494..a37132095 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -7,13 +7,13 @@ } #define SERIALISE_MSG(msg) \ - size_t msgSize = msg.ByteSizeLong(); \ + size_t msgSize = msg.ByteSizeLong(); \ uint8_t buffer[msgSize]; \ - if (!msg.SerializeToArray(buffer, msgSize)) { \ + if (!msg.SerializeToArray(buffer, msgSize)) { \ throw std::runtime_error("Error serialising message"); \ } -#define SERIALISE_MSG_PTR(msg) \ +#define SERIALISE_MSG_PTR(msg) \ size_t msgSize = msg->ByteSizeLong(); \ uint8_t buffer[msgSize]; \ if (!msg->SerializeToArray(buffer, msgSize)) { \ diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 701b29381..3492e04f2 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -47,8 +47,7 @@ getBatchRequests() return batchMessages; } -std::vector> -getResourceRequests() +std::vector> getResourceRequests() { return resourceRequests; } diff --git a/src/scheduler/MpiContext.cpp b/src/scheduler/MpiContext.cpp index 3e2f2fa4d..4674da0b1 100644 --- a/src/scheduler/MpiContext.cpp +++ b/src/scheduler/MpiContext.cpp @@ -25,7 +25,7 @@ int MpiContext::createWorld(const faabric::Message& msg) // Create the MPI world scheduler::MpiWorldRegistry& reg = scheduler::getMpiWorldRegistry(); - scheduler::MpiWorld &world = reg.createWorld(msg, worldId); + scheduler::MpiWorld& world = reg.createWorld(msg, worldId); // Broadcast setup to other hosts world.broadcastHostsToRanks(); diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 76ea19272..a6f17c03e 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -18,72 +18,100 @@ class RemoteCollectiveTestFixture : public RemoteMpiTestFixture { public: RemoteCollectiveTestFixture() - : thisWorldSize(6) - , remoteRankA(1) - , remoteRankB(2) - , remoteRankC(3) - , localRankA(4) - , localRankB(5) - , remoteWorldRanks({ remoteRankB, remoteRankC, remoteRankA }) - , localWorldRanks({ localRankB, localRankA, 0 }) - {} + { + thisWorldRanks = { thisHostRankB, thisHostRankA, 0 }; + otherWorldRanks = { otherHostRankB, otherHostRankC, otherHostRankA }; + + // Here we rely on the scheduler running out of resources and + // overloading this world with ranks 4 and 5 + setWorldSizes(thisWorldSize, 1, 3); + } + + MpiWorld& setUpThisWorld() + { + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); + + // Check it's set up as we expect + for (auto r : otherWorldRanks) { + REQUIRE(thisWorld.getHostForRank(r) == otherHost); + } + + for (auto r : thisWorldRanks) { + REQUIRE(thisWorld.getHostForRank(r) == thisHost); + } + + return thisWorld; + } protected: - int thisWorldSize; - int remoteRankA; - int remoteRankB; - int remoteRankC; + int thisWorldSize = 6; + + int otherHostRankA = 1; + int otherHostRankB = 2; + int otherHostRankC = 3; - int localRankA; - int localRankB; + int thisHostRankA = 4; + int thisHostRankB = 5; - std::vector remoteWorldRanks; - std::vector localWorldRanks; + std::vector otherWorldRanks; + std::vector thisWorldRanks; }; TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(2, 1, 1); - - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + setWorldSizes(2, 1, 1); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + thisWorld.broadcastHostsToRanks(); + + // Background thread to receive the allocation + std::thread otherWorldThread([this] { + otherWorld.initialiseFromMsg(msg); - remoteWorld.initialiseFromMsg(msg); + REQUIRE(otherWorld.getHostForRank(0) == thisHost); + REQUIRE(otherWorld.getHostForRank(1) == otherHost); - // Now check both world instances report the same mappings - REQUIRE(localWorld.getHostForRank(0) == thisHost); - REQUIRE(localWorld.getHostForRank(1) == otherHost); + otherWorld.destroy(); + }); + + usleep(500 * 1000); + + REQUIRE(thisWorld.getHostForRank(0) == thisHost); + REQUIRE(thisWorld.getHostForRank(1) == otherHost); + + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } - // Destroy worlds - localWorld.destroy(); - remoteWorld.destroy(); + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; std::vector messageData = { 0, 1, 2 }; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + thisWorld.broadcastHostsToRanks(); // Start the "remote" world in the background - std::thread remoteWorldThread([this, rankA, rankB, &messageData] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, rankA, rankB, &messageData] { + otherWorld.initialiseFromMsg(msg); // Receive the message for the given rank MPI_Status status{}; auto buffer = new int[messageData.size()]; - remoteWorld.recv( + otherWorld.recv( rankA, rankB, BYTES(buffer), MPI_INT, messageData.size(), &status); std::vector actual(buffer, buffer + messageData.size()); @@ -93,18 +121,18 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") assert(status.MPI_ERROR == MPI_SUCCESS); assert(status.bytesSize == messageData.size() * sizeof(int)); - remoteWorld.destroy(); + otherWorld.destroy(); }); // Send a message that should get sent to the "remote" world - localWorld.send( + thisWorld.send( rankA, rankB, BYTES(messageData.data()), MPI_INT, messageData.size()); - if (remoteWorldThread.joinable()) { - remoteWorldThread.join(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -112,108 +140,113 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; std::vector messageData = { 0, 1, 2 }; std::vector messageData2 = { 3, 4, 5 }; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); - localWorld.broadcastHostsToRanks(); + std::thread otherWorldThread( + [this, rankA, rankB, &messageData, &messageData2] { + otherWorld.initialiseFromMsg(msg); - std::thread senderThread([this, rankA, rankB, &messageData, &messageData2] { - remoteWorld.initialiseFromMsg(msg); + // Send a message that should get sent to this host + otherWorld.send(rankB, + rankA, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); - // Send a message that should get sent to this host - remoteWorld.send( - rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + // Now recv + auto buffer = new int[messageData2.size()]; + otherWorld.recv(rankA, + rankB, + BYTES(buffer), + MPI_INT, + messageData2.size(), + MPI_STATUS_IGNORE); + std::vector actual(buffer, buffer + messageData2.size()); + REQUIRE(actual == messageData2); - // Now recv - auto buffer = new int[messageData2.size()]; - remoteWorld.recv(rankA, - rankB, - BYTES(buffer), - MPI_INT, - messageData2.size(), - MPI_STATUS_IGNORE); - std::vector actual(buffer, buffer + messageData2.size()); - REQUIRE(actual == messageData2); + usleep(1000 * 1000); - usleep(500 * 1000); - }); + otherWorld.destroy(); + }); // Receive the message for the given rank MPI_Status status{}; auto buffer = new int[messageData.size()]; - localWorld.recv( + thisWorld.recv( rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); std::vector actual(buffer, buffer + messageData.size()); REQUIRE(actual == messageData); // Now send a message - localWorld.send( + thisWorld.send( rankA, rankB, BYTES(messageData2.data()), MPI_INT, messageData2.size()); REQUIRE(status.MPI_SOURCE == rankB); REQUIRE(status.MPI_ERROR == MPI_SUCCESS); REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - remoteWorld.destroy(); - localWorld.destroy(); + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, "Test barrier across hosts", "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; std::vector sendData = { 0, 1, 2 }; std::vector recvData = { -1, -1, -1 }; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, rankA, rankB, &sendData, &recvData] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, rankA, rankB, &sendData, &recvData] { + otherWorld.initialiseFromMsg(msg); - remoteWorld.send( + otherWorld.send( rankB, rankA, BYTES(sendData.data()), MPI_INT, sendData.size()); // Barrier on this rank - remoteWorld.barrier(rankB); + otherWorld.barrier(rankB); assert(sendData == recvData); + otherWorld.destroy(); }); // Receive the message for the given rank - localWorld.recv(rankB, - rankA, - BYTES(recvData.data()), - MPI_INT, - recvData.size(), - MPI_STATUS_IGNORE); + thisWorld.recv(rankB, + rankA, + BYTES(recvData.data()), + MPI_INT, + recvData.size(), + MPI_STATUS_IGNORE); REQUIRE(recvData == sendData); // Call barrier to synchronise remote host - localWorld.barrier(rankA); + thisWorld.barrier(rankA); - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -221,30 +254,31 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; int numMessages = 1000; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, rankA, rankB, numMessages] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, rankA, rankB, numMessages] { + otherWorld.initialiseFromMsg(msg); for (int i = 0; i < numMessages; i++) { - remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); + otherWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } usleep(500 * 1000); + otherWorld.destroy(); }); int recv; for (int i = 0; i < numMessages; i++) { - localWorld.recv( + thisWorld.recv( rankB, rankA, BYTES(&recv), MPI_INT, 1, MPI_STATUS_IGNORE); // Check in-order delivery @@ -253,77 +287,70 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } } - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, "Test broadcast across hosts", "[mpi]") { - // Here we rely on the scheduler running out of resources, and overloading - // the localWorld with ranks 4 and 5 - this->setWorldsSizes(thisWorldSize, 1, 3); - std::vector messageData = { 0, 1, 2 }; + MpiWorld& thisWorld = setUpThisWorld(); - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + std::vector messageData = { 0, 1, 2 }; - std::thread senderThread([this, &messageData] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, &messageData] { + otherWorld.initialiseFromMsg(msg); // Broadcast a message - remoteWorld.broadcast( - remoteRankB, BYTES(messageData.data()), MPI_INT, messageData.size()); - - // Check the host that the root is on - for (int rank : remoteWorldRanks) { - if (rank == remoteRankB) { + otherWorld.broadcast(otherHostRankB, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + + // Check the broadcast is received on this host by the other ranks + for (int rank : otherWorldRanks) { + if (rank == otherHostRankB) { continue; } std::vector actual(3, -1); - remoteWorld.recv( - remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); + otherWorld.recv( + otherHostRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); assert(actual == messageData); } + + // Give the other host time to receive the broadcast + usleep(1000 * 1000); + + otherWorld.destroy(); }); - // Check the local host - for (int rank : localWorldRanks) { + // Check the ranks on this host receive the broadcast + for (int rank : thisWorldRanks) { std::vector actual(3, -1); - localWorld.recv( - remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); + thisWorld.recv( + otherHostRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); REQUIRE(actual == messageData); } - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, "Test scatter across hosts", "[mpi]") { - // Here we rely on the scheduler running out of resources, and overloading - // the localWorld with ranks 4 and 5 - this->setWorldsSizes(thisWorldSize, 1, 3); - - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - faabric::util::setMockMode(false); - - localWorld.broadcastHostsToRanks(); + MpiWorld& thisWorld = setUpThisWorld(); // Build the data int nPerRank = 4; @@ -333,98 +360,94 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, messageData[i] = i; } - std::thread senderThread([this, nPerRank, &messageData] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, nPerRank, &messageData] { + otherWorld.initialiseFromMsg(msg); + // Do the scatter std::vector actual(nPerRank, -1); - remoteWorld.scatter(remoteRankB, - remoteRankB, - BYTES(messageData.data()), - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + otherWorld.scatter(otherHostRankB, + otherHostRankB, + BYTES(messageData.data()), + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); // Check for root assert(actual == std::vector({ 8, 9, 10, 11 })); // Check for other remote ranks - remoteWorld.scatter(remoteRankB, - remoteRankA, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + otherWorld.scatter(otherHostRankB, + otherHostRankA, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); assert(actual == std::vector({ 4, 5, 6, 7 })); - remoteWorld.scatter(remoteRankB, - remoteRankC, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + otherWorld.scatter(otherHostRankB, + otherHostRankC, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); usleep(500 * 1000); + + otherWorld.destroy(); }); // Check for local ranks std::vector actual(nPerRank, -1); - localWorld.scatter(remoteRankB, - 0, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.scatter(otherHostRankB, + 0, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); REQUIRE(actual == std::vector({ 0, 1, 2, 3 })); - localWorld.scatter(remoteRankB, - localRankB, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.scatter(otherHostRankB, + thisHostRankB, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); REQUIRE(actual == std::vector({ 20, 21, 22, 23 })); - localWorld.scatter(remoteRankB, - localRankA, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.scatter(otherHostRankB, + thisHostRankA, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, "Test gather across hosts", "[mpi]") { - // Here we rely on the scheduler running out of resources, and overloading - // the localWorld with ranks 4 and 5 - this->setWorldsSizes(thisWorldSize, 1, 3); - - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + MpiWorld& thisWorld = setUpThisWorld(); // Build the data for each rank int nPerRank = 4; @@ -447,57 +470,59 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, std::vector actual(thisWorldSize * nPerRank, -1); // Call gather for each rank other than the root (out of order) - int root = localRankA; - std::thread senderThread([this, root, &rankData, nPerRank] { - remoteWorld.initialiseFromMsg(msg); - - for (int rank : remoteWorldRanks) { - remoteWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); + int root = thisHostRankA; + std::thread otherWorldThread([this, root, &rankData, nPerRank] { + otherWorld.initialiseFromMsg(msg); + + for (int rank : otherWorldRanks) { + otherWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); } usleep(500 * 1000); + + otherWorld.destroy(); }); - for (int rank : localWorldRanks) { + for (int rank : thisWorldRanks) { if (rank == root) { continue; } - localWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); + thisWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); } // Call gather for root - localWorld.gather(root, - root, - BYTES(rankData[root].data()), - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.gather(root, + root, + BYTES(rankData[root].data()), + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); // Check data REQUIRE(actual == expected); - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -505,64 +530,66 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int sendRank = 1; int recvRank = 0; std::vector messageData = { 0, 1, 2 }; // Init world - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, sendRank, recvRank, &messageData] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, sendRank, recvRank, &messageData] { + otherWorld.initialiseFromMsg(msg); // Send message twice - remoteWorld.send(sendRank, - recvRank, - BYTES(messageData.data()), - MPI_INT, - messageData.size()); - remoteWorld.send(sendRank, - recvRank, - BYTES(messageData.data()), - MPI_INT, - messageData.size()); + otherWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + otherWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); usleep(500 * 1000); + + otherWorld.destroy(); }); // Receive one message asynchronously std::vector asyncMessage(messageData.size(), 0); - int recvId = localWorld.irecv(sendRank, - recvRank, - BYTES(asyncMessage.data()), - MPI_INT, - asyncMessage.size()); + int recvId = thisWorld.irecv(sendRank, + recvRank, + BYTES(asyncMessage.data()), + MPI_INT, + asyncMessage.size()); // Receive one message synchronously std::vector syncMessage(messageData.size(), 0); - localWorld.recv(sendRank, - recvRank, - BYTES(syncMessage.data()), - MPI_INT, - syncMessage.size(), - MPI_STATUS_IGNORE); + thisWorld.recv(sendRank, + recvRank, + BYTES(syncMessage.data()), + MPI_INT, + syncMessage.size(), + MPI_STATUS_IGNORE); // Wait for the async message - localWorld.awaitAsyncRequest(recvId); + thisWorld.awaitAsyncRequest(recvId); // Checks REQUIRE(syncMessage == messageData); REQUIRE(asyncMessage == messageData); - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -570,48 +597,50 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int sendRank = 1; int recvRank = 0; // Init world - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, sendRank, recvRank] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, sendRank, recvRank] { + otherWorld.initialiseFromMsg(msg); // Send different messages for (int i = 0; i < 3; i++) { - remoteWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); + otherWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } usleep(500 * 1000); + + otherWorld.destroy(); }); // Receive two messages asynchronously int recv1, recv2, recv3; int recvId1 = - localWorld.irecv(sendRank, recvRank, BYTES(&recv1), MPI_INT, 1); + thisWorld.irecv(sendRank, recvRank, BYTES(&recv1), MPI_INT, 1); int recvId2 = - localWorld.irecv(sendRank, recvRank, BYTES(&recv2), MPI_INT, 1); + thisWorld.irecv(sendRank, recvRank, BYTES(&recv2), MPI_INT, 1); // Receive one message synchronously - localWorld.recv( + thisWorld.recv( sendRank, recvRank, BYTES(&recv3), MPI_INT, 1, MPI_STATUS_IGNORE); SECTION("Wait out of order") { - localWorld.awaitAsyncRequest(recvId2); - localWorld.awaitAsyncRequest(recvId1); + thisWorld.awaitAsyncRequest(recvId2); + thisWorld.awaitAsyncRequest(recvId1); } SECTION("Wait in order") { - localWorld.awaitAsyncRequest(recvId1); - localWorld.awaitAsyncRequest(recvId2); + thisWorld.awaitAsyncRequest(recvId1); + thisWorld.awaitAsyncRequest(recvId2); } // Checks @@ -619,12 +648,12 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(recv2 == 1); REQUIRE(recv3 == 2); - // Destroy worlds - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -632,65 +661,65 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(3, 1, 2); + setWorldSizes(3, 1, 2); int worldSize = 3; - std::vector localRanks = { 0 }; + std::vector thisHostRanks = { 0 }; // Init world - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - localWorld.broadcastHostsToRanks(); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, worldSize] { - std::vector remoteRanks = { 1, 2 }; - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, worldSize] { + std::vector otherHostRanks = { 1, 2 }; + otherWorld.initialiseFromMsg(msg); // Send different messages - for (auto& rank : remoteRanks) { + for (auto& rank : otherHostRanks) { int left = rank > 0 ? rank - 1 : worldSize - 1; int right = (rank + 1) % worldSize; int recvData = -1; - remoteWorld.sendRecv(BYTES(&rank), - 1, - MPI_INT, - right, - BYTES(&recvData), - 1, - MPI_INT, - left, - rank, - MPI_STATUS_IGNORE); + otherWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); } usleep(500 * 1000); + otherWorld.destroy(); }); - for (auto& rank : localRanks) { + for (auto& rank : thisHostRanks) { int left = rank > 0 ? rank - 1 : worldSize - 1; int right = (rank + 1) % worldSize; int recvData = -1; - localWorld.sendRecv(BYTES(&rank), - 1, - MPI_INT, - right, - BYTES(&recvData), - 1, - MPI_INT, - left, - rank, - MPI_STATUS_IGNORE); + thisWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); REQUIRE(recvData == left); } - // Destroy world - if (senderThread.joinable()) { - senderThread.join(); + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); } - localWorld.destroy(); - remoteWorld.destroy(); + thisWorld.destroy(); } } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 2fc6223b3..0961a5eb5 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -64,32 +64,6 @@ TEST_CASE_METHOD(SchedulerTestFixture, } } -TEST_CASE_METHOD(SchedulerTestFixture, - "Test send out of scope before recv", - "[transport]") -{ - std::string expectedMsg = "Hello world!"; - - // Send message and let socket go out of scope - { - AsyncSendMessageEndpoint src(thisHost, testPort); - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - src.send(msg, expectedMsg.size()); - } - - // Recieve message in its own scope too - { - usleep(100 * 1000); - AsyncRecvMessageEndpoint dst(testPort); - - faabric::transport::Message recvMsg = dst.recv(); - REQUIRE(recvMsg.size() == expectedMsg.size()); - std::string actualMsg(recvMsg.data(), recvMsg.size()); - REQUIRE(actualMsg == expectedMsg); - } -} - TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") { // Prepare common message/response diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 5c0496385..8f215b98e 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -186,7 +186,6 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") memcpy(body, clientMsg.c_str(), clientMsg.size()); cli.asyncSend(0, body, clientMsg.size()); } - })); } diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 92a4ac777..f37f6f374 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -167,51 +167,54 @@ class MpiTestFixture : public MpiBaseTestFixture faabric::scheduler::MpiWorld world; }; +// Note that this test has two worlds, which each "think" that the other is +// remote. This is done by allowing one to have the IP of this host, the other +// to have the localhost IP, i.e. 127.0.0.1. class RemoteMpiTestFixture : public MpiBaseTestFixture { public: RemoteMpiTestFixture() : thisHost(faabric::util::getSystemConfig().endpointHost) - , otherHost(LOCALHOST) { - remoteWorld.overrideHost(otherHost); + otherWorld.overrideHost(otherHost); + + faabric::util::setMockMode(true); } ~RemoteMpiTestFixture() { faabric::util::setMockMode(false); + faabric::scheduler::getMpiWorldRegistry().clear(); } - void setWorldsSizes(int worldSize, int ranksWorldOne, int ranksWorldTwo) + void setWorldSizes(int worldSize, int ranksThisWorld, int ranksOtherWorld) { // Update message msg.set_mpiworldsize(worldSize); - // Set local ranks - faabric::HostResources localResources; - localResources.set_slots(ranksWorldOne); - // Account for the master rank that is already running in this world - localResources.set_usedslots(1); - // Set remote ranks - faabric::HostResources otherResources; - otherResources.set_slots(ranksWorldTwo); - // Note that the remaining ranks will be allocated to the world - // with the master host + // Set up the first world, holding the master rank (which already takes + // one slot). + // Note that any excess ranks will also be allocated to this world when + // the scheduler is overloaded. + faabric::HostResources thisResources; + thisResources.set_slots(ranksThisWorld); + thisResources.set_usedslots(1); + sch.setThisHostResources(thisResources); - std::string otherHost = LOCALHOST; + // Set up the other world and add it to the global set of hosts + faabric::HostResources otherResources; + otherResources.set_slots(ranksOtherWorld); sch.addHostToGlobalSet(otherHost); - // Mock everything to make sure the other host has resources as well - faabric::util::setMockMode(true); - sch.setThisHostResources(localResources); + // Queue the resource response for this other host faabric::scheduler::queueResourceResponse(otherHost, otherResources); } protected: std::string thisHost; - std::string otherHost; + std::string otherHost = LOCALHOST; - faabric::scheduler::MpiWorld remoteWorld; + faabric::scheduler::MpiWorld otherWorld; }; } From 6026ffc67c7e6435708f94d8564474ea6d56d765 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 10:56:29 +0000 Subject: [PATCH 45/66] Switch to SLEEP_MS macro --- include/faabric/util/macros.h | 4 ++++ src/state/StateKeyValue.cpp | 2 +- tests/dist/main.cpp | 3 ++- tests/test/redis/test_redis.cpp | 3 ++- tests/test/scheduler/test_executor.cpp | 8 ++++---- .../scheduler/test_function_client_server.cpp | 11 +++++------ .../test/scheduler/test_remote_mpi_worlds.cpp | 18 +++++++++--------- tests/test/scheduler/test_scheduler.cpp | 3 ++- .../scheduler/test_snapshot_client_server.cpp | 5 +++-- tests/test/state/test_state.cpp | 5 +++-- tests/test/state/test_state_server.cpp | 3 ++- .../transport/test_message_endpoint_client.cpp | 2 +- tests/test/transport/test_message_server.cpp | 14 +++++++------- tests/test/util/test_barrier.cpp | 5 ++++- tests/test/util/test_queue.cpp | 5 +++-- 15 files changed, 52 insertions(+), 39 deletions(-) diff --git a/include/faabric/util/macros.h b/include/faabric/util/macros.h index 7b865679f..c86f29362 100644 --- a/include/faabric/util/macros.h +++ b/include/faabric/util/macros.h @@ -1,6 +1,10 @@ #pragma once +#include + #define BYTES(arr) reinterpret_cast(arr) #define BYTES_CONST(arr) reinterpret_cast(arr) +#define SLEEP_MS(ms) usleep((ms) * 1000) + #define UNUSED(x) (void)(x) diff --git a/src/state/StateKeyValue.cpp b/src/state/StateKeyValue.cpp index c53dddaf1..a1936f74b 100644 --- a/src/state/StateKeyValue.cpp +++ b/src/state/StateKeyValue.cpp @@ -559,7 +559,7 @@ uint32_t StateKeyValue::waitOnRedisRemoteLock(const std::string& redisKey) break; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + SLEEP_MS(1); remoteLockId = redis.acquireLock(redisKey, REMOTE_LOCK_TIMEOUT_SECS); retryCount++; diff --git a/tests/dist/main.cpp b/tests/dist/main.cpp index 424110470..cca63d7e4 100644 --- a/tests/dist/main.cpp +++ b/tests/dist/main.cpp @@ -10,6 +10,7 @@ #include #include #include +#include using namespace faabric::scheduler; @@ -30,7 +31,7 @@ int main(int argc, char* argv[]) m.startBackground(); // Wait for things to start - usleep(3000 * 1000); + SLEEP_MS(3000); // Run the tests int result = Catch::Session().run(argc, argv); diff --git a/tests/test/redis/test_redis.cpp b/tests/test/redis/test_redis.cpp index 6550c0dcc..15439f60a 100644 --- a/tests/test/redis/test_redis.cpp +++ b/tests/test/redis/test_redis.cpp @@ -4,6 +4,7 @@ #include #include +#include #include @@ -544,7 +545,7 @@ TEST_CASE("Test enqueue after blocking dequeue") }); // Wait a bit (assume the waiting thread will get to block by now) - sleep(1); + SLEEP_MS(1000); redisQueue.enqueue("foobar", "baz"); // If this hangs, the redis client isn't dequeueing after an enqueue is diff --git a/tests/test/scheduler/test_executor.cpp b/tests/test/scheduler/test_executor.cpp index a5a52a527..5558700b1 100644 --- a/tests/test/scheduler/test_executor.cpp +++ b/tests/test/scheduler/test_executor.cpp @@ -454,7 +454,7 @@ TEST_CASE_METHOD(TestExecutorFixture, }); // Give it time to have made the request - usleep(SHORT_TEST_TIMEOUT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); // Check restore hasn't been called yet REQUIRE(restoreCount == 0); @@ -521,7 +521,7 @@ TEST_CASE_METHOD(TestExecutorFixture, // We have to manually add a wait here as the thread results won't actually // get logged on this host - usleep(SHORT_TEST_TIMEOUT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); auto actual = faabric::scheduler::getThreadResults(); REQUIRE(actual.size() == nThreads); @@ -681,7 +681,7 @@ TEST_CASE_METHOD(TestExecutorFixture, REQUIRE(sch.getFunctionExecutorCount(msg) == 1); - usleep((conf.boundTimeout + 500) * 1000); + SLEEP_MS(conf.boundTimeout + 500); REQUIRE(sch.getFunctionExecutorCount(msg) == 0); } @@ -712,7 +712,7 @@ TEST_CASE_METHOD(TestExecutorFixture, executeWithTestExecutor(req, true); // Wait for executor to have finished - sometimes takes a while - usleep(SHORT_TEST_TIMEOUT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); // Check thread results returned REQUIRE(faabric::scheduler::getThreadResults().size() == nThreads); diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index 988039b4d..06926c887 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -13,12 +13,11 @@ #include #include #include +#include #include #include #include -#define TEST_WAIT_MS 1000 - using namespace faabric::scheduler; namespace tests { @@ -37,7 +36,7 @@ class ClientServerFixture : cli(LOCALHOST) { server.start(); - usleep(TEST_WAIT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); // Set up executor executorFactory = std::make_shared(); @@ -80,7 +79,7 @@ TEST_CASE_METHOD(ClientServerFixture, // Send flush message cli.sendFlush(); - usleep(TEST_WAIT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); // Check the scheduler has been flushed REQUIRE(sch.getFunctionRegisteredHostCount(msgA) == 0); @@ -137,7 +136,7 @@ TEST_CASE_METHOD(ClientServerFixture, // Make the request cli.executeFunctions(req); - usleep(TEST_WAIT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); // Check no other hosts have been registered faabric::Message m = req->messages().at(0); @@ -226,7 +225,7 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") reqB.set_host(otherHost); *reqB.mutable_function() = msg; cli.unregister(reqB); - usleep(TEST_WAIT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 0); sch.setThisHostResources(originalResources); diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index a6f17c03e..6cf91b70b 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -78,7 +78,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") otherWorld.destroy(); }); - usleep(500 * 1000); + SLEEP_MS(500); REQUIRE(thisWorld.getHostForRank(0) == thisHost); REQUIRE(thisWorld.getHostForRank(1) == otherHost); @@ -173,7 +173,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(buffer, buffer + messageData2.size()); REQUIRE(actual == messageData2); - usleep(1000 * 1000); + SLEEP_MS(1000); otherWorld.destroy(); }); @@ -272,7 +272,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - usleep(500 * 1000); + SLEEP_MS(500); otherWorld.destroy(); }); @@ -325,7 +325,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Give the other host time to receive the broadcast - usleep(1000 * 1000); + SLEEP_MS(1000); otherWorld.destroy(); }); @@ -398,7 +398,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); - usleep(500 * 1000); + SLEEP_MS(500); otherWorld.destroy(); }); @@ -485,7 +485,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } - usleep(500 * 1000); + SLEEP_MS(500); otherWorld.destroy(); }); @@ -555,7 +555,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - usleep(500 * 1000); + SLEEP_MS(500); otherWorld.destroy(); }); @@ -614,7 +614,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } - usleep(500 * 1000); + SLEEP_MS(500); otherWorld.destroy(); }); @@ -692,7 +692,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_STATUS_IGNORE); } - usleep(500 * 1000); + SLEEP_MS(500); otherWorld.destroy(); }); diff --git a/tests/test/scheduler/test_scheduler.cpp b/tests/test/scheduler/test_scheduler.cpp index 2b906a152..c999eb45e 100644 --- a/tests/test/scheduler/test_scheduler.cpp +++ b/tests/test/scheduler/test_scheduler.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include using namespace faabric::scheduler; @@ -35,7 +36,7 @@ class SlowExecutor final : public Executor SPDLOG_DEBUG("SlowExecutor executing task{}", req->mutable_messages()->at(msgIdx).id()); - usleep(SHORT_TEST_TIMEOUT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); return 0; } }; diff --git a/tests/test/scheduler/test_snapshot_client_server.cpp b/tests/test/scheduler/test_snapshot_client_server.cpp index 476e20140..4a7356118 100644 --- a/tests/test/scheduler/test_snapshot_client_server.cpp +++ b/tests/test/scheduler/test_snapshot_client_server.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -29,7 +30,7 @@ class SnapshotClientServerFixture : cli(LOCALHOST) { server.start(); - usleep(1000 * SHORT_TEST_TIMEOUT_MS); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); } ~SnapshotClientServerFixture() { server.stop(); } @@ -62,7 +63,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, cli.pushSnapshot(snapKeyA, snapA); cli.pushSnapshot(snapKeyB, snapB); - usleep(1000 * 500); + SLEEP_MS(500); // Check snapshots created in registry REQUIRE(reg.getSnapshotCount() == 2); diff --git a/tests/test/state/test_state.cpp b/tests/test/state/test_state.cpp index 7308aef72..e593c45a4 100644 --- a/tests/test/state/test_state.cpp +++ b/tests/test/state/test_state.cpp @@ -7,10 +7,11 @@ #include #include #include +#include #include #include - #include + #include using namespace faabric::state; @@ -43,7 +44,7 @@ class StateServerTestFixture stateServer.start(); // Give it time to start - usleep(1000 * 1000); + SLEEP_MS(1000); } ~StateServerTestFixture() { stateServer.stop(); } diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index 73f02ab57..be628fd9d 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -27,7 +28,7 @@ class SimpleStateServerTestFixture conf.stateMode = "inmemory"; server.start(); - usleep(1000 * 100); + SLEEP_MS(100); } ~SimpleStateServerTestFixture() { server.stop(); } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 0961a5eb5..4448bcebf 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -44,7 +44,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Run the recv in the background std::thread recvThread([expectedMsg] { - usleep(1000 * 1000); + SLEEP_MS(1000); AsyncRecvMessageEndpoint dst(testPort); // Receive message diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 8f215b98e..4caedd32a 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -92,7 +92,7 @@ class SlowServer final : public MessageEndpointServer { SPDLOG_DEBUG("Slow message server test recv"); - usleep(delayMs * 1000); + SLEEP_MS(delayMs); auto response = std::make_unique(); response->set_data("From the slow server"); return response; @@ -105,7 +105,7 @@ TEST_CASE("Test start/stop server", "[transport]") DummyServer server; server.start(); - usleep(100 * 1000); + SLEEP_MS(100); server.stop(); } @@ -133,7 +133,7 @@ TEST_CASE("Test send one message to server", "[transport]") memcpy(bodyMsg, body.c_str(), body.size()); src.send(bodyMsg, body.size(), false); - usleep(1000 * 300); + SLEEP_MS(300); REQUIRE(server.messageCount == 1); // Close the server @@ -145,7 +145,7 @@ TEST_CASE("Test send response to client", "[transport]") std::thread serverThread([] { EchoServer server; server.start(); - usleep(1000 * 1000); + SLEEP_MS(1000); server.stop(); }); @@ -195,7 +195,7 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") } } - usleep(2000 * 1000); + SLEEP_MS(2000); REQUIRE(server.messageCount == numMessages * numClients); @@ -225,13 +225,13 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") server.start(); int threadSleep = server.delayMs + 500; - usleep(threadSleep * 1000); + SLEEP_MS(threadSleep); server.stop(); }); // Wait for the server to start up - usleep(500 * 1000); + SLEEP_MS(500); // Set up the client MessageEndpointClient cli( diff --git a/tests/test/util/test_barrier.cpp b/tests/test/util/test_barrier.cpp index 594c0ab90..b0e980440 100644 --- a/tests/test/util/test_barrier.cpp +++ b/tests/test/util/test_barrier.cpp @@ -1,6 +1,9 @@ #include + #include #include +#include + #include #include @@ -19,7 +22,7 @@ TEST_CASE("Test barrier operation", "[util]") auto t2 = std::thread([&b] { b.wait(); }); // Sleep for a bit while the threads spawn - usleep(500 * 1000); + SLEEP_MS(500); REQUIRE(b.getSlotCount() == 1); // Join with master to go through barrier diff --git a/tests/test/util/test_queue.cpp b/tests/test/util/test_queue.cpp index 64c4319d0..94b7f0907 100644 --- a/tests/test/util/test_queue.cpp +++ b/tests/test/util/test_queue.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -86,7 +87,7 @@ TEST_CASE("Test wait for draining queue with elements", "[util]") // Background thread to consume elements std::thread t([&q, &dequeued, nElems] { for (int i = 0; i < nElems; i++) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + SLEEP_MS(100); int j = q.dequeue(); dequeued.emplace_back(j); @@ -117,7 +118,7 @@ TEST_CASE("Test queue on non-copy-constructible object", "[util]") std::thread ta([&q] { q.dequeue().set_value(1); }); std::thread tb([&q] { - usleep(500 * 1000); + SLEEP_MS(500); q.dequeue().set_value(2); }); From b656c49efa3b9c6d590059c24e057bdd9b372a83 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 10:56:54 +0000 Subject: [PATCH 46/66] Formatting --- include/faabric/util/macros.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/faabric/util/macros.h b/include/faabric/util/macros.h index c86f29362..504e4be40 100644 --- a/include/faabric/util/macros.h +++ b/include/faabric/util/macros.h @@ -5,6 +5,6 @@ #define BYTES(arr) reinterpret_cast(arr) #define BYTES_CONST(arr) reinterpret_cast(arr) -#define SLEEP_MS(ms) usleep((ms) * 1000) +#define SLEEP_MS(ms) usleep((ms)*1000) #define UNUSED(x) (void)(x) From 6e089d0b699f541a9570346d45702d7e513e61f8 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 14:21:26 +0000 Subject: [PATCH 47/66] Self review and restart dummy state server on initial error --- include/faabric/scheduler/MpiWorld.h | 4 +- .../faabric/transport/MessageEndpointServer.h | 2 +- .../faabric/transport/MpiMessageEndpoint.h | 17 ++++--- include/faabric/transport/context.h | 4 ++ src/scheduler/MpiWorld.cpp | 20 ++++---- src/scheduler/SnapshotClient.cpp | 47 +++++++++---------- src/transport/MessageEndpoint.cpp | 2 - src/transport/context.cpp | 8 +--- tests/test/transport/test_message_server.cpp | 14 ++++++ 9 files changed, 64 insertions(+), 54 deletions(-) diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index efd53efdd..2ce29ae1c 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -19,7 +19,7 @@ typedef faabric::util::Queue> class MpiWorld { public: - MpiWorld(int basePort = DEFAULT_MPI_BASE_PORT); + MpiWorld(); void create(const faabric::Message& call, int newId, int newSize); @@ -184,7 +184,7 @@ class MpiWorld int id = -1; int size = -1; std::string thisHost; - int basePort = -1; + int basePort = DEFAULT_MPI_BASE_PORT; faabric::util::TimePoint creationTime; std::atomic_flag isDestroyed = false; diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index fa8de64b6..3d7df8f74 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -19,7 +19,7 @@ class MessageEndpointServer public: MessageEndpointServer(int asyncPortIn, int syncPortIn); - void start(); + virtual void start(); virtual void stop(); diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 0f5f5a397..d715a8881 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -7,15 +7,14 @@ namespace faabric::transport { -/* - * This class abstracts the notion of a communication channel between two remote - * MPI ranks. There will always be one rank local to this host, and one remote. - * Note that the ports are unique per (user, function, sendRank, recvRank) - * tuple. - * - * This is different to our normal messaging clients as it wraps two _async_ - * sockets. - */ +// This class abstracts the notion of a communication channel between two remote +// MPI ranks. There will always be one rank local to this host, and one remote. +// Note that the ports are unique per (user, function, sendRank, recvRank) +// tuple. +// +// This is different to our normal messaging clients as it wraps two +// _async_ sockets. + class MpiMessageEndpoint { public: diff --git a/include/faabric/transport/context.h b/include/faabric/transport/context.h index 9a0a7d846..4e608254d 100644 --- a/include/faabric/transport/context.h +++ b/include/faabric/transport/context.h @@ -3,6 +3,10 @@ #include #include +// We specify a number of background I/O threads when constructing the ZeroMQ +// context. Guidelines on how to scale this can be found here: +// https://zguide.zeromq.org/docs/chapter2/#I-O-Threads + #define ZMQ_CONTEXT_IO_THREADS 1 namespace faabric::transport { diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index b75665e96..33cfa1979 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -6,9 +6,8 @@ #include #include -/* Each MPI rank runs in a separate thread, thus we use TLS to maintain the - * per-rank data structures. - */ +// Each MPI rank runs in a separate thread, thus we use TLS to maintain the +// per-rank data structures static thread_local std::vector< std::unique_ptr> mpiMessageEndpoints; @@ -17,6 +16,13 @@ static thread_local std::vector< std::shared_ptr> unackedMessageBuffers; +static thread_local std::set iSendRequests; + +static thread_local std::map> reqIdToRanks; + +// These long-lived sockets are used by each world to communicate rank-to-host +// mappings. They are thread-local to ensure separation between concurrent +// worlds executing on the same host static thread_local std::unique_ptr< faabric::transport::AsyncRecvMessageEndpoint> ranksRecvEndpoint; @@ -26,17 +32,13 @@ static thread_local std::unordered_map< std::unique_ptr> ranksSendEndpoints; -static thread_local std::set iSendRequests; - -static thread_local std::map> reqIdToRanks; - +// This is used for mocking in tests static std::vector rankMessages; namespace faabric::scheduler { -MpiWorld::MpiWorld(int basePortIn) +MpiWorld::MpiWorld() : thisHost(faabric::util::getSystemConfig().endpointHost) - , basePort(basePortIn) , creationTime(faabric::util::startTimer()) , cartProcsPerDim(2) {} diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 58ffba9eb..020edfa45 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -69,6 +69,21 @@ void clearMockSnapshotRequests() // Snapshot client // ----------------------------------- +#define SEND_FB_MSG(T, mb) \ + { \ + uint8_t* buffer = mb.GetBufferPointer(); \ + int size = mb.GetSize(); \ + faabric::EmptyResponse response; \ + syncSend(T, buffer, size, &response); \ + } + +#define SEND_FB_MSG_ASYNC(T, mb) \ + { \ + uint8_t* buffer = mb.GetBufferPointer(); \ + int size = mb.GetSize(); \ + asyncSend(T, buffer, size); \ + } + SnapshotClient::SnapshotClient(const std::string& hostIn) : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_ASYNC_PORT, @@ -91,16 +106,10 @@ void SnapshotClient::pushSnapshot(const std::string& key, auto dataOffset = mb.CreateVector(data.data, data.size); auto requestOffset = CreateSnapshotPushRequest(mb, keyOffset, dataOffset); + mb.Finish(requestOffset); // Send it - mb.Finish(requestOffset); - uint8_t* buffer = mb.GetBufferPointer(); - int size = mb.GetSize(); - faabric::EmptyResponse response; - syncSend(faabric::scheduler::SnapshotCalls::PushSnapshot, - buffer, - size, - &response); + SEND_FB_MSG(SnapshotCalls::PushSnapshot, mb) } } @@ -133,15 +142,9 @@ void SnapshotClient::pushSnapshotDiffs( auto diffsOffset = mb.CreateVector(diffsFbVector); auto requestOffset = CreateSnapshotDiffPushRequest(mb, keyOffset, diffsOffset); - mb.Finish(requestOffset); - uint8_t* buffer = mb.GetBufferPointer(); - int size = mb.GetSize(); - faabric::EmptyResponse response; - syncSend(faabric::scheduler::SnapshotCalls::PushSnapshotDiffs, - buffer, - size, - &response); + + SEND_FB_MSG(SnapshotCalls::PushSnapshotDiffs, mb); } } @@ -158,12 +161,9 @@ void SnapshotClient::deleteSnapshot(const std::string& key) flatbuffers::FlatBufferBuilder mb; auto keyOffset = mb.CreateString(key); auto requestOffset = CreateSnapshotDeleteRequest(mb, keyOffset); - mb.Finish(requestOffset); - uint8_t* buffer = mb.GetBufferPointer(); - int size = mb.GetSize(); - asyncSend( - faabric::scheduler::SnapshotCalls::PushSnapshotDiffs, buffer, size); + + SEND_FB_MSG_ASYNC(SnapshotCalls::DeleteSnapshot, mb); } } @@ -221,10 +221,7 @@ void SnapshotClient::pushThreadResult( } mb.Finish(requestOffset); - uint8_t* buffer = mb.GetBufferPointer(); - int size = mb.GetSize(); - asyncSend( - faabric::scheduler::SnapshotCalls::ThreadResult, buffer, size); + SEND_FB_MSG_ASYNC(SnapshotCalls::ThreadResult, mb) } } } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index cbd750a06..51093cec2 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -313,8 +313,6 @@ Message SyncRecvMessageEndpoint::recv(int size) return doRecv(repSocket, size); } -// We create a new endpoint every time. Re-using them would be a possible -// optimisation if needed. void SyncRecvMessageEndpoint::sendResponse(uint8_t* data, int size) { SPDLOG_TRACE("REP {} ({} bytes)", port, size); diff --git a/src/transport/context.cpp b/src/transport/context.cpp index e6adfc475..3c6b9d918 100644 --- a/src/transport/context.cpp +++ b/src/transport/context.cpp @@ -4,12 +4,8 @@ namespace faabric::transport { -/* - * The zmq::context_t object is thread safe, and the constructor parameter - * indicates the number of hardware IO threads to be used. As a rule of thumb, - * use one IO thread per Gbps of data. - */ - +// The ZeroMQ context object is thread safe, so we're ok to have a single global +// instance. static std::shared_ptr instance = nullptr; void initGlobalMessageContext() diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 4caedd32a..6b41e1a52 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -26,6 +26,20 @@ class DummyServer final : public MessageEndpointServer // Variable to keep track of the received messages int messageCount; + void start() override + { + // In a CI environment tests can be slow to tear down fully, so we want + // to sleep and retry if the initial connection fails. + try { + MessageEndpointServer::start(); + } catch (zmq::error_t& ex) { + SPDLOG_DEBUG("Retrying dummy server start after delay"); + + SLEEP_MS(1000); + MessageEndpointServer::start(); + } + } + private: void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) override From ec9d58d671bb4167fc5d4b6a0623bf0506dd914c Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 14:42:31 +0000 Subject: [PATCH 48/66] Remove unnecessary threads in transport tests --- .../test_message_endpoint_client.cpp | 8 +- tests/test/transport/test_message_server.cpp | 73 +++++-------------- 2 files changed, 23 insertions(+), 58 deletions(-) diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 4448bcebf..30d6e3221 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -9,8 +9,8 @@ using namespace faabric::transport; -const std::string thisHost = "127.0.0.1"; -const int testPort = 9800; +static const std::string thisHost = "127.0.0.1"; +static const int testPort = 9800; namespace tests { @@ -114,7 +114,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, std::string baseMsg = "Hello "; std::thread senderThread([numMessages, baseMsg] { - // Open the source endpoint client, don't bind + // Open the source endpoint client AsyncSendMessageEndpoint src(thisHost, testPort); for (int i = 0; i < numMessages; i++) { std::string msgData = baseMsg + std::to_string(i); @@ -155,7 +155,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, for (int j = 0; j < numSenders; j++) { senderThreads.emplace_back(std::thread([numMessages, expectedMsg] { - // Open the source endpoint client, don't bind + // Open the source endpoint client AsyncSendMessageEndpoint src(thisHost, testPort); for (int i = 0; i < numMessages; i++) { uint8_t msg[expectedMsg.size()]; diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 6b41e1a52..55500082b 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -11,9 +11,9 @@ using namespace faabric::transport; -const std::string thisHost = "127.0.0.1"; -const int testPortAsync = 9998; -const int testPortSync = 9999; +static const std::string thisHost = "127.0.0.1"; +static const int testPortAsync = 9998; +static const int testPortSync = 9999; class DummyServer final : public MessageEndpointServer { @@ -33,7 +33,7 @@ class DummyServer final : public MessageEndpointServer try { MessageEndpointServer::start(); } catch (zmq::error_t& ex) { - SPDLOG_DEBUG("Retrying dummy server start after delay"); + SPDLOG_WARN("Error connecting dummy server, retrying after delay"); SLEEP_MS(1000); MessageEndpointServer::start(); @@ -114,58 +114,37 @@ class SlowServer final : public MessageEndpointServer }; namespace tests { -TEST_CASE("Test start/stop server", "[transport]") -{ - DummyServer server; - server.start(); - SLEEP_MS(100); - - server.stop(); -} - -TEST_CASE("Test send one message to server", "[transport]") + TEST_CASE("Test send one message to server", "[transport]") { - // Start server DummyServer server; server.start(); - // Open the source endpoint client, don't bind - AsyncSendMessageEndpoint src(thisHost, testPortAsync, testPortSync); - - // Send message: server expects header + body - std::string header = "header"; - uint8_t headerMsg[header.size()]; - memcpy(headerMsg, header.c_str(), header.size()); + REQUIRE(server.messageCount == 0); - // Mark we are sending the header - src.send(headerMsg, header.size(), true); + MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); - // Send the body + // Send a message std::string body = "body"; uint8_t bodyMsg[body.size()]; memcpy(bodyMsg, body.c_str(), body.size()); - src.send(bodyMsg, body.size(), false); + cli.asyncSend(0, bodyMsg, body.size()); + + SLEEP_MS(500); - SLEEP_MS(300); REQUIRE(server.messageCount == 1); - // Close the server server.stop(); } TEST_CASE("Test send response to client", "[transport]") { - std::thread serverThread([] { - EchoServer server; - server.start(); - SLEEP_MS(1000); - server.stop(); - }); + EchoServer server; + server.start(); std::string expectedMsg = "Response from server"; - // Open the source endpoint client, don't bind + // Open the source endpoint client MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); // Send and await the response @@ -174,9 +153,7 @@ TEST_CASE("Test send response to client", "[transport]") assert(response.data() == expectedMsg); - if (serverThread.joinable()) { - serverThread.join(); - } + server.stop(); } TEST_CASE("Test multiple clients talking to one server", "[transport]") @@ -233,19 +210,9 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") expectFailure = true; } - // Start the server in the background - std::thread t([] { - SlowServer server; - server.start(); - - int threadSleep = server.delayMs + 500; - SLEEP_MS(threadSleep); - - server.stop(); - }); - - // Wait for the server to start up - SLEEP_MS(500); + // Start the server + SlowServer server; + server.start(); // Set up the client MessageEndpointClient cli( @@ -264,8 +231,6 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") REQUIRE(response.data() == "From the slow server"); } - if (t.joinable()) { - t.join(); - } + server.stop(); } } From f9d45ba01971f3acb8f631bfa5ee844b675e0dcf Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 15:33:25 +0000 Subject: [PATCH 49/66] Lengthen all timeouts --- tests/test/redis/test_redis.cpp | 2 +- .../test/scheduler/test_remote_mpi_worlds.cpp | 24 +++++++++---------- .../scheduler/test_snapshot_client_server.cpp | 2 +- tests/test/state/test_state.cpp | 2 +- tests/test/state/test_state_server.cpp | 2 +- .../test_message_endpoint_client.cpp | 2 +- tests/test/transport/test_message_server.cpp | 10 ++++---- tests/test/util/test_barrier.cpp | 4 +++- tests/test/util/test_queue.cpp | 4 +++- 9 files changed, 29 insertions(+), 23 deletions(-) diff --git a/tests/test/redis/test_redis.cpp b/tests/test/redis/test_redis.cpp index 15439f60a..c6711b67b 100644 --- a/tests/test/redis/test_redis.cpp +++ b/tests/test/redis/test_redis.cpp @@ -545,7 +545,7 @@ TEST_CASE("Test enqueue after blocking dequeue") }); // Wait a bit (assume the waiting thread will get to block by now) - SLEEP_MS(1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); redisQueue.enqueue("foobar", "baz"); // If this hangs, the redis client isn't dequeueing after an enqueue is diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 6cf91b70b..cdae57345 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -78,7 +78,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") otherWorld.destroy(); }); - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); REQUIRE(thisWorld.getHostForRank(0) == thisHost); REQUIRE(thisWorld.getHostForRank(1) == otherHost); @@ -173,7 +173,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(buffer, buffer + messageData2.size()); REQUIRE(actual == messageData2); - SLEEP_MS(1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); @@ -272,7 +272,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); @@ -325,7 +325,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Give the other host time to receive the broadcast - SLEEP_MS(1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); @@ -363,7 +363,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, std::thread otherWorldThread([this, nPerRank, &messageData] { otherWorld.initialiseFromMsg(msg); - // Do the scatter + // Do the scatter (when send rank == recv rank) std::vector actual(nPerRank, -1); otherWorld.scatter(otherHostRankB, otherHostRankB, @@ -377,7 +377,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, // Check for root assert(actual == std::vector({ 8, 9, 10, 11 })); - // Check for other remote ranks + // Check the other ranks on this host have received the data otherWorld.scatter(otherHostRankB, otherHostRankA, nullptr, @@ -398,12 +398,12 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); - // Check for local ranks + // Check for ranks on this host std::vector actual(nPerRank, -1); thisWorld.scatter(otherHostRankB, 0, @@ -485,7 +485,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); @@ -555,7 +555,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); @@ -614,7 +614,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); @@ -692,7 +692,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_STATUS_IGNORE); } - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); otherWorld.destroy(); }); diff --git a/tests/test/scheduler/test_snapshot_client_server.cpp b/tests/test/scheduler/test_snapshot_client_server.cpp index 4a7356118..7c94b04ce 100644 --- a/tests/test/scheduler/test_snapshot_client_server.cpp +++ b/tests/test/scheduler/test_snapshot_client_server.cpp @@ -63,7 +63,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, cli.pushSnapshot(snapKeyA, snapA); cli.pushSnapshot(snapKeyB, snapB); - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); // Check snapshots created in registry REQUIRE(reg.getSnapshotCount() == 2); diff --git a/tests/test/state/test_state.cpp b/tests/test/state/test_state.cpp index e593c45a4..93b353ca8 100644 --- a/tests/test/state/test_state.cpp +++ b/tests/test/state/test_state.cpp @@ -44,7 +44,7 @@ class StateServerTestFixture stateServer.start(); // Give it time to start - SLEEP_MS(1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); } ~StateServerTestFixture() { stateServer.stop(); } diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index be628fd9d..d0f52faa8 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -28,7 +28,7 @@ class SimpleStateServerTestFixture conf.stateMode = "inmemory"; server.start(); - SLEEP_MS(100); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); } ~SimpleStateServerTestFixture() { server.stop(); } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 30d6e3221..eab83e775 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -44,7 +44,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Run the recv in the background std::thread recvThread([expectedMsg] { - SLEEP_MS(1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); AsyncRecvMessageEndpoint dst(testPort); // Receive message diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 55500082b..2dc5e935f 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -1,5 +1,7 @@ #include +#include "faabric_utils.h" + #include #include @@ -35,7 +37,7 @@ class DummyServer final : public MessageEndpointServer } catch (zmq::error_t& ex) { SPDLOG_WARN("Error connecting dummy server, retrying after delay"); - SLEEP_MS(1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); MessageEndpointServer::start(); } } @@ -115,7 +117,7 @@ class SlowServer final : public MessageEndpointServer namespace tests { - TEST_CASE("Test send one message to server", "[transport]") +TEST_CASE("Test send one message to server", "[transport]") { DummyServer server; server.start(); @@ -130,7 +132,7 @@ namespace tests { memcpy(bodyMsg, body.c_str(), body.size()); cli.asyncSend(0, bodyMsg, body.size()); - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); REQUIRE(server.messageCount == 1); @@ -186,7 +188,7 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") } } - SLEEP_MS(2000); + SLEEP_MS(2 * SHORT_TEST_TIMEOUT_MS); REQUIRE(server.messageCount == numMessages * numClients); diff --git a/tests/test/util/test_barrier.cpp b/tests/test/util/test_barrier.cpp index b0e980440..582c7b8e3 100644 --- a/tests/test/util/test_barrier.cpp +++ b/tests/test/util/test_barrier.cpp @@ -1,5 +1,7 @@ #include +#include "faabric_utils.h" + #include #include #include @@ -22,7 +24,7 @@ TEST_CASE("Test barrier operation", "[util]") auto t2 = std::thread([&b] { b.wait(); }); // Sleep for a bit while the threads spawn - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); REQUIRE(b.getSlotCount() == 1); // Join with master to go through barrier diff --git a/tests/test/util/test_queue.cpp b/tests/test/util/test_queue.cpp index 94b7f0907..fbf1560ab 100644 --- a/tests/test/util/test_queue.cpp +++ b/tests/test/util/test_queue.cpp @@ -1,5 +1,7 @@ #include +#include "faabric_utils.h" + #include #include #include @@ -118,7 +120,7 @@ TEST_CASE("Test queue on non-copy-constructible object", "[util]") std::thread ta([&q] { q.dequeue().set_value(1); }); std::thread tb([&q] { - SLEEP_MS(500); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); q.dequeue().set_value(2); }); From e2671341c8735449140f88c858c808a5e4dd9e49 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 16:03:48 +0000 Subject: [PATCH 50/66] Retry connecting socket --- src/transport/MessageEndpoint.cpp | 77 ++++++++++--------- .../test_message_endpoint_client.cpp | 1 - tests/test/transport/test_message_server.cpp | 14 ---- 3 files changed, 41 insertions(+), 51 deletions(-) diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 51093cec2..0ed85c9ac 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -3,22 +3,44 @@ #include #include #include +#include #include +#define RETRY_SLEEP_MS 1000 + #define CATCH_ZMQ_ERR(op, label) \ try { \ op; \ } catch (zmq::error_t & e) { \ - if (e.num() == ZMQ_ETERM) { \ - SPDLOG_TRACE( \ - "Got ZeroMQ ETERM for {} on address {}", label, address); \ - } else { \ - SPDLOG_ERROR("Caught ZeroMQ error for {} on address {}: {} ({})", \ - label, \ - address, \ - e.num(), \ - e.what()); \ + SPDLOG_ERROR("Caught ZeroMQ error for {} on address {}: {} ({})", \ + label, \ + address, \ + e.num(), \ + e.what()); \ + throw; \ + } + +#define CATCH_ZMQ_ERR_RETRY(op, label) \ + try { \ + op; \ + } catch (zmq::error_t & e) { \ + SPDLOG_WARN("Caught ZeroMQ error for {} on address {}: {} ({})", \ + label, \ + address, \ + e.num(), \ + e.what()); \ + SPDLOG_WARN("Retrying {} on address {}", label, address); \ + SLEEP_MS(RETRY_SLEEP_MS); \ + try { \ + op; \ + } catch (zmq::error_t & e2) { \ + SPDLOG_ERROR( \ + "Caught ZeroMQ error on retry for {} on address {}: {} ({})", \ + label, \ + address, \ + e2.num(), \ + e2.what()); \ throw; \ } \ } @@ -35,7 +57,6 @@ MessageEndpoint::MessageEndpoint(const std::string& hostIn, , tid(std::this_thread::get_id()) , id(faabric::util::generateGid()) { - // Check and set socket timeout if (timeoutMs <= 0) { SPDLOG_ERROR("Setting invalid timeout of {}", timeoutMs); @@ -62,25 +83,25 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, case zmq::socket_type::req: { SPDLOG_TRACE( "New socket: req {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR(socket.connect(address), "connect") + CATCH_ZMQ_ERR_RETRY(socket.connect(address), "connect") break; } case zmq::socket_type::push: { SPDLOG_TRACE( "New socket: push {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR(socket.connect(address), "connect") + CATCH_ZMQ_ERR_RETRY(socket.connect(address), "connect") break; } case zmq::socket_type::pull: { SPDLOG_TRACE( "New socket: pull {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR(socket.bind(address), "bind") + CATCH_ZMQ_ERR_RETRY(socket.bind(address), "bind") break; } case zmq::socket_type::rep: { SPDLOG_TRACE( "New socket: rep {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR(socket.bind(address), "bind") + CATCH_ZMQ_ERR_RETRY(socket.bind(address), "bind") break; } default: { @@ -149,7 +170,7 @@ Message MessageEndpoint::recvBuffer(zmq::socket_t& socket, int size) } } catch (zmq::error_t& e) { if (e.num() == ZMQ_ETERM) { - SPDLOG_TRACE("Endpoint received ETERM"); + SPDLOG_WARN("Endpoint {}:{} received ETERM on recv", host, port); return Message(); } @@ -173,7 +194,7 @@ Message MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) } } catch (zmq::error_t& e) { if (e.num() == ZMQ_ETERM) { - SPDLOG_TRACE("Endpoint received ETERM"); + SPDLOG_WARN("Endpoint {}:{} received ETERM on recv", host, port); return Message(); } throw; @@ -261,24 +282,8 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* serialisedMsg, doSend(reqSocket, serialisedMsg, msgSize, more); // Do the receive - zmq::message_t msg; - CATCH_ZMQ_ERR( - try { - auto res = reqSocket.recv(msg); - if (!res.has_value()) { - SPDLOG_ERROR("Timed out receiving message with no size"); - throw MessageTimeoutException("Timed out receiving message"); - } - } catch (zmq::error_t& e) { - if (e.num() == ZMQ_ETERM) { - SPDLOG_TRACE("Endpoint received ETERM"); - return Message(); - } - throw; - }, - "send_recv") - - return Message(msg); + SPDLOG_TRACE("RECV (REQ) {}", port); + return recvNoBuffer(reqSocket); } // ---------------------------------------------- @@ -300,7 +305,7 @@ Message AsyncRecvMessageEndpoint::recv(int size) // ---------------------------------------------- // SYNC RECV ENDPOINT // ---------------------------------------------- -// + SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) : MessageEndpoint(ANY_HOST, portIn, timeoutMs) { @@ -309,7 +314,7 @@ SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) Message SyncRecvMessageEndpoint::recv(int size) { - SPDLOG_TRACE("RECV {} (REP) ({} bytes)", port, size); + SPDLOG_TRACE("RECV (REP) {} ({} bytes)", port, size); return doRecv(repSocket, size); } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index eab83e775..ff373dd46 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -8,7 +8,6 @@ #include using namespace faabric::transport; - static const std::string thisHost = "127.0.0.1"; static const int testPort = 9800; diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 2dc5e935f..1073aba60 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -28,20 +28,6 @@ class DummyServer final : public MessageEndpointServer // Variable to keep track of the received messages int messageCount; - void start() override - { - // In a CI environment tests can be slow to tear down fully, so we want - // to sleep and retry if the initial connection fails. - try { - MessageEndpointServer::start(); - } catch (zmq::error_t& ex) { - SPDLOG_WARN("Error connecting dummy server, retrying after delay"); - - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - MessageEndpointServer::start(); - } - } - private: void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) override From 8bb3888b22a1b9ec85daff86369889de3f403cc3 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 28 Jun 2021 17:59:33 +0000 Subject: [PATCH 51/66] Avoid arbitrary sleeps in tests --- include/faabric/util/macros.h | 1 + src/transport/MessageEndpointServer.cpp | 13 +++++- tests/test/scheduler/test_executor.cpp | 34 ++++++++-------- .../scheduler/test_function_client_server.cpp | 19 +++++---- .../test/scheduler/test_remote_mpi_worlds.cpp | 40 ++++++++++--------- .../scheduler/test_snapshot_client_server.cpp | 3 -- tests/test/state/test_state.cpp | 3 -- tests/test/state/test_state_server.cpp | 1 - .../test_message_endpoint_client.cpp | 14 ++++--- tests/test/transport/test_message_server.cpp | 18 +++++---- tests/test/util/test_barrier.cpp | 3 +- tests/utils/faabric_utils.h | 19 +++++++++ tests/utils/fixtures.h | 4 ++ 13 files changed, 105 insertions(+), 67 deletions(-) diff --git a/include/faabric/util/macros.h b/include/faabric/util/macros.h index 504e4be40..e78e01df2 100644 --- a/include/faabric/util/macros.h +++ b/include/faabric/util/macros.h @@ -8,3 +8,4 @@ #define SLEEP_MS(ms) usleep((ms)*1000) #define UNUSED(x) (void)(x) + diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 2d43999ec..3abdae49c 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -14,8 +15,13 @@ MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) void MessageEndpointServer::start() { - asyncThread = std::thread([this] { + // Callers will only pass this barrier once the server sockets have been + // opened (hence we don't need to add arbitrary sleeps all over the place). + faabric::util::Barrier startBarrier(3); + + asyncThread = std::thread([this, &startBarrier] { AsyncRecvMessageEndpoint endpoint(asyncPort); + startBarrier.wait(); // Loop until we receive a shutdown message while (true) { @@ -45,8 +51,9 @@ void MessageEndpointServer::start() } }); - syncThread = std::thread([this] { + syncThread = std::thread([this, &startBarrier] { SyncRecvMessageEndpoint endpoint(syncPort); + startBarrier.wait(); // Loop until we receive a shutdown message while (true) { @@ -83,6 +90,8 @@ void MessageEndpointServer::start() endpoint.sendResponse(buffer, respSize); } }); + + startBarrier.wait(); } void MessageEndpointServer::stop() diff --git a/tests/test/scheduler/test_executor.cpp b/tests/test/scheduler/test_executor.cpp index 5558700b1..66d7edb31 100644 --- a/tests/test/scheduler/test_executor.cpp +++ b/tests/test/scheduler/test_executor.cpp @@ -439,7 +439,8 @@ TEST_CASE_METHOD(TestExecutorFixture, faabric::scheduler::queueResourceResponse(otherHost, resOther); // Background thread to execute main function and await results - std::thread t([] { + faabric::util::Barrier barrier(2); + std::thread t([&barrier] { int nThreads = 8; std::shared_ptr req = faabric::util::batchExecFactory("dummy", "thread-check", 1); @@ -449,19 +450,20 @@ TEST_CASE_METHOD(TestExecutorFixture, auto& sch = faabric::scheduler::getScheduler(); sch.callFunctions(req, false); + barrier.wait(); faabric::Message res = sch.getFunctionResult(msg.id(), 2000); assert(res.returnvalue() == 0); }); - // Give it time to have made the request - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); + // Wait until the function has executed and submitted another request + auto reqs = faabric::scheduler::getBatchRequests(); + REQUIRE_RETRY(reqs = faabric::scheduler::getBatchRequests(), + reqs.size() == 1); // Check restore hasn't been called yet REQUIRE(restoreCount == 0); // Get the request that's been submitted - auto reqs = faabric::scheduler::getBatchRequests(); - REQUIRE(reqs.size() == 1); std::string actualHost = reqs.at(0).first; REQUIRE(actualHost == otherHost); @@ -491,6 +493,7 @@ TEST_CASE_METHOD(TestExecutorFixture, } // Rejoin the other thread + barrier.wait(); if (t.joinable()) { t.join(); } @@ -519,11 +522,11 @@ TEST_CASE_METHOD(TestExecutorFixture, executeWithTestExecutor(req, true); - // We have to manually add a wait here as the thread results won't actually - // get logged on this host - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); + // Note that because the results don't actually get logged on this host, we + // can't wait on them as usual. auto actual = faabric::scheduler::getThreadResults(); - REQUIRE(actual.size() == nThreads); + REQUIRE_RETRY(actual = faabric::scheduler::getThreadResults(), + actual.size() == nThreads); std::vector actualMessageIds; for (auto& p : actual) { @@ -681,9 +684,7 @@ TEST_CASE_METHOD(TestExecutorFixture, REQUIRE(sch.getFunctionExecutorCount(msg) == 1); - SLEEP_MS(conf.boundTimeout + 500); - - REQUIRE(sch.getFunctionExecutorCount(msg) == 0); + REQUIRE_RETRY({}, sch.getFunctionExecutorCount(msg) == 0); } TEST_CASE_METHOD(TestExecutorFixture, @@ -711,11 +712,10 @@ TEST_CASE_METHOD(TestExecutorFixture, executeWithTestExecutor(req, true); - // Wait for executor to have finished - sometimes takes a while - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - - // Check thread results returned - REQUIRE(faabric::scheduler::getThreadResults().size() == nThreads); + // Results aren't set on this host as it's not the master, so we have to + // wait + REQUIRE_RETRY({}, + faabric::scheduler::getThreadResults().size() == nThreads); // Check results have been sent back to the master host auto actualResults = faabric::scheduler::getThreadResults(); diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index 06926c887..c66dfff8d 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -35,12 +35,11 @@ class ClientServerFixture ClientServerFixture() : cli(LOCALHOST) { - server.start(); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - // Set up executor executorFactory = std::make_shared(); setExecutorFactory(executorFactory); + + server.start(); } ~ClientServerFixture() @@ -77,9 +76,8 @@ TEST_CASE_METHOD(ClientServerFixture, REQUIRE(msgs.at(1).function() == "bar"); sch.clearRecordedMessages(); - // Send flush message + // Send flush message (which is synchronous) cli.sendFlush(); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); // Check the scheduler has been flushed REQUIRE(sch.getFunctionRegisteredHostCount(msgA) == 0); @@ -136,7 +134,11 @@ TEST_CASE_METHOD(ClientServerFixture, // Make the request cli.executeFunctions(req); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); + + for (const auto& m : req->messages()) { + // This timeout can be long as it shouldn't fail + sch.getFunctionResult(m.id(), 5 * SHORT_TEST_TIMEOUT_MS); + } // Check no other hosts have been registered faabric::Message m = req->messages().at(0); @@ -217,6 +219,7 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") reqA.set_host("foobar"); *reqA.mutable_function() = msg; + // Check that nothing's happened cli.unregister(reqA); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 1); @@ -225,8 +228,8 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") reqB.set_host(otherHost); *reqB.mutable_function() = msg; cli.unregister(reqB); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 0); + + REQUIRE_RETRY({}, sch.getFunctionRegisteredHostCount(msg) == 0); sch.setThisHostResources(originalResources); faabric::scheduler::clearMockRequests(); diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index cdae57345..d890ac369 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -78,16 +79,13 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") otherWorld.destroy(); }); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - - REQUIRE(thisWorld.getHostForRank(0) == thisHost); - REQUIRE(thisWorld.getHostForRank(1) == otherHost); - - // Clean up if (otherWorldThread.joinable()) { otherWorldThread.join(); } + REQUIRE(thisWorld.getHostForRank(0) == thisHost); + REQUIRE(thisWorld.getHostForRank(1) == otherHost); + thisWorld.destroy(); } @@ -173,7 +171,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(buffer, buffer + messageData2.size()); REQUIRE(actual == messageData2); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); + testBarrier.wait(); otherWorld.destroy(); }); @@ -194,6 +192,8 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(status.MPI_ERROR == MPI_SUCCESS); REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); + testBarrier.wait(); + // Clean up if (otherWorldThread.joinable()) { otherWorldThread.join(); @@ -272,7 +272,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); + testBarrier.wait(); otherWorld.destroy(); }); @@ -288,6 +288,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Clean up + testBarrier.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -325,8 +326,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Give the other host time to receive the broadcast - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - + testBarrier.wait(); otherWorld.destroy(); }); @@ -339,6 +339,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Clean up + testBarrier.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -398,8 +399,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - + testBarrier.wait(); otherWorld.destroy(); }); @@ -436,6 +436,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); // Clean up + testBarrier.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -485,8 +486,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - + testBarrier.wait(); otherWorld.destroy(); }); @@ -518,6 +518,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, REQUIRE(actual == expected); // Clean up + testBarrier.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -555,8 +556,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - + testBarrier.wait(); otherWorld.destroy(); }); @@ -585,6 +585,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(asyncMessage == messageData); // Clean up + testBarrier.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -614,8 +615,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - + testBarrier.wait(); otherWorld.destroy(); }); @@ -649,6 +649,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(recv3 == 2); // Clean up + testBarrier.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -692,7 +693,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_STATUS_IGNORE); } - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); + testBarrier.wait(); otherWorld.destroy(); }); @@ -716,6 +717,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Clean up + testBarrier.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } diff --git a/tests/test/scheduler/test_snapshot_client_server.cpp b/tests/test/scheduler/test_snapshot_client_server.cpp index 7c94b04ce..9319c76ed 100644 --- a/tests/test/scheduler/test_snapshot_client_server.cpp +++ b/tests/test/scheduler/test_snapshot_client_server.cpp @@ -30,7 +30,6 @@ class SnapshotClientServerFixture : cli(LOCALHOST) { server.start(); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); } ~SnapshotClientServerFixture() { server.stop(); } @@ -63,8 +62,6 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, cli.pushSnapshot(snapKeyA, snapA); cli.pushSnapshot(snapKeyB, snapB); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - // Check snapshots created in registry REQUIRE(reg.getSnapshotCount() == 2); const faabric::util::SnapshotData& actualA = reg.getSnapshot(snapKeyA); diff --git a/tests/test/state/test_state.cpp b/tests/test/state/test_state.cpp index 93b353ca8..1827531cd 100644 --- a/tests/test/state/test_state.cpp +++ b/tests/test/state/test_state.cpp @@ -42,9 +42,6 @@ class StateServerTestFixture // Start the state server SPDLOG_DEBUG("Running state server"); stateServer.start(); - - // Give it time to start - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); } ~StateServerTestFixture() { stateServer.stop(); } diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index d0f52faa8..fa497b28f 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -28,7 +28,6 @@ class SimpleStateServerTestFixture conf.stateMode = "inmemory"; server.start(); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); } ~SimpleStateServerTestFixture() { server.stop(); } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index ff373dd46..f66e90998 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -41,22 +41,26 @@ TEST_CASE_METHOD(SchedulerTestFixture, AsyncSendMessageEndpoint src(thisHost, testPort); - // Run the recv in the background - std::thread recvThread([expectedMsg] { - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - AsyncRecvMessageEndpoint dst(testPort); + faabric::util::Barrier barrier(2); + + std::thread recvThread([&barrier, expectedMsg] { + // Make sure this only runs once the send has been done + barrier.wait(); // Receive message + AsyncRecvMessageEndpoint dst(testPort); faabric::transport::Message recvMsg = dst.recv(); + assert(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); assert(actualMsg == expectedMsg); }); - // Send message (should wait for receiver to become ready) uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); + src.send(msg, expectedMsg.size()); + barrier.wait(); if (recvThread.joinable()) { recvThread.join(); diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 1073aba60..d17840c48 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -118,9 +118,7 @@ TEST_CASE("Test send one message to server", "[transport]") memcpy(bodyMsg, body.c_str(), body.size()); cli.asyncSend(0, bodyMsg, body.size()); - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - - REQUIRE(server.messageCount == 1); + REQUIRE_RETRY({}, server.messageCount == 1); server.stop(); } @@ -146,6 +144,7 @@ TEST_CASE("Test send response to client", "[transport]") TEST_CASE("Test multiple clients talking to one server", "[transport]") { + // Start the server in the background DummyServer server; server.start(); @@ -153,8 +152,11 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") int numClients = 10; int numMessages = 1000; + // Set up a barrier to wait on all the clients having finished + faabric::util::Barrier barrier(numClients + 1); + for (int i = 0; i < numClients; i++) { - clientThreads.emplace_back(std::thread([numMessages] { + clientThreads.emplace_back(std::thread([&barrier, numMessages] { // Prepare client MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); @@ -165,18 +167,20 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") memcpy(body, clientMsg.c_str(), clientMsg.size()); cli.asyncSend(0, body, clientMsg.size()); } + + barrier.wait(); })); } + barrier.wait(); + for (auto& t : clientThreads) { if (t.joinable()) { t.join(); } } - SLEEP_MS(2 * SHORT_TEST_TIMEOUT_MS); - - REQUIRE(server.messageCount == numMessages * numClients); + REQUIRE_RETRY({}, server.messageCount == numMessages * numClients); server.stop(); } diff --git a/tests/test/util/test_barrier.cpp b/tests/test/util/test_barrier.cpp index 582c7b8e3..40f64dfb6 100644 --- a/tests/test/util/test_barrier.cpp +++ b/tests/test/util/test_barrier.cpp @@ -24,8 +24,7 @@ TEST_CASE("Test barrier operation", "[util]") auto t2 = std::thread([&b] { b.wait(); }); // Sleep for a bit while the threads spawn - SLEEP_MS(SHORT_TEST_TIMEOUT_MS); - REQUIRE(b.getSlotCount() == 1); + REQUIRE_RETRY({}, b.getSlotCount() == 1); // Join with master to go through barrier b.wait(); diff --git a/tests/utils/faabric_utils.h b/tests/utils/faabric_utils.h index 24f078d60..af6d8a694 100644 --- a/tests/utils/faabric_utils.h +++ b/tests/utils/faabric_utils.h @@ -12,6 +12,25 @@ using namespace faabric; #define SHORT_TEST_TIMEOUT_MS 1000 +#define REQUIRE_RETRY_MAX 5 +#define REQUIRE_RETRY_SLEEP_MS 1000 + +#define REQUIRE_RETRY(updater, check) \ + { \ + { updater; }; \ + bool res = (check); \ + int count = 0; \ + while (!res && count < REQUIRE_RETRY_MAX) { \ + count++; \ + SLEEP_MS(REQUIRE_RETRY_SLEEP_MS); \ + { updater; }; \ + res = (check); \ + } \ + if (!res) { \ + FAIL(); \ + } \ + } + #define FAABRIC_CATCH_LOGGER \ struct LogListener : Catch::TestEventListenerBase \ { \ diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index f37f6f374..f54b4f426 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -175,6 +176,7 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture public: RemoteMpiTestFixture() : thisHost(faabric::util::getSystemConfig().endpointHost) + , testBarrier(2) { otherWorld.overrideHost(otherHost); @@ -215,6 +217,8 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture std::string thisHost; std::string otherHost = LOCALHOST; + faabric::util::Barrier testBarrier; + faabric::scheduler::MpiWorld otherWorld; }; } From 6bebb61c814e8536c5fcbc286319b206916ec551 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 07:51:28 +0000 Subject: [PATCH 52/66] Add timeout on barrier --- include/faabric/util/barrier.h | 6 +++++- src/transport/MessageEndpointServer.cpp | 5 +++-- src/util/barrier.cpp | 11 +++++++++-- tests/test/util/test_barrier.cpp | 8 ++++++++ 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/include/faabric/util/barrier.h b/include/faabric/util/barrier.h index b5afaf276..5c3763a85 100644 --- a/include/faabric/util/barrier.h +++ b/include/faabric/util/barrier.h @@ -5,10 +5,12 @@ namespace faabric::util { +#define DEFAULT_BARRIER_TIMEOUT_MS 10000 + class Barrier { public: - explicit Barrier(int count); + explicit Barrier(int count, int timeoutMsIn=DEFAULT_BARRIER_TIMEOUT_MS); void wait(); @@ -20,6 +22,8 @@ class Barrier int threadCount; int slotCount; int uses; + int timeoutMs; + std::mutex mx; std::condition_variable cv; }; diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 3abdae49c..256ee4757 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -15,8 +15,9 @@ MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) void MessageEndpointServer::start() { - // Callers will only pass this barrier once the server sockets have been - // opened (hence we don't need to add arbitrary sleeps all over the place). + // This barrier means that callers can guarantee that when this function + // completes, both sockets will have been opened (and hence the server is + // ready to use). faabric::util::Barrier startBarrier(3); asyncThread = std::thread([this, &startBarrier] { diff --git a/src/util/barrier.cpp b/src/util/barrier.cpp index 67bed11a5..4dc63a876 100644 --- a/src/util/barrier.cpp +++ b/src/util/barrier.cpp @@ -2,10 +2,11 @@ #include namespace faabric::util { -Barrier::Barrier(int count) +Barrier::Barrier(int count, int timeoutMsIn) : threadCount(count) , slotCount(count) , uses(0) + , timeoutMs(timeoutMsIn) {} void Barrier::wait() @@ -24,7 +25,13 @@ void Barrier::wait() slotCount = threadCount; cv.notify_all(); } else { - cv.wait(lock, [&] { return usesCopy < uses; }); + auto timePoint = std::chrono::system_clock::now() + + std::chrono::milliseconds(timeoutMs); + bool waitRes = + cv.wait_until(lock, timePoint, [&] { return usesCopy < uses; }); + if (!waitRes) { + throw std::runtime_error("Barrier timed out"); + } } } } diff --git a/tests/test/util/test_barrier.cpp b/tests/test/util/test_barrier.cpp index 40f64dfb6..97b075c96 100644 --- a/tests/test/util/test_barrier.cpp +++ b/tests/test/util/test_barrier.cpp @@ -40,4 +40,12 @@ TEST_CASE("Test barrier operation", "[util]") REQUIRE(b.getSlotCount() == 3); REQUIRE(b.getUseCount() == 1); } + +TEST_CASE("Test barrier timeout", "[util]") +{ + int timeoutMs = 500; + Barrier b(2, timeoutMs); + + REQUIRE_THROWS(b.wait()); +} } From 51fd3d4bc3654507e7c1879ba45776c7f3430e65 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 09:19:54 +0000 Subject: [PATCH 53/66] Long-lived shutdown endpoints in server --- include/faabric/transport/MessageEndpoint.h | 10 +-- .../faabric/transport/MessageEndpointServer.h | 5 +- include/faabric/util/barrier.h | 2 +- include/faabric/util/macros.h | 1 - src/state/StateKeyValue.cpp | 2 +- src/transport/MessageEndpoint.cpp | 30 +++----- src/transport/MessageEndpointServer.cpp | 75 ++++++++----------- tests/test/scheduler/test_scheduler.cpp | 2 +- tests/test/state/test_state_server.cpp | 1 - tests/test/transport/test_message_server.cpp | 28 ++++--- tests/utils/faabric_utils.h | 10 ++- 11 files changed, 77 insertions(+), 89 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 8b69e865d..6bb530a3d 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -74,9 +74,7 @@ class AsyncSendMessageEndpoint : public MessageEndpoint void sendHeader(int header); - void sendShutdown(); - - void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); + void send(const uint8_t* data, size_t dataSize, bool more = false); private: zmq::socket_t pushSocket; @@ -91,10 +89,10 @@ class SyncSendMessageEndpoint : public MessageEndpoint void sendHeader(int header); - void sendShutdown(); + void sendRaw(const uint8_t* data, size_t dataSize); - Message sendAwaitResponse(const uint8_t* serialisedMsg, - size_t msgSize, + Message sendAwaitResponse(const uint8_t* data, + size_t dataSize, bool more = false); private: diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 3d7df8f74..2b3ea6a46 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -41,11 +41,12 @@ class MessageEndpointServer private: const int asyncPort; - const int syncPort; std::thread asyncThread; - std::thread syncThread; + + AsyncSendMessageEndpoint asyncShutdownSender; + SyncSendMessageEndpoint syncShutdownSender; }; } diff --git a/include/faabric/util/barrier.h b/include/faabric/util/barrier.h index 5c3763a85..97139e692 100644 --- a/include/faabric/util/barrier.h +++ b/include/faabric/util/barrier.h @@ -10,7 +10,7 @@ namespace faabric::util { class Barrier { public: - explicit Barrier(int count, int timeoutMsIn=DEFAULT_BARRIER_TIMEOUT_MS); + explicit Barrier(int count, int timeoutMsIn = DEFAULT_BARRIER_TIMEOUT_MS); void wait(); diff --git a/include/faabric/util/macros.h b/include/faabric/util/macros.h index e78e01df2..504e4be40 100644 --- a/include/faabric/util/macros.h +++ b/include/faabric/util/macros.h @@ -8,4 +8,3 @@ #define SLEEP_MS(ms) usleep((ms)*1000) #define UNUSED(x) (void)(x) - diff --git a/src/state/StateKeyValue.cpp b/src/state/StateKeyValue.cpp index a1936f74b..6f7c223c9 100644 --- a/src/state/StateKeyValue.cpp +++ b/src/state/StateKeyValue.cpp @@ -559,7 +559,7 @@ uint32_t StateKeyValue::waitOnRedisRemoteLock(const std::string& redisKey) break; } - SLEEP_MS(1); + SLEEP_MS(500); remoteLockId = redis.acquireLock(redisKey, REMOTE_LOCK_TIMEOUT_SECS); retryCount++; diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 0ed85c9ac..808478b1f 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -233,19 +233,12 @@ void AsyncSendMessageEndpoint::sendHeader(int header) doSend(pushSocket, &headerBytes, sizeof(headerBytes), true); } -void AsyncSendMessageEndpoint::sendShutdown() -{ - int header = -1; - uint8_t headerBytes = static_cast(header); - doSend(pushSocket, &headerBytes, sizeof(headerBytes), false); -} - -void AsyncSendMessageEndpoint::send(uint8_t* serialisedMsg, - size_t msgSize, +void AsyncSendMessageEndpoint::send(const uint8_t* data, + size_t dataSize, bool more) { - SPDLOG_TRACE("PUSH {}:{} ({} bytes, more {})", host, port, msgSize, more); - doSend(pushSocket, serialisedMsg, msgSize, more); + SPDLOG_TRACE("PUSH {}:{} ({} bytes, more {})", host, port, dataSize, more); + doSend(pushSocket, data, dataSize, more); } // ---------------------------------------------- @@ -266,20 +259,19 @@ void SyncSendMessageEndpoint::sendHeader(int header) doSend(reqSocket, &headerBytes, sizeof(headerBytes), true); } -void SyncSendMessageEndpoint::sendShutdown() +void SyncSendMessageEndpoint::sendRaw(const uint8_t* data, size_t dataSize) { - int header = -1; - uint8_t headerBytes = static_cast(header); - doSend(reqSocket, &headerBytes, sizeof(headerBytes), false); + SPDLOG_TRACE("REQ {}:{} ({} bytes)", host, port, dataSize); + doSend(reqSocket, data, dataSize, false); } -Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* serialisedMsg, - size_t msgSize, +Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, + size_t dataSize, bool more) { - SPDLOG_TRACE("REQ {}:{} ({} bytes, more {})", host, port, msgSize, more); + SPDLOG_TRACE("REQ {}:{} ({} bytes, more {})", host, port, dataSize, more); - doSend(reqSocket, serialisedMsg, msgSize, more); + doSend(reqSocket, data, dataSize, more); // Do the receive SPDLOG_TRACE("RECV (REQ) {}", port); diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 256ee4757..a7fc4c2de 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -8,9 +8,34 @@ #include namespace faabric::transport { + +static const std::vector shutdownHeader = { 0, 0, 1, 1 }; + +#define SHUTDOWN_CHECK(header, label) \ + { \ + if (header.size() == shutdownHeader.size()) { \ + if (header.dataCopy() == shutdownHeader) { \ + SPDLOG_TRACE("Server {} endpoint received shutdown message", \ + label); \ + break; \ + } \ + } \ + } + +#define RECEIVE_BODY(header, endpoint) \ + if (!header.more()) { \ + throw std::runtime_error("Header sent without SNDMORE flag"); \ + } \ + Message body = endpoint.recv(); \ + if (body.more()) { \ + throw std::runtime_error("Body sent with SNDMORE flag"); \ + } + MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) : asyncPort(asyncPortIn) , syncPort(syncPortIn) + , asyncShutdownSender(LOCALHOST, asyncPort) + , syncShutdownSender(LOCALHOST, syncPort) {} void MessageEndpointServer::start() @@ -24,28 +49,13 @@ void MessageEndpointServer::start() AsyncRecvMessageEndpoint endpoint(asyncPort); startBarrier.wait(); - // Loop until we receive a shutdown message while (true) { // Receive header and body Message header = endpoint.recv(); - // Detect shutdown condition - if (header.size() == sizeof(uint8_t) && !header.more()) { - SPDLOG_TRACE("Async server socket received shutdown message"); - break; - } - - // Check the header was sent with ZMQ_SNDMORE flag - if (!header.more()) { - throw std::runtime_error("Header sent without SNDMORE flag"); - } + SHUTDOWN_CHECK(header, "async") - // Check that there are no more messages to receive - Message body = endpoint.recv(); - if (body.more()) { - throw std::runtime_error("Body sent with SNDMORE flag"); - } - assert(body.udata() != nullptr); + RECEIVE_BODY(header, endpoint) // Server-specific message handling doAsyncRecv(header, body); @@ -56,27 +66,13 @@ void MessageEndpointServer::start() SyncRecvMessageEndpoint endpoint(syncPort); startBarrier.wait(); - // Loop until we receive a shutdown message while (true) { // Receive header and body Message header = endpoint.recv(); - // Detect shutdown condition - if (header.size() == sizeof(uint8_t) && !header.more()) { - SPDLOG_TRACE("Sync server socket received shutdown message"); - break; - } + SHUTDOWN_CHECK(header, "sync") - // Check the header was sent with ZMQ_SNDMORE flag - if (!header.more()) { - throw std::runtime_error("Header sent without SNDMORE flag"); - } - - // Check that there are no more messages to receive - Message body = endpoint.recv(); - if (body.more()) { - throw std::runtime_error("Body sent with SNDMORE flag"); - } + RECEIVE_BODY(header, endpoint) // Server-specific message handling std::unique_ptr resp = @@ -97,17 +93,12 @@ void MessageEndpointServer::start() void MessageEndpointServer::stop() { - SPDLOG_TRACE( - "Sending sync shutdown message locally to {}:{}", LOCALHOST, syncPort); - - SyncSendMessageEndpoint syncSender(LOCALHOST, syncPort); - syncSender.sendShutdown(); + // Send shutdown messages + SPDLOG_TRACE("Server sending shutdown messages"); - SPDLOG_TRACE( - "Sending async shutdown message locally to {}:{}", LOCALHOST, asyncPort); + syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); - AsyncSendMessageEndpoint asyncSender(LOCALHOST, asyncPort); - asyncSender.sendShutdown(); + asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); // Join the threads if (asyncThread.joinable()) { diff --git a/tests/test/scheduler/test_scheduler.cpp b/tests/test/scheduler/test_scheduler.cpp index c999eb45e..dfb00fe9c 100644 --- a/tests/test/scheduler/test_scheduler.cpp +++ b/tests/test/scheduler/test_scheduler.cpp @@ -33,7 +33,7 @@ class SlowExecutor final : public Executor int msgIdx, std::shared_ptr req) override { - SPDLOG_DEBUG("SlowExecutor executing task{}", + SPDLOG_DEBUG("Slow executor executing task{}", req->mutable_messages()->at(msgIdx).id()); SLEEP_MS(SHORT_TEST_TIMEOUT_MS); diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index fa497b28f..2a5603dca 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -162,7 +162,6 @@ TEST_CASE_METHOD(SimpleStateServerTestFixture, "Test local-only append", "[state]") { - // Append a few chunks std::vector chunkA = { 1, 1 }; std::vector chunkB = { 2, 2, 2 }; std::vector chunkC = { 3, 3 }; diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index d17840c48..e23c37d3e 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -72,12 +72,12 @@ class EchoServer final : public MessageEndpointServer } }; -class SlowServer final : public MessageEndpointServer +class SleepServer final : public MessageEndpointServer { public: int delayMs = 1000; - SlowServer() + SleepServer() : MessageEndpointServer(testPortAsync, testPortSync) {} @@ -85,18 +85,19 @@ class SlowServer final : public MessageEndpointServer void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) override { - throw std::runtime_error("SlowServer not expecting async recv"); + throw std::runtime_error("Sleep server not expecting async recv"); } std::unique_ptr doSyncRecv( faabric::transport::Message& header, faabric::transport::Message& body) override { - SPDLOG_DEBUG("Slow message server test recv"); + int* sleepTimeMs = (int*)body.udata(); + SPDLOG_DEBUG("Sleep server sleeping for {}ms", *sleepTimeMs); + SLEEP_MS(*sleepTimeMs); - SLEEP_MS(delayMs); auto response = std::make_unique(); - response->set_data("From the slow server"); + response->set_data("Response after sleep"); return response; } }; @@ -188,39 +189,42 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") TEST_CASE("Test client timeout on requests to valid server", "[transport]") { int clientTimeout; + int serverSleep; bool expectFailure; SECTION("Long timeout no failure") { clientTimeout = 20000; + serverSleep = 100; expectFailure = false; } SECTION("Short timeout failure") { clientTimeout = 10; + serverSleep = 2000; expectFailure = true; } // Start the server - SlowServer server; + SleepServer server; server.start(); // Set up the client MessageEndpointClient cli( thisHost, testPortAsync, testPortSync, clientTimeout); - std::vector data = { 1, 1, 1 }; + + uint8_t* sleepBytes = BYTES(&serverSleep); faabric::StatePart response; if (expectFailure) { // Check for failure - REQUIRE_THROWS_AS(cli.syncSend(0, data.data(), data.size(), &response), + REQUIRE_THROWS_AS(cli.syncSend(0, sleepBytes, sizeof(int), &response), MessageTimeoutException); } else { - cli.syncSend(0, data.data(), data.size(), &response); - + cli.syncSend(0, sleepBytes, sizeof(int), &response); std::vector expected = { 0, 1, 2, 3 }; - REQUIRE(response.data() == "From the slow server"); + REQUIRE(response.data() == "Response after sleep"); } server.stop(); diff --git a/tests/utils/faabric_utils.h b/tests/utils/faabric_utils.h index af6d8a694..acc86cb73 100644 --- a/tests/utils/faabric_utils.h +++ b/tests/utils/faabric_utils.h @@ -17,13 +17,17 @@ using namespace faabric; #define REQUIRE_RETRY(updater, check) \ { \ - { updater; }; \ + { \ + updater; \ + }; \ bool res = (check); \ int count = 0; \ while (!res && count < REQUIRE_RETRY_MAX) { \ count++; \ - SLEEP_MS(REQUIRE_RETRY_SLEEP_MS); \ - { updater; }; \ + SLEEP_MS(REQUIRE_RETRY_SLEEP_MS); \ + { \ + updater; \ + }; \ res = (check); \ } \ if (!res) { \ From 4de6f2702135367bab0963c941fc1eefc9604637 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 11:56:48 +0000 Subject: [PATCH 54/66] Added latch to async server --- .../faabric/transport/MessageEndpointServer.h | 23 ++++---- include/faabric/util/{barrier.h => latch.h} | 15 ++--- src/transport/MessageEndpointServer.cpp | 44 ++++++++++---- src/util/CMakeLists.txt | 2 +- src/util/barrier.cpp | 49 ---------------- src/util/latch.cpp | 31 ++++++++++ tests/test/scheduler/test_executor.cpp | 8 +-- .../scheduler/test_function_client_server.cpp | 7 ++- .../test/scheduler/test_remote_mpi_worlds.cpp | 33 ++++++----- .../test_message_endpoint_client.cpp | 56 +++++++++--------- tests/test/transport/test_message_server.cpp | 57 +++++++++---------- .../transport/test_mpi_message_endpoint.cpp | 1 - tests/test/util/test_barrier.cpp | 30 ++++------ tests/utils/fixtures.h | 6 +- 14 files changed, 174 insertions(+), 188 deletions(-) rename include/faabric/util/{barrier.h => latch.h} (59%) delete mode 100644 src/util/barrier.cpp create mode 100644 src/util/latch.cpp diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 2b3ea6a46..999a6544e 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -3,17 +3,14 @@ #include #include #include +#include #include namespace faabric::transport { -/* Server handling a long-running 0MQ socket - * - * This abstract class implements a server-like loop functionality and will - * always run in the background. Note that message endpoints (i.e. 0MQ sockets) - * are _not_ thread safe, must be open-ed and close-ed from the _same_ thread, - * and thus should preferably live in the thread's local address space. - */ + +// This server has two underlying sockets, one for synchronous communication and +// one for asynchronous. class MessageEndpointServer { public: @@ -23,13 +20,11 @@ class MessageEndpointServer virtual void stop(); + void setAsyncLatch(); + + void awaitAsyncLatch(); + protected: - /* Template function to handle message reception - * - * A message endpoint server in faabric expects each communication to be - * a multi-part 0MQ message. One message containing the header, and another - * one with the body. Note that 0MQ _guarantees_ in-order delivery. - */ virtual void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) = 0; @@ -48,5 +43,7 @@ class MessageEndpointServer AsyncSendMessageEndpoint asyncShutdownSender; SyncSendMessageEndpoint syncShutdownSender; + + std::unique_ptr asyncLatch; }; } diff --git a/include/faabric/util/barrier.h b/include/faabric/util/latch.h similarity index 59% rename from include/faabric/util/barrier.h rename to include/faabric/util/latch.h index 97139e692..a10ad431b 100644 --- a/include/faabric/util/barrier.h +++ b/include/faabric/util/latch.h @@ -7,21 +7,16 @@ namespace faabric::util { #define DEFAULT_BARRIER_TIMEOUT_MS 10000 -class Barrier +class Latch { public: - explicit Barrier(int count, int timeoutMsIn = DEFAULT_BARRIER_TIMEOUT_MS); + explicit Latch(int countIn, int timeoutMsIn = DEFAULT_BARRIER_TIMEOUT_MS); void wait(); - - int getSlotCount(); - - int getUseCount(); - private: - int threadCount; - int slotCount; - int uses; + int count; + int waiters = 0; + int timeoutMs; std::mutex mx; diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index a7fc4c2de..5dd6cbaef 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include @@ -40,14 +40,14 @@ MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) void MessageEndpointServer::start() { - // This barrier means that callers can guarantee that when this function + // This latch means that callers can guarantee that when this function // completes, both sockets will have been opened (and hence the server is // ready to use). - faabric::util::Barrier startBarrier(3); + faabric::util::Latch startLatch(3); - asyncThread = std::thread([this, &startBarrier] { + asyncThread = std::thread([this, &startLatch] { AsyncRecvMessageEndpoint endpoint(asyncPort); - startBarrier.wait(); + startLatch.wait(); while (true) { // Receive header and body @@ -59,12 +59,20 @@ void MessageEndpointServer::start() // Server-specific message handling doAsyncRecv(header, body); + + // Wait on the async latch if necessary + if (asyncLatch != nullptr) { + SPDLOG_TRACE( + "Server thread waiting on async latch for port {}", + asyncPort); + asyncLatch->wait(); + } } }); - syncThread = std::thread([this, &startBarrier] { + syncThread = std::thread([this, &startLatch] { SyncRecvMessageEndpoint endpoint(syncPort); - startBarrier.wait(); + startLatch.wait(); while (true) { // Receive header and body @@ -88,18 +96,19 @@ void MessageEndpointServer::start() } }); - startBarrier.wait(); + startLatch.wait(); } void MessageEndpointServer::stop() { // Send shutdown messages - SPDLOG_TRACE("Server sending shutdown messages"); - - syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); + SPDLOG_TRACE( + "Server sending shutdown messages to ports {} {}", asyncPort, syncPort); asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); + syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); + // Join the threads if (asyncThread.joinable()) { asyncThread.join(); @@ -110,4 +119,17 @@ void MessageEndpointServer::stop() } } +void MessageEndpointServer::setAsyncLatch() +{ + asyncLatch = std::make_unique(2); +} + +void MessageEndpointServer::awaitAsyncLatch() +{ + SPDLOG_TRACE("Waiting on async latch for port {}", asyncPort); + asyncLatch->wait(); + + SPDLOG_TRACE("Finished async latch for port {}", asyncPort); + asyncLatch = nullptr; +} } diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index cc204a2fb..3db722736 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -3,7 +3,6 @@ find_package(RapidJSON) file(GLOB HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/util/*.h") set(LIB_FILES - barrier.cpp bytes.cpp config.cpp clock.cpp @@ -14,6 +13,7 @@ set(LIB_FILES gids.cpp http.cpp json.cpp + latch.cpp logging.cpp memory.cpp network.cpp diff --git a/src/util/barrier.cpp b/src/util/barrier.cpp deleted file mode 100644 index 4dc63a876..000000000 --- a/src/util/barrier.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include - -namespace faabric::util { -Barrier::Barrier(int count, int timeoutMsIn) - : threadCount(count) - , slotCount(count) - , uses(0) - , timeoutMs(timeoutMsIn) -{} - -void Barrier::wait() -{ - { - UniqueLock lock(mx); - int usesCopy = uses; - - slotCount--; - if (slotCount == 0) { - uses++; - // Checks for overflow - if (uses < 0) { - throw std::runtime_error("Barrier was used too many times"); - } - slotCount = threadCount; - cv.notify_all(); - } else { - auto timePoint = std::chrono::system_clock::now() + - std::chrono::milliseconds(timeoutMs); - bool waitRes = - cv.wait_until(lock, timePoint, [&] { return usesCopy < uses; }); - if (!waitRes) { - throw std::runtime_error("Barrier timed out"); - } - } - } -} - -int Barrier::getSlotCount() -{ - return slotCount; -} - -int Barrier::getUseCount() -{ - return uses; -} - -} diff --git a/src/util/latch.cpp b/src/util/latch.cpp new file mode 100644 index 000000000..d438910fe --- /dev/null +++ b/src/util/latch.cpp @@ -0,0 +1,31 @@ +#include +#include + +namespace faabric::util { +Latch::Latch(int countIn, int timeoutMsIn) + : count(countIn) + , timeoutMs(timeoutMsIn) +{} + +void Latch::wait() +{ + UniqueLock lock(mx); + + waiters++; + + if (waiters > count) { + throw std::runtime_error("Latch already used"); + } + + if (waiters == count) { + cv.notify_all(); + } else { + auto timePoint = std::chrono::system_clock::now() + + std::chrono::milliseconds(timeoutMs); + + if (!cv.wait_until(lock, timePoint, [&] { return waiters >= count; })) { + throw std::runtime_error("latch timed out"); + } + } +} +} diff --git a/tests/test/scheduler/test_executor.cpp b/tests/test/scheduler/test_executor.cpp index 66d7edb31..88dbbc85c 100644 --- a/tests/test/scheduler/test_executor.cpp +++ b/tests/test/scheduler/test_executor.cpp @@ -439,8 +439,8 @@ TEST_CASE_METHOD(TestExecutorFixture, faabric::scheduler::queueResourceResponse(otherHost, resOther); // Background thread to execute main function and await results - faabric::util::Barrier barrier(2); - std::thread t([&barrier] { + faabric::util::Latch latch(2); + std::thread t([&latch] { int nThreads = 8; std::shared_ptr req = faabric::util::batchExecFactory("dummy", "thread-check", 1); @@ -450,7 +450,7 @@ TEST_CASE_METHOD(TestExecutorFixture, auto& sch = faabric::scheduler::getScheduler(); sch.callFunctions(req, false); - barrier.wait(); + latch.wait(); faabric::Message res = sch.getFunctionResult(msg.id(), 2000); assert(res.returnvalue() == 0); }); @@ -493,7 +493,7 @@ TEST_CASE_METHOD(TestExecutorFixture, } // Rejoin the other thread - barrier.wait(); + latch.wait(); if (t.joinable()) { t.join(); } diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index c66dfff8d..ccbfca203 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -220,16 +220,21 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") *reqA.mutable_function() = msg; // Check that nothing's happened + server.setAsyncLatch(); cli.unregister(reqA); + server.awaitAsyncLatch(); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 1); // Make the request to unregister the actual host faabric::UnregisterRequest reqB; reqB.set_host(otherHost); *reqB.mutable_function() = msg; + + server.setAsyncLatch(); cli.unregister(reqB); + server.awaitAsyncLatch(); - REQUIRE_RETRY({}, sch.getFunctionRegisteredHostCount(msg) == 0); + REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 0); sch.setThisHostResources(originalResources); faabric::scheduler::clearMockRequests(); diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index d890ac369..21780001e 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -171,7 +170,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(buffer, buffer + messageData2.size()); REQUIRE(actual == messageData2); - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -192,7 +191,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(status.MPI_ERROR == MPI_SUCCESS); REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); - testBarrier.wait(); + testLatch.wait(); // Clean up if (otherWorldThread.joinable()) { @@ -272,7 +271,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -288,7 +287,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Clean up - testBarrier.wait(); + testLatch.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -326,7 +325,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Give the other host time to receive the broadcast - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -339,7 +338,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Clean up - testBarrier.wait(); + testLatch.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -399,7 +398,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -436,7 +435,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); // Clean up - testBarrier.wait(); + testLatch.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -486,7 +485,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -518,7 +517,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, REQUIRE(actual == expected); // Clean up - testBarrier.wait(); + testLatch.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -556,7 +555,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -585,7 +584,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(asyncMessage == messageData); // Clean up - testBarrier.wait(); + testLatch.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -615,7 +614,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -649,7 +648,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(recv3 == 2); // Clean up - testBarrier.wait(); + testLatch.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -693,7 +692,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_STATUS_IGNORE); } - testBarrier.wait(); + testLatch.wait(); otherWorld.destroy(); }); @@ -717,7 +716,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Clean up - testBarrier.wait(); + testLatch.wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index f66e90998..4e463acbc 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -8,8 +8,8 @@ #include using namespace faabric::transport; -static const std::string thisHost = "127.0.0.1"; -static const int testPort = 9800; + +#define TEST_PORT 9800 namespace tests { @@ -17,8 +17,8 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv one message", "[transport]") { - AsyncSendMessageEndpoint src(thisHost, testPort); - AsyncRecvMessageEndpoint dst(testPort); + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); + AsyncRecvMessageEndpoint dst(TEST_PORT); // Send message std::string expectedMsg = "Hello world!"; @@ -39,16 +39,16 @@ TEST_CASE_METHOD(SchedulerTestFixture, { std::string expectedMsg = "Hello world!"; - AsyncSendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); - faabric::util::Barrier barrier(2); + faabric::util::Latch latch(2); - std::thread recvThread([&barrier, expectedMsg] { + std::thread recvThread([&latch, expectedMsg] { // Make sure this only runs once the send has been done - barrier.wait(); + latch.wait(); // Receive message - AsyncRecvMessageEndpoint dst(testPort); + AsyncRecvMessageEndpoint dst(TEST_PORT); faabric::transport::Message recvMsg = dst.recv(); assert(recvMsg.size() == expectedMsg.size()); @@ -60,7 +60,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); src.send(msg, expectedMsg.size()); - barrier.wait(); + latch.wait(); if (recvThread.joinable()) { recvThread.join(); @@ -75,7 +75,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") std::thread senderThread([expectedMsg, expectedResponse] { // Open the source endpoint client - SyncSendMessageEndpoint src(thisHost, testPort); + SyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); // Send message and wait for response std::vector bytes(BYTES_CONST(expectedMsg.c_str()), @@ -92,7 +92,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") }); // Receive message - SyncRecvMessageEndpoint dst(testPort); + SyncRecvMessageEndpoint dst(TEST_PORT); faabric::transport::Message recvMsg = dst.recv(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); @@ -118,7 +118,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, std::thread senderThread([numMessages, baseMsg] { // Open the source endpoint client - AsyncSendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); for (int i = 0; i < numMessages; i++) { std::string msgData = baseMsg + std::to_string(i); uint8_t msg[msgData.size()]; @@ -128,7 +128,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, }); // Receive messages - AsyncRecvMessageEndpoint dst(testPort); + AsyncRecvMessageEndpoint dst(TEST_PORT); for (int i = 0; i < numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -159,7 +159,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, for (int j = 0; j < numSenders; j++) { senderThreads.emplace_back(std::thread([numMessages, expectedMsg] { // Open the source endpoint client - AsyncSendMessageEndpoint src(thisHost, testPort); + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); for (int i = 0; i < numMessages; i++) { uint8_t msg[expectedMsg.size()]; memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); @@ -169,7 +169,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, } // Receive messages - AsyncRecvMessageEndpoint dst(testPort); + AsyncRecvMessageEndpoint dst(TEST_PORT); for (int i = 0; i < numSenders * numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -195,35 +195,35 @@ TEST_CASE_METHOD(SchedulerTestFixture, SECTION("Sanity check valid timeout") { - AsyncSendMessageEndpoint s(thisHost, testPort, 100); - AsyncRecvMessageEndpoint r(testPort, 100); + AsyncSendMessageEndpoint s(LOCALHOST, TEST_PORT, 100); + AsyncRecvMessageEndpoint r(TEST_PORT, 100); - SyncSendMessageEndpoint sB(thisHost, testPort + 10, 100); - SyncRecvMessageEndpoint rB(testPort + 10, 100); + SyncSendMessageEndpoint sB(LOCALHOST, TEST_PORT + 10, 100); + SyncRecvMessageEndpoint rB(TEST_PORT + 10, 100); } SECTION("Recv zero timeout") { - REQUIRE_THROWS(AsyncRecvMessageEndpoint(testPort, 0)); - REQUIRE_THROWS(SyncRecvMessageEndpoint(testPort + 10, 0)); + REQUIRE_THROWS(AsyncRecvMessageEndpoint(TEST_PORT, 0)); + REQUIRE_THROWS(SyncRecvMessageEndpoint(TEST_PORT + 10, 0)); } SECTION("Send zero timeout") { - REQUIRE_THROWS(AsyncSendMessageEndpoint(thisHost, testPort, 0)); - REQUIRE_THROWS(SyncSendMessageEndpoint(thisHost, testPort + 10, 0)); + REQUIRE_THROWS(AsyncSendMessageEndpoint(LOCALHOST, TEST_PORT, 0)); + REQUIRE_THROWS(SyncSendMessageEndpoint(LOCALHOST, TEST_PORT + 10, 0)); } SECTION("Recv negative timeout") { - REQUIRE_THROWS(AsyncRecvMessageEndpoint(testPort, -1)); - REQUIRE_THROWS(SyncRecvMessageEndpoint(testPort + 10, -1)); + REQUIRE_THROWS(AsyncRecvMessageEndpoint(TEST_PORT, -1)); + REQUIRE_THROWS(SyncRecvMessageEndpoint(TEST_PORT + 10, -1)); } SECTION("Send negative timeout") { - REQUIRE_THROWS(AsyncSendMessageEndpoint(thisHost, testPort, -1)); - REQUIRE_THROWS(SyncSendMessageEndpoint(thisHost, testPort + 10, -1)); + REQUIRE_THROWS(AsyncSendMessageEndpoint(LOCALHOST, TEST_PORT, -1)); + REQUIRE_THROWS(SyncSendMessageEndpoint(LOCALHOST, TEST_PORT + 10, -1)); } } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index e23c37d3e..eb8291dd1 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -13,20 +13,17 @@ using namespace faabric::transport; -static const std::string thisHost = "127.0.0.1"; -static const int testPortAsync = 9998; -static const int testPortSync = 9999; +#define TEST_PORT_ASYNC 9998 +#define TEST_PORT_SYNC 9999 class DummyServer final : public MessageEndpointServer { public: DummyServer() - : MessageEndpointServer(testPortAsync, testPortSync) - , messageCount(0) + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) {} - // Variable to keep track of the received messages - int messageCount; + std::atomic messageCount = 0; private: void doAsyncRecv(faabric::transport::Message& header, @@ -49,14 +46,14 @@ class EchoServer final : public MessageEndpointServer { public: EchoServer() - : MessageEndpointServer(testPortAsync, testPortSync) + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) {} protected: void doAsyncRecv(faabric::transport::Message& header, faabric::transport::Message& body) override { - throw std::runtime_error("EchoServer not expecting async recv"); + throw std::runtime_error("Echo server not expecting async recv"); } std::unique_ptr doSyncRecv( @@ -78,7 +75,7 @@ class SleepServer final : public MessageEndpointServer int delayMs = 1000; SleepServer() - : MessageEndpointServer(testPortAsync, testPortSync) + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) {} protected: @@ -111,15 +108,18 @@ TEST_CASE("Test send one message to server", "[transport]") REQUIRE(server.messageCount == 0); - MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); + MessageEndpointClient cli(LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); // Send a message std::string body = "body"; uint8_t bodyMsg[body.size()]; memcpy(bodyMsg, body.c_str(), body.size()); + + server.setAsyncLatch(); cli.asyncSend(0, bodyMsg, body.size()); + server.awaitAsyncLatch(); - REQUIRE_RETRY({}, server.messageCount == 1); + REQUIRE(server.messageCount == 1); server.stop(); } @@ -132,7 +132,7 @@ TEST_CASE("Test send response to client", "[transport]") std::string expectedMsg = "Response from server"; // Open the source endpoint client - MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); + MessageEndpointClient cli(LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); // Send and await the response faabric::StatePart response; @@ -145,44 +145,41 @@ TEST_CASE("Test send response to client", "[transport]") TEST_CASE("Test multiple clients talking to one server", "[transport]") { - // Start the server in the background - DummyServer server; + EchoServer server; server.start(); std::vector clientThreads; int numClients = 10; int numMessages = 1000; - // Set up a barrier to wait on all the clients having finished - faabric::util::Barrier barrier(numClients + 1); - for (int i = 0; i < numClients; i++) { - clientThreads.emplace_back(std::thread([&barrier, numMessages] { + clientThreads.emplace_back(std::thread([i, numMessages] { // Prepare client - MessageEndpointClient cli(thisHost, testPortAsync, testPortSync); + MessageEndpointClient cli( + LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); - std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { - // Send body + std::string clientMsg = + fmt::format("Message {} from client {}", j, i); + + // Send and get response uint8_t body[clientMsg.size()]; memcpy(body, clientMsg.c_str(), clientMsg.size()); - cli.asyncSend(0, body, clientMsg.size()); - } + faabric::StatePart response; + cli.syncSend(0, body, clientMsg.size(), &response); - barrier.wait(); + std::string actual = response.data(); + assert(actual == clientMsg); + } })); } - barrier.wait(); - for (auto& t : clientThreads) { if (t.joinable()) { t.join(); } } - REQUIRE_RETRY({}, server.messageCount == numMessages * numClients); - server.stop(); } @@ -212,7 +209,7 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") // Set up the client MessageEndpointClient cli( - thisHost, testPortAsync, testPortSync, clientTimeout); + LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC, clientTimeout); uint8_t* sleepBytes = BYTES(&serverSleep); faabric::StatePart response; diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index 643c9bda0..670062b20 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -22,7 +22,6 @@ TEST_CASE_METHOD(SchedulerTestFixture, sendEndpoint.sendMpiMessage(expected); std::shared_ptr actual = recvEndpoint.recvMpiMessage(); - // Checks REQUIRE(expected->id() == actual->id()); } } diff --git a/tests/test/util/test_barrier.cpp b/tests/test/util/test_barrier.cpp index 97b075c96..27d9eb087 100644 --- a/tests/test/util/test_barrier.cpp +++ b/tests/test/util/test_barrier.cpp @@ -2,8 +2,8 @@ #include "faabric_utils.h" -#include #include +#include #include #include @@ -12,22 +12,14 @@ using namespace faabric::util; namespace tests { -TEST_CASE("Test barrier operation", "[util]") +TEST_CASE("Test latch operation", "[util]") { - Barrier b(3); + Latch l(3); - REQUIRE(b.getSlotCount() == 3); - REQUIRE(b.getUseCount() == 0); + auto t1 = std::thread([&l] { l.wait(); }); + auto t2 = std::thread([&l] { l.wait(); }); - auto t1 = std::thread([&b] { b.wait(); }); - - auto t2 = std::thread([&b] { b.wait(); }); - - // Sleep for a bit while the threads spawn - REQUIRE_RETRY({}, b.getSlotCount() == 1); - - // Join with master to go through barrier - b.wait(); + l.wait(); if (t1.joinable()) { t1.join(); @@ -37,15 +29,13 @@ TEST_CASE("Test barrier operation", "[util]") t2.join(); } - REQUIRE(b.getSlotCount() == 3); - REQUIRE(b.getUseCount() == 1); + REQUIRE_THROWS(l.wait()); } -TEST_CASE("Test barrier timeout", "[util]") +TEST_CASE("Test latch timeout", "[util]") { int timeoutMs = 500; - Barrier b(2, timeoutMs); - - REQUIRE_THROWS(b.wait()); + Latch l(2, timeoutMs); + REQUIRE_THROWS(l.wait()); } } diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index f54b4f426..7f189b0c2 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include @@ -176,7 +176,7 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture public: RemoteMpiTestFixture() : thisHost(faabric::util::getSystemConfig().endpointHost) - , testBarrier(2) + , testLatch(2) { otherWorld.overrideHost(otherHost); @@ -217,7 +217,7 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture std::string thisHost; std::string otherHost = LOCALHOST; - faabric::util::Barrier testBarrier; + faabric::util::Latch testLatch; faabric::scheduler::MpiWorld otherWorld; }; From 9b0cbaf8e16fb694ce90ce43b9e23994e6e79f44 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 13:04:43 +0000 Subject: [PATCH 55/66] Guard against null pointers and avoid memcpying --- include/faabric/transport/MessageEndpoint.h | 2 +- .../faabric/transport/MessageEndpointClient.h | 2 +- src/scheduler/SnapshotClient.cpp | 4 ++-- src/transport/Message.cpp | 10 +++++++--- src/transport/MessageEndpoint.cpp | 2 +- src/transport/MessageEndpointClient.cpp | 2 +- .../transport/test_message_endpoint_client.cpp | 18 ++++++------------ tests/test/transport/test_message_server.cpp | 7 ++----- 8 files changed, 21 insertions(+), 26 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 6bb530a3d..935c81a17 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -119,7 +119,7 @@ class SyncRecvMessageEndpoint : public MessageEndpoint Message recv(int size = 0); - void sendResponse(uint8_t* data, int size); + void sendResponse(const uint8_t* data, int size); private: zmq::socket_t repSocket; diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h index 945e30509..02e945925 100644 --- a/include/faabric/transport/MessageEndpointClient.h +++ b/include/faabric/transport/MessageEndpointClient.h @@ -16,7 +16,7 @@ class MessageEndpointClient void asyncSend(int header, google::protobuf::Message* msg); - void asyncSend(int header, uint8_t* buffer, size_t bufferSize); + void asyncSend(int header, const uint8_t* buffer, size_t bufferSize); void syncSend(int header, google::protobuf::Message* msg, diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 020edfa45..c3de00cb2 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -71,7 +71,7 @@ void clearMockSnapshotRequests() #define SEND_FB_MSG(T, mb) \ { \ - uint8_t* buffer = mb.GetBufferPointer(); \ + const uint8_t* buffer = mb.GetBufferPointer(); \ int size = mb.GetSize(); \ faabric::EmptyResponse response; \ syncSend(T, buffer, size, &response); \ @@ -79,7 +79,7 @@ void clearMockSnapshotRequests() #define SEND_FB_MSG_ASYNC(T, mb) \ { \ - uint8_t* buffer = mb.GetBufferPointer(); \ + const uint8_t* buffer = mb.GetBufferPointer(); \ int size = mb.GetSize(); \ asyncSend(T, buffer, size); \ } diff --git a/src/transport/Message.cpp b/src/transport/Message.cpp index 5779a9e42..19ac4614d 100644 --- a/src/transport/Message.cpp +++ b/src/transport/Message.cpp @@ -1,11 +1,15 @@ +#include #include +#include namespace faabric::transport { Message::Message(const zmq::message_t& msgIn) - : bytes(msgIn.size()) - , _more(msgIn.more()) + : _more(msgIn.more()) { - std::memcpy(bytes.data(), msgIn.data(), msgIn.size()); + if (msgIn.data() != nullptr) { + bytes = std::vector(BYTES_CONST(msgIn.data()), + BYTES_CONST(msgIn.data()) + msgIn.size()); + } } Message::Message(int sizeIn) diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 808478b1f..755f1bf87 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -310,7 +310,7 @@ Message SyncRecvMessageEndpoint::recv(int size) return doRecv(repSocket, size); } -void SyncRecvMessageEndpoint::sendResponse(uint8_t* data, int size) +void SyncRecvMessageEndpoint::sendResponse(const uint8_t* data, int size) { SPDLOG_TRACE("REP {} ({} bytes)", port, size); doSend(repSocket, data, size, false); diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 6f028fac3..a041c5b0e 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -27,7 +27,7 @@ void MessageEndpointClient::asyncSend(int header, } void MessageEndpointClient::asyncSend(int header, - uint8_t* buffer, + const uint8_t* buffer, size_t bufferSize) { asyncEndpoint.sendHeader(header); diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 4e463acbc..afec981e6 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -22,8 +22,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Send message std::string expectedMsg = "Hello world!"; - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); + const uint8_t* msg = BYTES_CONST(expectedMsg.c_str()); src.send(msg, expectedMsg.size()); // Receive message @@ -56,9 +55,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, assert(actualMsg == expectedMsg); }); - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - + const uint8_t* msg = BYTES_CONST(expectedMsg.c_str()); src.send(msg, expectedMsg.size()); latch.wait(); @@ -99,8 +96,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") REQUIRE(actualMsg == expectedMsg); // Send response - uint8_t msg[expectedResponse.size()]; - memcpy(msg, expectedResponse.c_str(), expectedResponse.size()); + const uint8_t* msg = BYTES_CONST(expectedResponse.c_str()); dst.sendResponse(msg, expectedResponse.size()); // Wait for sender thread @@ -121,8 +117,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); for (int i = 0; i < numMessages; i++) { std::string msgData = baseMsg + std::to_string(i); - uint8_t msg[msgData.size()]; - memcpy(msg, msgData.c_str(), msgData.size()); + const uint8_t* msg = BYTES_CONST(msgData.c_str()); src.send(msg, msgData.size()); } }); @@ -155,14 +150,13 @@ TEST_CASE_METHOD(SchedulerTestFixture, int numSenders = 10; std::string expectedMsg = "Hello from client"; std::vector senderThreads; + const uint8_t* msg = BYTES_CONST(expectedMsg.c_str()); for (int j = 0; j < numSenders; j++) { - senderThreads.emplace_back(std::thread([numMessages, expectedMsg] { + senderThreads.emplace_back(std::thread([msg, numMessages, expectedMsg] { // Open the source endpoint client AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); for (int i = 0; i < numMessages; i++) { - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); src.send(msg, expectedMsg.size()); } })); diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index eb8291dd1..4c88c2814 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -112,8 +112,7 @@ TEST_CASE("Test send one message to server", "[transport]") // Send a message std::string body = "body"; - uint8_t bodyMsg[body.size()]; - memcpy(bodyMsg, body.c_str(), body.size()); + const uint8_t* bodyMsg = BYTES_CONST(body.c_str()); server.setAsyncLatch(); cli.asyncSend(0, bodyMsg, body.size()); @@ -163,8 +162,7 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") fmt::format("Message {} from client {}", j, i); // Send and get response - uint8_t body[clientMsg.size()]; - memcpy(body, clientMsg.c_str(), clientMsg.size()); + const uint8_t* body = BYTES_CONST(clientMsg.c_str()); faabric::StatePart response; cli.syncSend(0, body, clientMsg.size(), &response); @@ -220,7 +218,6 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") MessageTimeoutException); } else { cli.syncSend(0, sleepBytes, sizeof(int), &response); - std::vector expected = { 0, 1, 2, 3 }; REQUIRE(response.data() == "Response after sleep"); } From ff00429bf3f1f9464b49dbda240222b0774a7305 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 13:35:04 +0000 Subject: [PATCH 56/66] Share latch via shared pointer --- .../faabric/transport/MessageEndpointServer.h | 2 +- include/faabric/util/latch.h | 10 +++++- src/scheduler/SnapshotClient.cpp | 4 +-- src/transport/MessageEndpointServer.cpp | 19 ++++++----- src/util/latch.cpp | 9 +++++- tests/test/scheduler/test_executor.cpp | 6 ++-- .../test/scheduler/test_remote_mpi_worlds.cpp | 32 +++++++++---------- .../test_message_endpoint_client.cpp | 6 ++-- .../util/{test_barrier.cpp => test_latch.cpp} | 14 ++++---- tests/utils/fixtures.h | 4 +-- 10 files changed, 60 insertions(+), 46 deletions(-) rename tests/test/util/{test_barrier.cpp => test_latch.cpp} (65%) diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 999a6544e..13458791c 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -44,6 +44,6 @@ class MessageEndpointServer AsyncSendMessageEndpoint asyncShutdownSender; SyncSendMessageEndpoint syncShutdownSender; - std::unique_ptr asyncLatch; + std::shared_ptr asyncLatch; }; } diff --git a/include/faabric/util/latch.h b/include/faabric/util/latch.h index a10ad431b..cdd84c863 100644 --- a/include/faabric/util/latch.h +++ b/include/faabric/util/latch.h @@ -10,9 +10,17 @@ namespace faabric::util { class Latch { public: - explicit Latch(int countIn, int timeoutMsIn = DEFAULT_BARRIER_TIMEOUT_MS); + // WARNING: this latch must be shared between threads using a shared + // pointer, otherwise there seems to be some nasty race conditions related + // to its destruction. + static std::shared_ptr create( + int count, + int timeoutMs = DEFAULT_BARRIER_TIMEOUT_MS); void wait(); + + explicit Latch(int countIn, int timeoutMsIn); + private: int count; int waiters = 0; diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index c3de00cb2..ce821adc9 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -71,7 +71,7 @@ void clearMockSnapshotRequests() #define SEND_FB_MSG(T, mb) \ { \ - const uint8_t* buffer = mb.GetBufferPointer(); \ + const uint8_t* buffer = mb.GetBufferPointer(); \ int size = mb.GetSize(); \ faabric::EmptyResponse response; \ syncSend(T, buffer, size, &response); \ @@ -79,7 +79,7 @@ void clearMockSnapshotRequests() #define SEND_FB_MSG_ASYNC(T, mb) \ { \ - const uint8_t* buffer = mb.GetBufferPointer(); \ + const uint8_t* buffer = mb.GetBufferPointer(); \ int size = mb.GetSize(); \ asyncSend(T, buffer, size); \ } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 5dd6cbaef..4421d8d9d 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -43,11 +43,11 @@ void MessageEndpointServer::start() // This latch means that callers can guarantee that when this function // completes, both sockets will have been opened (and hence the server is // ready to use). - faabric::util::Latch startLatch(3); + auto startLatch = faabric::util::Latch::create(3); - asyncThread = std::thread([this, &startLatch] { + asyncThread = std::thread([this, startLatch] { AsyncRecvMessageEndpoint endpoint(asyncPort); - startLatch.wait(); + startLatch->wait(); while (true) { // Receive header and body @@ -62,17 +62,16 @@ void MessageEndpointServer::start() // Wait on the async latch if necessary if (asyncLatch != nullptr) { - SPDLOG_TRACE( - "Server thread waiting on async latch for port {}", - asyncPort); + SPDLOG_TRACE("Server thread waiting on async latch for port {}", + asyncPort); asyncLatch->wait(); } } }); - syncThread = std::thread([this, &startLatch] { + syncThread = std::thread([this, startLatch] { SyncRecvMessageEndpoint endpoint(syncPort); - startLatch.wait(); + startLatch->wait(); while (true) { // Receive header and body @@ -96,7 +95,7 @@ void MessageEndpointServer::start() } }); - startLatch.wait(); + startLatch->wait(); } void MessageEndpointServer::stop() @@ -121,7 +120,7 @@ void MessageEndpointServer::stop() void MessageEndpointServer::setAsyncLatch() { - asyncLatch = std::make_unique(2); + asyncLatch = faabric::util::Latch::create(2); } void MessageEndpointServer::awaitAsyncLatch() diff --git a/src/util/latch.cpp b/src/util/latch.cpp index d438910fe..8d90e007d 100644 --- a/src/util/latch.cpp +++ b/src/util/latch.cpp @@ -1,7 +1,14 @@ #include #include +#include namespace faabric::util { + +std::shared_ptr Latch::create(int count, int timeoutMs) +{ + return std::make_shared(count, timeoutMs); +} + Latch::Latch(int countIn, int timeoutMsIn) : count(countIn) , timeoutMs(timeoutMsIn) @@ -24,7 +31,7 @@ void Latch::wait() std::chrono::milliseconds(timeoutMs); if (!cv.wait_until(lock, timePoint, [&] { return waiters >= count; })) { - throw std::runtime_error("latch timed out"); + throw std::runtime_error("Latch timed out"); } } } diff --git a/tests/test/scheduler/test_executor.cpp b/tests/test/scheduler/test_executor.cpp index 88dbbc85c..67417c938 100644 --- a/tests/test/scheduler/test_executor.cpp +++ b/tests/test/scheduler/test_executor.cpp @@ -439,7 +439,7 @@ TEST_CASE_METHOD(TestExecutorFixture, faabric::scheduler::queueResourceResponse(otherHost, resOther); // Background thread to execute main function and await results - faabric::util::Latch latch(2); + auto latch = faabric::util::Latch::create(2); std::thread t([&latch] { int nThreads = 8; std::shared_ptr req = @@ -450,7 +450,7 @@ TEST_CASE_METHOD(TestExecutorFixture, auto& sch = faabric::scheduler::getScheduler(); sch.callFunctions(req, false); - latch.wait(); + latch->wait(); faabric::Message res = sch.getFunctionResult(msg.id(), 2000); assert(res.returnvalue() == 0); }); @@ -493,7 +493,7 @@ TEST_CASE_METHOD(TestExecutorFixture, } // Rejoin the other thread - latch.wait(); + latch->wait(); if (t.joinable()) { t.join(); } diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 21780001e..1cceed404 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -170,7 +170,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, std::vector actual(buffer, buffer + messageData2.size()); REQUIRE(actual == messageData2); - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -191,7 +191,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(status.MPI_ERROR == MPI_SUCCESS); REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); - testLatch.wait(); + testLatch->wait(); // Clean up if (otherWorldThread.joinable()) { @@ -271,7 +271,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -287,7 +287,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Clean up - testLatch.wait(); + testLatch->wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -325,7 +325,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Give the other host time to receive the broadcast - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -338,7 +338,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, } // Clean up - testLatch.wait(); + testLatch->wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -398,7 +398,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -435,7 +435,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); // Clean up - testLatch.wait(); + testLatch->wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -485,7 +485,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, nPerRank); } - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -517,7 +517,7 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, REQUIRE(actual == expected); // Clean up - testLatch.wait(); + testLatch->wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -555,7 +555,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_INT, messageData.size()); - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -584,7 +584,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(asyncMessage == messageData); // Clean up - testLatch.wait(); + testLatch->wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -614,7 +614,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, otherWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -648,7 +648,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(recv3 == 2); // Clean up - testLatch.wait(); + testLatch->wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } @@ -692,7 +692,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, MPI_STATUS_IGNORE); } - testLatch.wait(); + testLatch->wait(); otherWorld.destroy(); }); @@ -716,7 +716,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } // Clean up - testLatch.wait(); + testLatch->wait(); if (otherWorldThread.joinable()) { otherWorldThread.join(); } diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index afec981e6..b42ff3e6d 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -40,11 +40,11 @@ TEST_CASE_METHOD(SchedulerTestFixture, AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); - faabric::util::Latch latch(2); + auto latch = faabric::util::Latch::create(2); std::thread recvThread([&latch, expectedMsg] { // Make sure this only runs once the send has been done - latch.wait(); + latch->wait(); // Receive message AsyncRecvMessageEndpoint dst(TEST_PORT); @@ -57,7 +57,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, const uint8_t* msg = BYTES_CONST(expectedMsg.c_str()); src.send(msg, expectedMsg.size()); - latch.wait(); + latch->wait(); if (recvThread.joinable()) { recvThread.join(); diff --git a/tests/test/util/test_barrier.cpp b/tests/test/util/test_latch.cpp similarity index 65% rename from tests/test/util/test_barrier.cpp rename to tests/test/util/test_latch.cpp index 27d9eb087..3025ed06a 100644 --- a/tests/test/util/test_barrier.cpp +++ b/tests/test/util/test_latch.cpp @@ -14,12 +14,12 @@ using namespace faabric::util; namespace tests { TEST_CASE("Test latch operation", "[util]") { - Latch l(3); + auto l = Latch::create(3); - auto t1 = std::thread([&l] { l.wait(); }); - auto t2 = std::thread([&l] { l.wait(); }); + auto t1 = std::thread([l] { l->wait(); }); + auto t2 = std::thread([l] { l->wait(); }); - l.wait(); + l->wait(); if (t1.joinable()) { t1.join(); @@ -29,13 +29,13 @@ TEST_CASE("Test latch operation", "[util]") t2.join(); } - REQUIRE_THROWS(l.wait()); + REQUIRE_THROWS(l->wait()); } TEST_CASE("Test latch timeout", "[util]") { int timeoutMs = 500; - Latch l(2, timeoutMs); - REQUIRE_THROWS(l.wait()); + auto l = Latch::create(2, timeoutMs); + REQUIRE_THROWS(l->wait()); } } diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 7f189b0c2..eeac4aa9f 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -176,7 +176,7 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture public: RemoteMpiTestFixture() : thisHost(faabric::util::getSystemConfig().endpointHost) - , testLatch(2) + , testLatch(faabric::util::Latch::create(2)) { otherWorld.overrideHost(otherHost); @@ -217,7 +217,7 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture std::string thisHost; std::string otherHost = LOCALHOST; - faabric::util::Latch testLatch; + std::shared_ptr testLatch; faabric::scheduler::MpiWorld otherWorld; }; From 413ce6e16691e42ae2afa390a71d945a9fc56f2a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 13:50:33 +0000 Subject: [PATCH 57/66] Typos --- include/faabric/transport/MessageEndpoint.h | 8 +++----- include/faabric/util/latch.h | 4 ++-- tests/test/scheduler/test_executor.cpp | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 935c81a17..eb0f3d913 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -23,11 +23,9 @@ namespace faabric::transport { -/* - * Note, that sockets must be open-ed and close-ed from the _same_ thread. For a - * proto://host:pair triple, one socket may bind, and all the rest must connect. - * Order does not matter. - */ +// Note: sockets must be open-ed and close-ed from the _same_ thread. In a given +// communication group, one socket may bind, and all the rest must connect. +// Order does not matter. class MessageEndpoint { public: diff --git a/include/faabric/util/latch.h b/include/faabric/util/latch.h index cdd84c863..462d261b6 100644 --- a/include/faabric/util/latch.h +++ b/include/faabric/util/latch.h @@ -5,7 +5,7 @@ namespace faabric::util { -#define DEFAULT_BARRIER_TIMEOUT_MS 10000 +#define DEFAULT_LATCH_TIMEOUT_MS 10000 class Latch { @@ -15,7 +15,7 @@ class Latch // to its destruction. static std::shared_ptr create( int count, - int timeoutMs = DEFAULT_BARRIER_TIMEOUT_MS); + int timeoutMs = DEFAULT_LATCH_TIMEOUT_MS); void wait(); diff --git a/tests/test/scheduler/test_executor.cpp b/tests/test/scheduler/test_executor.cpp index 67417c938..87a87b236 100644 --- a/tests/test/scheduler/test_executor.cpp +++ b/tests/test/scheduler/test_executor.cpp @@ -455,7 +455,7 @@ TEST_CASE_METHOD(TestExecutorFixture, assert(res.returnvalue() == 0); }); - // Wait until the function has executed and submitted another request + // Wait until the function has executed and submit another request auto reqs = faabric::scheduler::getBatchRequests(); REQUIRE_RETRY(reqs = faabric::scheduler::getBatchRequests(), reqs.size() == 1); From 4e6dbf22e5748de77aaa106ff557f4e494098d3c Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 13:59:30 +0000 Subject: [PATCH 58/66] Link util to transport --- src/transport/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index 4d5b3a651..b49d705f9 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -25,5 +25,5 @@ set(LIB_FILES faabric_lib(transport "${LIB_FILES}") -target_link_libraries(transport proto zeromq_imported) +target_link_libraries(transport util proto zeromq_imported) From b4d8cc71517e8561dd8c6aea95d6abd430faaa9f Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Tue, 29 Jun 2021 14:24:06 +0000 Subject: [PATCH 59/66] Move global message context handling out of FaabricMain --- examples/server.cpp | 3 +++ src/runner/FaabricMain.cpp | 5 ----- tests/dist/main.cpp | 4 ++++ tests/dist/server.cpp | 3 +++ 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/server.cpp b/examples/server.cpp index 19108c0a9..3bdcbc5b7 100644 --- a/examples/server.cpp +++ b/examples/server.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include using namespace faabric::scheduler; @@ -38,6 +39,7 @@ class ExampleExecutorFactory : public ExecutorFactory int main() { faabric::util::initLogging(); + faabric::transport::initGlobalMessageContext(); // Start the worker pool SPDLOG_INFO("Starting executor pool in the background"); @@ -53,6 +55,7 @@ int main() SPDLOG_INFO("Shutting down endpoint"); m.shutdown(); + faabric::transport::closeGlobalMessageContext(); return EXIT_SUCCESS; } diff --git a/src/runner/FaabricMain.cpp b/src/runner/FaabricMain.cpp index e80d9a9a8..ed9764b58 100644 --- a/src/runner/FaabricMain.cpp +++ b/src/runner/FaabricMain.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include @@ -21,8 +20,6 @@ FaabricMain::FaabricMain( void FaabricMain::startBackground() { - faabric::transport::initGlobalMessageContext(); - // Start basics startRunner(); @@ -94,8 +91,6 @@ void FaabricMain::shutdown() SPDLOG_INFO("Waiting for the snapshot server to finish"); snapshotServer.stop(); - faabric::transport::closeGlobalMessageContext(); - SPDLOG_INFO("Faabric pool successfully shut down"); } } diff --git a/tests/dist/main.cpp b/tests/dist/main.cpp index cca63d7e4..77f944364 100644 --- a/tests/dist/main.cpp +++ b/tests/dist/main.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -19,6 +20,7 @@ FAABRIC_CATCH_LOGGER int main(int argc, char* argv[]) { faabric::util::initLogging(); + faabric::transport::initGlobalMessageContext(); // Set up the distributed tests tests::initDistTests(); @@ -41,5 +43,7 @@ int main(int argc, char* argv[]) SPDLOG_INFO("Shutting down"); m.shutdown(); + faabric::transport::closeGlobalMessageContext(); + return result; } diff --git a/tests/dist/server.cpp b/tests/dist/server.cpp index 2ab2d8431..9fe79011b 100644 --- a/tests/dist/server.cpp +++ b/tests/dist/server.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include using namespace faabric::scheduler; @@ -11,6 +12,7 @@ using namespace faabric::scheduler; int main() { faabric::util::initLogging(); + faabric::transport::initGlobalMessageContext(); tests::initDistTests(); SPDLOG_INFO("Starting distributed test server on worker"); @@ -25,6 +27,7 @@ int main() SPDLOG_INFO("Shutting down"); m.shutdown(); + faabric::transport::closeGlobalMessageContext(); return EXIT_SUCCESS; } From 7584527d5525ffae5caed50f3ac09ab8cdef8664 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 30 Jun 2021 14:06:30 +0000 Subject: [PATCH 60/66] Add retry logic in servers, unify server threads into signle class --- include/faabric/transport/MessageEndpoint.h | 33 ++-- .../faabric/transport/MessageEndpointServer.h | 28 ++- src/transport/MessageEndpoint.cpp | 33 ++-- src/transport/MessageEndpointServer.cpp | 164 ++++++++++++------ tests/test/runner/test_main.cpp | 19 +- 5 files changed, 191 insertions(+), 86 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index eb0f3d913..0e4dbef9d 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -63,7 +63,7 @@ class MessageEndpoint Message recvNoBuffer(zmq::socket_t& socket); }; -class AsyncSendMessageEndpoint : public MessageEndpoint +class AsyncSendMessageEndpoint final : public MessageEndpoint { public: AsyncSendMessageEndpoint(const std::string& hostIn, @@ -78,7 +78,7 @@ class AsyncSendMessageEndpoint : public MessageEndpoint zmq::socket_t pushSocket; }; -class SyncSendMessageEndpoint : public MessageEndpoint +class SyncSendMessageEndpoint final : public MessageEndpoint { public: SyncSendMessageEndpoint(const std::string& hostIn, @@ -97,33 +97,40 @@ class SyncSendMessageEndpoint : public MessageEndpoint zmq::socket_t reqSocket; }; -class AsyncRecvMessageEndpoint : public MessageEndpoint +class RecvMessageEndpoint : public MessageEndpoint +{ + public: + RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType); + + virtual ~RecvMessageEndpoint(){}; + + virtual Message recv(int size = 0); + + protected: + zmq::socket_t socket; +}; + +class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: AsyncRecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - Message recv(int size = 0); - - private: - zmq::socket_t pullSocket; + Message recv(int size = 0) override; }; -class SyncRecvMessageEndpoint : public MessageEndpoint +class SyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: SyncRecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - Message recv(int size = 0); + Message recv(int size = 0) override; void sendResponse(const uint8_t* data, int size); - - private: - zmq::socket_t repSocket; }; -class MessageTimeoutException : public faabric::util::FaabricException +class MessageTimeoutException final : public faabric::util::FaabricException { public: explicit MessageTimeoutException(std::string message) diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 13458791c..e1cc756e7 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -9,8 +9,26 @@ namespace faabric::transport { -// This server has two underlying sockets, one for synchronous communication and -// one for asynchronous. +// Each server has two underlying sockets, one for synchronous communication and +// one for asynchronous. Each is run inside its own background thread. +class MessageEndpointServer; + +class MessageEndpointServerThread +{ + public: + MessageEndpointServerThread(MessageEndpointServer* serverIn, bool asyncIn); + + void start(std::shared_ptr latch); + + void join(); + + private: + MessageEndpointServer* server; + bool async = false; + + std::thread backgroundThread; +}; + class MessageEndpointServer { public: @@ -35,11 +53,13 @@ class MessageEndpointServer void sendSyncResponse(google::protobuf::Message* resp); private: + friend class MessageEndpointServerThread; + const int asyncPort; const int syncPort; - std::thread asyncThread; - std::thread syncThread; + MessageEndpointServerThread asyncThread; + MessageEndpointServerThread syncThread; AsyncSendMessageEndpoint asyncShutdownSender; SyncSendMessageEndpoint syncShutdownSender; diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 755f1bf87..b5994476e 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -157,7 +157,7 @@ Message MessageEndpoint::recvBuffer(zmq::socket_t& socket, int size) auto res = socket.recv(zmq::buffer(msg.udata(), msg.size())); if (!res.has_value()) { - SPDLOG_ERROR("Timed out receiving message of size {}", size); + SPDLOG_TRACE("Timed out receiving message of size {}", size); throw MessageTimeoutException("Timed out receiving message"); } @@ -189,7 +189,7 @@ Message MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) try { auto res = socket.recv(msg); if (!res.has_value()) { - SPDLOG_ERROR("Timed out receiving message with no size"); + SPDLOG_TRACE("Timed out receiving message with no size"); throw MessageTimeoutException("Timed out receiving message"); } } catch (zmq::error_t& e) { @@ -250,7 +250,7 @@ SyncSendMessageEndpoint::SyncSendMessageEndpoint(const std::string& hostIn, int timeoutMs) : MessageEndpoint(hostIn, portIn, timeoutMs) { - reqSocket = setUpSocket(zmq::socket_type::req, portIn + 1); + reqSocket = setUpSocket(zmq::socket_type::req, portIn); } void SyncSendMessageEndpoint::sendHeader(int header) @@ -270,7 +270,6 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, bool more) { SPDLOG_TRACE("REQ {}:{} ({} bytes, more {})", host, port, dataSize, more); - doSend(reqSocket, data, dataSize, more); // Do the receive @@ -278,20 +277,33 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, return recvNoBuffer(reqSocket); } +// ---------------------------------------------- +// RECV ENDPOINT +// ---------------------------------------------- + +RecvMessageEndpoint::RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType) + : MessageEndpoint(ANY_HOST, portIn, timeoutMs) +{ + socket = setUpSocket(socketType, portIn); +} + +Message RecvMessageEndpoint::recv(int size) { + return doRecv(socket, size); +} + // ---------------------------------------------- // ASYNC RECV ENDPOINT // ---------------------------------------------- AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) - : MessageEndpoint(ANY_HOST, portIn, timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) { - pullSocket = setUpSocket(zmq::socket_type::pull, portIn); } Message AsyncRecvMessageEndpoint::recv(int size) { SPDLOG_TRACE("PULL {} ({} bytes)", port, size); - return doRecv(pullSocket, size); + return RecvMessageEndpoint::recv(size); } // ---------------------------------------------- @@ -299,20 +311,19 @@ Message AsyncRecvMessageEndpoint::recv(int size) // ---------------------------------------------- SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) - : MessageEndpoint(ANY_HOST, portIn, timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::rep) { - repSocket = setUpSocket(zmq::socket_type::rep, portIn + 1); } Message SyncRecvMessageEndpoint::recv(int size) { SPDLOG_TRACE("RECV (REP) {} ({} bytes)", port, size); - return doRecv(repSocket, size); + return RecvMessageEndpoint::recv(size); } void SyncRecvMessageEndpoint::sendResponse(const uint8_t* data, int size) { SPDLOG_TRACE("REP {} ({} bytes)", port, size); - doSend(repSocket, data, size, false); + doSend(socket, data, size, false); } } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 4421d8d9d..9ed3df7e4 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -31,69 +31,128 @@ static const std::vector shutdownHeader = { 0, 0, 1, 1 }; throw std::runtime_error("Body sent with SNDMORE flag"); \ } -MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) - : asyncPort(asyncPortIn) - , syncPort(syncPortIn) - , asyncShutdownSender(LOCALHOST, asyncPort) - , syncShutdownSender(LOCALHOST, syncPort) +MessageEndpointServerThread::MessageEndpointServerThread( + MessageEndpointServer* serverIn, + bool asyncIn) + : server(serverIn) + , async(asyncIn) {} -void MessageEndpointServer::start() +void MessageEndpointServerThread::start( + std::shared_ptr latch) { - // This latch means that callers can guarantee that when this function - // completes, both sockets will have been opened (and hence the server is - // ready to use). - auto startLatch = faabric::util::Latch::create(3); + backgroundThread = std::thread([this, latch] { + std::unique_ptr endpoint = nullptr; + int port = -1; + + if (async) { + port = server->asyncPort; + endpoint = std::make_unique(port); + } else { + port = server->syncPort; + endpoint = std::make_unique(port); + } - asyncThread = std::thread([this, startLatch] { - AsyncRecvMessageEndpoint endpoint(asyncPort); - startLatch->wait(); + latch->wait(); while (true) { - // Receive header and body - Message header = endpoint.recv(); - - SHUTDOWN_CHECK(header, "async") - - RECEIVE_BODY(header, endpoint) - - // Server-specific message handling - doAsyncRecv(header, body); + bool headerReceived = false; + bool bodyReceived = false; + try { + // Receive header and body + Message header = endpoint->recv(); + headerReceived = true; + + if (header.size() == shutdownHeader.size()) { + if (header.dataCopy() == shutdownHeader) { + SPDLOG_TRACE( + "Server {} endpoint received shutdown message"); + break; + } + } + + if (!header.more()) { + throw std::runtime_error( + "Header sent without SNDMORE flag"); + } + + Message body = endpoint->recv(); + if (body.more()) { + throw std::runtime_error("Body sent with SNDMORE flag"); + } + bodyReceived = true; + + if (async) { + // Server-specific async handling + server->doAsyncRecv(header, body); + } else { + // Server-specific sync handling + std::unique_ptr resp = + server->doSyncRecv(header, body); + size_t respSize = resp->ByteSizeLong(); + + uint8_t buffer[respSize]; + if (!resp->SerializeToArray(buffer, respSize)) { + throw std::runtime_error("Error serialising message"); + } + + // Return the response + static_cast(endpoint.get()) + ->sendResponse(buffer, respSize); + } + } catch (MessageTimeoutException& ex) { + // If we don't get a header in the timeout, we're ok to just + // loop round and try again + if (!headerReceived) { + SPDLOG_TRACE("Server on port {}, looping after no message", + port); + continue; + } + + if (headerReceived && !bodyReceived) { + SPDLOG_ERROR( + "Server on port {}, got header, timed out on body", port); + throw; + } + } // Wait on the async latch if necessary - if (asyncLatch != nullptr) { - SPDLOG_TRACE("Server thread waiting on async latch for port {}", - asyncPort); - asyncLatch->wait(); + if (server->asyncLatch != nullptr) { + SPDLOG_TRACE("Server thread waiting on async latch"); + server->asyncLatch->wait(); } + + headerReceived = false; + bodyReceived = false; } }); +} - syncThread = std::thread([this, startLatch] { - SyncRecvMessageEndpoint endpoint(syncPort); - startLatch->wait(); - - while (true) { - // Receive header and body - Message header = endpoint.recv(); - - SHUTDOWN_CHECK(header, "sync") - - RECEIVE_BODY(header, endpoint) +void MessageEndpointServerThread::join() +{ + if (backgroundThread.joinable()) { + backgroundThread.join(); + } +} - // Server-specific message handling - std::unique_ptr resp = - doSyncRecv(header, body); - size_t respSize = resp->ByteSizeLong(); +MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) + : asyncPort(asyncPortIn) + , syncPort(syncPortIn) + , asyncThread(this, true) + , syncThread(this, false) + , asyncShutdownSender(LOCALHOST, asyncPort) + , syncShutdownSender(LOCALHOST, syncPort) +{} - uint8_t buffer[respSize]; - if (!resp->SerializeToArray(buffer, respSize)) { - throw std::runtime_error("Error serialising message"); - } +void MessageEndpointServer::start() +{ + // This latch means that callers can guarantee that when this function + // completes, both sockets will have been opened (and hence the server is + // ready to use). + auto startLatch = faabric::util::Latch::create(3); - endpoint.sendResponse(buffer, respSize); - } - }); + asyncThread.start(startLatch); + syncThread.start(startLatch); startLatch->wait(); } @@ -109,13 +168,8 @@ void MessageEndpointServer::stop() syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); // Join the threads - if (asyncThread.joinable()) { - asyncThread.join(); - } - - if (syncThread.joinable()) { - syncThread.join(); - } + asyncThread.join(); + syncThread.join(); } void MessageEndpointServer::setAsyncLatch() diff --git a/tests/test/runner/test_main.cpp b/tests/test/runner/test_main.cpp index 4268f7be3..a97696103 100644 --- a/tests/test/runner/test_main.cpp +++ b/tests/test/runner/test_main.cpp @@ -13,14 +13,24 @@ using namespace faabric::scheduler; namespace tests { -TEST_CASE("Test main runner", "[runner]") +class MainRunnerTestFixture : public SchedulerTestFixture +{ + public: + MainRunnerTestFixture() + { + std::shared_ptr fac = + std::make_shared(); + faabric::scheduler::setExecutorFactory(fac); + } +}; + +TEST_CASE_METHOD(MainRunnerTestFixture, "Test main runner", "[runner]") { - cleanFaabric(); std::shared_ptr fac = faabric::scheduler::getExecutorFactory(); faabric::runner::FaabricMain m(fac); - m.startRunner(); + m.startBackground(); std::shared_ptr req = faabric::util::batchExecFactory("foo", "bar", 4); @@ -34,5 +44,8 @@ TEST_CASE("Test main runner", "[runner]") sch.getFunctionResult(m.id(), SHORT_TEST_TIMEOUT_MS); REQUIRE(res.outputdata() == expected); } + + m.shutdown(); } + } From 19a7d6edcd16819cd6520372104700b09181691e Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 30 Jun 2021 14:47:44 +0000 Subject: [PATCH 61/66] Fix port mixup and remove unused macros --- include/faabric/transport/common.h | 4 ++-- src/scheduler/Scheduler.cpp | 2 ++ src/transport/MessageEndpointServer.cpp | 22 +--------------------- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index dcdcf182b..7ee8ee759 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -8,7 +8,7 @@ #define STATE_SYNC_PORT 8004 #define FUNCTION_CALL_ASYNC_PORT 8005 #define FUNCTION_CALL_SYNC_PORT 8006 -#define SNAPSHOT_SYNC_PORT 8007 -#define SNAPSHOT_ASYNC_PORT 8008 +#define SNAPSHOT_ASYNC_PORT 8007 +#define SNAPSHOT_SYNC_PORT 8008 #define DEFAULT_MPI_BASE_PORT 8800 diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 866ca51a0..f3f48437c 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -63,6 +63,8 @@ void Scheduler::addHostToGlobalSet() void Scheduler::reset() { + SPDLOG_DEBUG("Resetting scheduler"); + // Shut down all Executors for (auto& p : executors) { for (auto& e : p.second) { diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 9ed3df7e4..2606026ce 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -11,26 +11,6 @@ namespace faabric::transport { static const std::vector shutdownHeader = { 0, 0, 1, 1 }; -#define SHUTDOWN_CHECK(header, label) \ - { \ - if (header.size() == shutdownHeader.size()) { \ - if (header.dataCopy() == shutdownHeader) { \ - SPDLOG_TRACE("Server {} endpoint received shutdown message", \ - label); \ - break; \ - } \ - } \ - } - -#define RECEIVE_BODY(header, endpoint) \ - if (!header.more()) { \ - throw std::runtime_error("Header sent without SNDMORE flag"); \ - } \ - Message body = endpoint.recv(); \ - if (body.more()) { \ - throw std::runtime_error("Body sent with SNDMORE flag"); \ - } - MessageEndpointServerThread::MessageEndpointServerThread( MessageEndpointServer* serverIn, bool asyncIn) @@ -66,7 +46,7 @@ void MessageEndpointServerThread::start( if (header.size() == shutdownHeader.size()) { if (header.dataCopy() == shutdownHeader) { SPDLOG_TRACE( - "Server {} endpoint received shutdown message"); + "Server on {} received shutdown message", port); break; } } From bf6b15cb17733af5ebe88d0a15a1c4a5eb52de2a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Thu, 1 Jul 2021 10:38:09 +0000 Subject: [PATCH 62/66] Remove use thread-local cache in scheduler --- include/faabric/scheduler/Scheduler.h | 8 +-- src/runner/FaabricMain.cpp | 6 +-- src/scheduler/Executor.cpp | 4 ++ src/scheduler/Scheduler.cpp | 71 +++++++++++-------------- src/transport/MessageEndpoint.cpp | 13 ++--- src/transport/MessageEndpointServer.cpp | 4 +- tests/dist/main.cpp | 45 ++++++++-------- tests/dist/server.cpp | 1 + tests/test/main.cpp | 1 - tests/test/runner/test_main.cpp | 26 +++++---- 10 files changed, 87 insertions(+), 92 deletions(-) diff --git a/include/faabric/scheduler/Scheduler.h b/include/faabric/scheduler/Scheduler.h index cc9dee83e..68ed0574e 100644 --- a/include/faabric/scheduler/Scheduler.h +++ b/include/faabric/scheduler/Scheduler.h @@ -91,6 +91,8 @@ class Scheduler void reset(); + void resetThreadLocalCache(); + void shutdown(); void broadcastSnapshotDelete(const faabric::Message& msg, @@ -182,15 +184,9 @@ class Scheduler std::unordered_map> threadResults; - std::shared_mutex functionCallClientsMx; - std::unordered_map - functionCallClients; faabric::scheduler::FunctionCallClient& getFunctionCallClient( const std::string& otherHost); - std::shared_mutex snapshotClientsMx; - std::unordered_map - snapshotClients; faabric::scheduler::SnapshotClient& getSnapshotClient( const std::string& otherHost); diff --git a/src/runner/FaabricMain.cpp b/src/runner/FaabricMain.cpp index ed9764b58..18625cfe4 100644 --- a/src/runner/FaabricMain.cpp +++ b/src/runner/FaabricMain.cpp @@ -79,9 +79,6 @@ void FaabricMain::shutdown() { SPDLOG_INFO("Removing from global working set"); - auto& sch = faabric::scheduler::getScheduler(); - sch.shutdown(); - SPDLOG_INFO("Waiting for the state server to finish"); stateServer.stop(); @@ -91,6 +88,9 @@ void FaabricMain::shutdown() SPDLOG_INFO("Waiting for the snapshot server to finish"); snapshotServer.stop(); + auto& sch = faabric::scheduler::getScheduler(); + sch.shutdown(); + SPDLOG_INFO("Faabric pool successfully shut down"); } } diff --git a/src/scheduler/Executor.cpp b/src/scheduler/Executor.cpp index 0551eb80f..8fb228cee 100644 --- a/src/scheduler/Executor.cpp +++ b/src/scheduler/Executor.cpp @@ -298,6 +298,10 @@ void Executor::threadPoolThread(int threadPoolIdx) sch.notifyExecutorShutdown(this, boundMessage); } } + + // We have to clean up TLS here as this should be the last use of the + // scheduler from this thread + sch.resetThreadLocalCache(); } bool Executor::tryClaim() diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index f3f48437c..79c2a8515 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -22,6 +22,18 @@ using namespace faabric::util; namespace faabric::scheduler { +// 0MQ sockets are not thread-safe, and opening them and closing them from +// different threads messes things up. However, we don't want to constatnly +// create and recreate them to make calls in the scheduler, therefore we cache +// them in TLS, and perform thread-specific tidy-up. +static thread_local std::unordered_map + functionCallClients; + +static thread_local std::unordered_map + snapshotClients; + Scheduler& getScheduler() { static Scheduler sch; @@ -61,10 +73,21 @@ void Scheduler::addHostToGlobalSet() redis.sadd(AVAILABLE_HOST_SET, thisHost); } +void Scheduler::resetThreadLocalCache() +{ + auto tid = (pid_t)syscall(SYS_gettid); + SPDLOG_DEBUG("Resetting scheduler thread-local cache for thread {}", tid); + + functionCallClients.clear(); + snapshotClients.clear(); +} + void Scheduler::reset() { SPDLOG_DEBUG("Resetting scheduler"); + resetThreadLocalCache(); + // Shut down all Executors for (auto& p : executors) { for (auto& e : p.second) { @@ -95,9 +118,6 @@ void Scheduler::reset() recordedMessagesAll.clear(); recordedMessagesLocal.clear(); recordedMessagesShared.clear(); - - functionCallClients.clear(); - snapshotClients.clear(); } void Scheduler::shutdown() @@ -545,54 +565,25 @@ std::vector Scheduler::getRecordedMessagesLocal() return recordedMessagesLocal; } -std::string getClientKey(const std::string& otherHost) -{ - // Note, our keys here have to include the tid as the clients can only be - // used within the same thread. - std::thread::id tid = std::this_thread::get_id(); - std::stringstream ss; - ss << otherHost << "_" << tid; - std::string key = ss.str(); - return key; -} - FunctionCallClient& Scheduler::getFunctionCallClient( const std::string& otherHost) { - std::string key = getClientKey(otherHost); - if (functionCallClients.find(key) == functionCallClients.end()) { - faabric::util::FullLock lock(functionCallClientsMx); - - if (functionCallClients.find(key) == functionCallClients.end()) { - SPDLOG_DEBUG( - "Adding new function call client for {} ({})", otherHost, key); - functionCallClients.emplace(key, otherHost); - } + if (functionCallClients.find(otherHost) == functionCallClients.end()) { + SPDLOG_DEBUG("Adding new function call client for {}", otherHost); + functionCallClients.emplace(otherHost, otherHost); } - { - faabric::util::SharedLock lock(functionCallClientsMx); - return functionCallClients.at(key); - } + return functionCallClients.at(otherHost); } SnapshotClient& Scheduler::getSnapshotClient(const std::string& otherHost) { - std::string key = getClientKey(otherHost); - if (snapshotClients.find(key) == snapshotClients.end()) { - faabric::util::FullLock lock(snapshotClientsMx); - - if (snapshotClients.find(key) == snapshotClients.end()) { - SPDLOG_DEBUG( - "Adding new snapshot client for {} ({})", otherHost, key); - snapshotClients.emplace(key, otherHost); - } + if (snapshotClients.find(otherHost) == snapshotClients.end()) { + SPDLOG_DEBUG("Adding new snapshot client for {}", otherHost); + snapshotClients.emplace(otherHost, otherHost); } - { - faabric::util::SharedLock lock(snapshotClientsMx); - return snapshotClients.at(key); - } + return snapshotClients.at(otherHost); } std::vector> diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index b5994476e..182f41728 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -281,13 +281,16 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, // RECV ENDPOINT // ---------------------------------------------- -RecvMessageEndpoint::RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType) +RecvMessageEndpoint::RecvMessageEndpoint(int portIn, + int timeoutMs, + zmq::socket_type socketType) : MessageEndpoint(ANY_HOST, portIn, timeoutMs) { socket = setUpSocket(socketType, portIn); } -Message RecvMessageEndpoint::recv(int size) { +Message RecvMessageEndpoint::recv(int size) +{ return doRecv(socket, size); } @@ -297,8 +300,7 @@ Message RecvMessageEndpoint::recv(int size) { AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) -{ -} +{} Message AsyncRecvMessageEndpoint::recv(int size) { @@ -312,8 +314,7 @@ Message AsyncRecvMessageEndpoint::recv(int size) SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::rep) -{ -} +{} Message SyncRecvMessageEndpoint::recv(int size) { diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 2606026ce..08f623afa 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -45,8 +45,8 @@ void MessageEndpointServerThread::start( if (header.size() == shutdownHeader.size()) { if (header.dataCopy() == shutdownHeader) { - SPDLOG_TRACE( - "Server on {} received shutdown message", port); + SPDLOG_TRACE("Server on {} received shutdown message", + port); break; } } diff --git a/tests/dist/main.cpp b/tests/dist/main.cpp index 77f944364..a93cfcd35 100644 --- a/tests/dist/main.cpp +++ b/tests/dist/main.cpp @@ -1,47 +1,44 @@ #define CATCH_CONFIG_RUNNER -#include "DistTestExecutor.h" -#include "faabric_utils.h" #include +#include "DistTestExecutor.h" +#include "faabric_utils.h" #include "init.h" -#include #include #include #include #include -#include - -using namespace faabric::scheduler; FAABRIC_CATCH_LOGGER int main(int argc, char* argv[]) { - faabric::util::initLogging(); faabric::transport::initGlobalMessageContext(); - - // Set up the distributed tests + faabric::util::initLogging(); tests::initDistTests(); - // Start everything up - SPDLOG_INFO("Starting distributed test server on master"); - std::shared_ptr fac = + std::shared_ptr fac = std::make_shared(); - faabric::runner::FaabricMain m(fac); - m.startBackground(); - - // Wait for things to start - SLEEP_MS(3000); - - // Run the tests - int result = Catch::Session().run(argc, argv); - fflush(stdout); - // Shut down - SPDLOG_INFO("Shutting down"); - m.shutdown(); + // WARNING: all 0MQ sockets have to have gone *out of scope* before we shut + // down the context, therefore this segment must be in a nested scope (or + // another function). + int result; + { + SPDLOG_INFO("Starting distributed test server on master"); + faabric::runner::FaabricMain m(fac); + m.startBackground(); + + // Run the tests + result = Catch::Session().run(argc, argv); + fflush(stdout); + + // Shut down + SPDLOG_INFO("Shutting down"); + m.shutdown(); + } faabric::transport::closeGlobalMessageContext(); diff --git a/tests/dist/server.cpp b/tests/dist/server.cpp index 9fe79011b..d62248bd9 100644 --- a/tests/dist/server.cpp +++ b/tests/dist/server.cpp @@ -21,6 +21,7 @@ int main() faabric::runner::FaabricMain m(fac); m.startBackground(); + // Note, endpoint will block until killed SPDLOG_INFO("Starting HTTP endpoint on worker"); faabric::endpoint::FaabricEndpoint endpoint; endpoint.start(); diff --git a/tests/test/main.cpp b/tests/test/main.cpp index 30a52adb3..28aac75a0 100644 --- a/tests/test/main.cpp +++ b/tests/test/main.cpp @@ -13,7 +13,6 @@ FAABRIC_CATCH_LOGGER int main(int argc, char* argv[]) { faabric::transport::initGlobalMessageContext(); - faabric::util::setTestMode(true); faabric::util::initLogging(); diff --git a/tests/test/runner/test_main.cpp b/tests/test/runner/test_main.cpp index a97696103..9d81fc8aa 100644 --- a/tests/test/runner/test_main.cpp +++ b/tests/test/runner/test_main.cpp @@ -32,17 +32,23 @@ TEST_CASE_METHOD(MainRunnerTestFixture, "Test main runner", "[runner]") m.startBackground(); - std::shared_ptr req = - faabric::util::batchExecFactory("foo", "bar", 4); + SECTION("Do nothing") {} - auto& sch = faabric::scheduler::getScheduler(); - sch.callFunctions(req); - - for (const auto& m : req->messages()) { - std::string expected = fmt::format("DummyExecutor executed {}", m.id()); - faabric::Message res = - sch.getFunctionResult(m.id(), SHORT_TEST_TIMEOUT_MS); - REQUIRE(res.outputdata() == expected); + SECTION("Make calls") + { + std::shared_ptr req = + faabric::util::batchExecFactory("foo", "bar", 4); + + auto& sch = faabric::scheduler::getScheduler(); + sch.callFunctions(req); + + for (const auto& m : req->messages()) { + std::string expected = + fmt::format("DummyExecutor executed {}", m.id()); + faabric::Message res = + sch.getFunctionResult(m.id(), SHORT_TEST_TIMEOUT_MS); + REQUIRE(res.outputdata() == expected); + } } m.shutdown(); From b3327e98a2f44e71471a516bfa4e8ae3da76b2ff Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 5 Jul 2021 08:30:02 +0000 Subject: [PATCH 63/66] Rename retry macro --- src/transport/MessageEndpoint.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 182f41728..6ef573b73 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -21,7 +21,7 @@ throw; \ } -#define CATCH_ZMQ_ERR_RETRY(op, label) \ +#define CATCH_ZMQ_ERR_RETRY_ONCE(op, label) \ try { \ op; \ } catch (zmq::error_t & e) { \ @@ -83,25 +83,25 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, case zmq::socket_type::req: { SPDLOG_TRACE( "New socket: req {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR_RETRY(socket.connect(address), "connect") + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") break; } case zmq::socket_type::push: { SPDLOG_TRACE( "New socket: push {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR_RETRY(socket.connect(address), "connect") + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") break; } case zmq::socket_type::pull: { SPDLOG_TRACE( "New socket: pull {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR_RETRY(socket.bind(address), "bind") + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } case zmq::socket_type::rep: { SPDLOG_TRACE( "New socket: rep {}:{} (timeout {}ms)", host, port, timeoutMs); - CATCH_ZMQ_ERR_RETRY(socket.bind(address), "bind") + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } default: { From fd9ad8af1a7006d59751f87456f7ae7755c3e91b Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 5 Jul 2021 10:33:32 +0000 Subject: [PATCH 64/66] Move latch constructor --- include/faabric/util/latch.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/faabric/util/latch.h b/include/faabric/util/latch.h index 462d261b6..b87419fe4 100644 --- a/include/faabric/util/latch.h +++ b/include/faabric/util/latch.h @@ -17,10 +17,10 @@ class Latch int count, int timeoutMs = DEFAULT_LATCH_TIMEOUT_MS); - void wait(); - explicit Latch(int countIn, int timeoutMsIn); + void wait(); + private: int count; int waiters = 0; From 1f71e38f48988b74b98bccd6b3417eb73821bfb9 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 5 Jul 2021 10:34:09 +0000 Subject: [PATCH 65/66] Move FB macro --- include/faabric/transport/macros.h | 15 +++++++++++++++ src/scheduler/SnapshotClient.cpp | 16 +--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index a37132095..c24d5d125 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -19,3 +19,18 @@ if (!msg->SerializeToArray(buffer, msgSize)) { \ throw std::runtime_error("Error serialising message"); \ } + +#define SEND_FB_MSG(T, mb) \ + { \ + const uint8_t* buffer = mb.GetBufferPointer(); \ + int size = mb.GetSize(); \ + faabric::EmptyResponse response; \ + syncSend(T, buffer, size, &response); \ + } + +#define SEND_FB_MSG_ASYNC(T, mb) \ + { \ + const uint8_t* buffer = mb.GetBufferPointer(); \ + int size = mb.GetSize(); \ + asyncSend(T, buffer, size); \ + } diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index ce821adc9..d3f80bc87 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -69,21 +70,6 @@ void clearMockSnapshotRequests() // Snapshot client // ----------------------------------- -#define SEND_FB_MSG(T, mb) \ - { \ - const uint8_t* buffer = mb.GetBufferPointer(); \ - int size = mb.GetSize(); \ - faabric::EmptyResponse response; \ - syncSend(T, buffer, size, &response); \ - } - -#define SEND_FB_MSG_ASYNC(T, mb) \ - { \ - const uint8_t* buffer = mb.GetBufferPointer(); \ - int size = mb.GetSize(); \ - asyncSend(T, buffer, size); \ - } - SnapshotClient::SnapshotClient(const std::string& hostIn) : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_ASYNC_PORT, From 8d7565dcfdd258aac9af991d9bd7b52c9d611f24 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 5 Jul 2021 10:42:27 +0000 Subject: [PATCH 66/66] Message server interface to buffers --- .../faabric/scheduler/FunctionCallServer.h | 21 ++--- include/faabric/scheduler/SnapshotServer.h | 20 ++--- include/faabric/state/StateServer.h | 44 +++++----- .../faabric/transport/MessageEndpointServer.h | 12 ++- src/scheduler/FunctionCallServer.cpp | 47 +++++------ src/scheduler/SnapshotServer.cpp | 52 ++++++------ src/state/StateServer.cpp | 80 ++++++++++--------- src/transport/Message.cpp | 1 - src/transport/MessageEndpoint.cpp | 2 +- src/transport/MessageEndpointServer.cpp | 15 ++-- tests/test/transport/test_message_server.cpp | 36 ++++----- 11 files changed, 170 insertions(+), 160 deletions(-) diff --git a/include/faabric/scheduler/FunctionCallServer.h b/include/faabric/scheduler/FunctionCallServer.h index 733a6dec6..23227c43c 100644 --- a/include/faabric/scheduler/FunctionCallServer.h +++ b/include/faabric/scheduler/FunctionCallServer.h @@ -15,21 +15,22 @@ class FunctionCallServer final private: Scheduler& scheduler; - void doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override; - std::unique_ptr doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) override; + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override; - std::unique_ptr recvFlush( - faabric::transport::Message& body); + std::unique_ptr recvFlush(const uint8_t* buffer, + size_t bufferSize); std::unique_ptr recvGetResources( - faabric::transport::Message& body); + const uint8_t* buffer, + size_t bufferSize); - void recvExecuteFunctions(faabric::transport::Message& body); + void recvExecuteFunctions(const uint8_t* buffer, size_t bufferSize); - void recvUnregister(faabric::transport::Message& body); + void recvUnregister(const uint8_t* buffer, size_t bufferSize); }; } diff --git a/include/faabric/scheduler/SnapshotServer.h b/include/faabric/scheduler/SnapshotServer.h index 8d0ea4a66..861bb75a0 100644 --- a/include/faabric/scheduler/SnapshotServer.h +++ b/include/faabric/scheduler/SnapshotServer.h @@ -12,22 +12,24 @@ class SnapshotServer final : public faabric::transport::MessageEndpointServer SnapshotServer(); protected: - void doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override; - std::unique_ptr doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) override; + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override; std::unique_ptr recvPushSnapshot( - faabric::transport::Message& msg); + const uint8_t* buffer, + size_t bufferSize); std::unique_ptr recvPushSnapshotDiffs( - faabric::transport::Message& msg); + const uint8_t* buffer, + size_t bufferSize); - void recvDeleteSnapshot(faabric::transport::Message& msg); + void recvDeleteSnapshot(const uint8_t* buffer, size_t bufferSize); - void recvThreadResult(faabric::transport::Message& msg); + void recvThreadResult(const uint8_t* buffer, size_t bufferSize); private: void applyDiffsToSnapshot( diff --git a/include/faabric/state/StateServer.h b/include/faabric/state/StateServer.h index 883ab503e..ebd62760a 100644 --- a/include/faabric/state/StateServer.h +++ b/include/faabric/state/StateServer.h @@ -13,40 +13,42 @@ class StateServer final : public faabric::transport::MessageEndpointServer private: State& state; - void doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override; - std::unique_ptr doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) override; + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override; // Sync methods - std::unique_ptr recvSize( - faabric::transport::Message& body); + std::unique_ptr recvSize(const uint8_t* buffer, + size_t bufferSize); - std::unique_ptr recvPull( - faabric::transport::Message& body); + std::unique_ptr recvPull(const uint8_t* buffer, + size_t bufferSize); - std::unique_ptr recvPush( - faabric::transport::Message& body); + std::unique_ptr recvPush(const uint8_t* buffer, + size_t bufferSize); - std::unique_ptr recvAppend( - faabric::transport::Message& body); + std::unique_ptr recvAppend(const uint8_t* buffer, + size_t bufferSize); std::unique_ptr recvPullAppended( - faabric::transport::Message& body); + const uint8_t* buffer, + size_t bufferSize); std::unique_ptr recvClearAppended( - faabric::transport::Message& body); + const uint8_t* buffer, + size_t bufferSize); - std::unique_ptr recvDelete( - faabric::transport::Message& body); + std::unique_ptr recvDelete(const uint8_t* buffer, + size_t bufferSize); - std::unique_ptr recvLock( - faabric::transport::Message& body); + std::unique_ptr recvLock(const uint8_t* buffer, + size_t bufferSize); - std::unique_ptr recvUnlock( - faabric::transport::Message& body); + std::unique_ptr recvUnlock(const uint8_t* buffer, + size_t bufferSize); }; } diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index e1cc756e7..a7c3f88a2 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -43,14 +43,12 @@ class MessageEndpointServer void awaitAsyncLatch(); protected: - virtual void doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) = 0; + virtual void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) = 0; - virtual std::unique_ptr doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) = 0; - - void sendSyncResponse(google::protobuf::Message* resp); + virtual std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) = 0; private: friend class MessageEndpointServerThread; diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 53fd0b7f5..dcede0518 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -14,50 +14,48 @@ FunctionCallServer::FunctionCallServer() , scheduler(getScheduler()) {} -void FunctionCallServer::doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +void FunctionCallServer::doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) { - assert(header.size() == sizeof(uint8_t)); - uint8_t call = static_cast(*header.data()); - switch (call) { + switch (header) { case faabric::scheduler::FunctionCalls::ExecuteFunctions: { - recvExecuteFunctions(body); + recvExecuteFunctions(buffer, bufferSize); break; } case faabric::scheduler::FunctionCalls::Unregister: { - recvUnregister(body); + recvUnregister(buffer, bufferSize); break; } default: { throw std::runtime_error( - fmt::format("Unrecognized async call header: {}", call)); + fmt::format("Unrecognized async call header: {}", header)); } } } std::unique_ptr FunctionCallServer::doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) + int header, + const uint8_t* buffer, + size_t bufferSize) { - assert(header.size() == sizeof(uint8_t)); - - uint8_t call = static_cast(*header.data()); - switch (call) { + switch (header) { case faabric::scheduler::FunctionCalls::Flush: { - return recvFlush(body); + return recvFlush(buffer, bufferSize); } case faabric::scheduler::FunctionCalls::GetResources: { - return recvGetResources(body); + return recvGetResources(buffer, bufferSize); } default: { throw std::runtime_error( - fmt::format("Unrecognized sync call header: {}", call)); + fmt::format("Unrecognized sync call header: {}", header)); } } } std::unique_ptr FunctionCallServer::recvFlush( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { // Clear out any cached state faabric::state::getGlobalState().forceClearAll(false); @@ -68,18 +66,20 @@ std::unique_ptr FunctionCallServer::recvFlush( return std::make_unique(); } -void FunctionCallServer::recvExecuteFunctions(faabric::transport::Message& body) +void FunctionCallServer::recvExecuteFunctions(const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::BatchExecuteRequest, body.data(), body.size()) + PARSE_MSG(faabric::BatchExecuteRequest, buffer, bufferSize) // This host has now been told to execute these functions no matter what scheduler.callFunctions(std::make_shared(msg), true); } -void FunctionCallServer::recvUnregister(faabric::transport::Message& body) +void FunctionCallServer::recvUnregister(const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::UnregisterRequest, body.data(), body.size()) + PARSE_MSG(faabric::UnregisterRequest, buffer, bufferSize) std::string funcStr = faabric::util::funcToString(msg.function(), false); SPDLOG_DEBUG("Unregistering host {} for {}", msg.host(), funcStr); @@ -89,7 +89,8 @@ void FunctionCallServer::recvUnregister(faabric::transport::Message& body) } std::unique_ptr FunctionCallServer::recvGetResources( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { auto response = std::make_unique( scheduler.getThisHostResources()); diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 934c54e78..f220b5ed1 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -16,52 +16,49 @@ SnapshotServer::SnapshotServer() SNAPSHOT_SYNC_PORT) {} -void SnapshotServer::doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +void SnapshotServer::doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) { - assert(header.size() == sizeof(uint8_t)); - uint8_t call = static_cast(*header.data()); - switch (call) { + switch (header) { case faabric::scheduler::SnapshotCalls::DeleteSnapshot: { - this->recvDeleteSnapshot(body); + this->recvDeleteSnapshot(buffer, bufferSize); break; } case faabric::scheduler::SnapshotCalls::ThreadResult: { - this->recvThreadResult(body); + this->recvThreadResult(buffer, bufferSize); break; } default: { throw std::runtime_error( - fmt::format("Unrecognized async call header: {}", call)); + fmt::format("Unrecognized async call header: {}", header)); } } } -std::unique_ptr SnapshotServer::doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) +std::unique_ptr +SnapshotServer::doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) { - assert(header.size() == sizeof(uint8_t)); - uint8_t call = static_cast(*header.data()); - switch (call) { + switch (header) { case faabric::scheduler::SnapshotCalls::PushSnapshot: { - return recvPushSnapshot(body); + return recvPushSnapshot(buffer, bufferSize); } case faabric::scheduler::SnapshotCalls::PushSnapshotDiffs: { - return recvPushSnapshotDiffs(body); + return recvPushSnapshotDiffs(buffer, bufferSize); } default: { throw std::runtime_error( - fmt::format("Unrecognized sync call header: {}", call)); + fmt::format("Unrecognized sync call header: {}", header)); } } } std::unique_ptr SnapshotServer::recvPushSnapshot( - faabric::transport::Message& msg) + const uint8_t* buffer, + size_t bufferSize) { - SnapshotPushRequest* r = - flatbuffers::GetMutableRoot(msg.udata()); + const SnapshotPushRequest* r = + flatbuffers::GetRoot(buffer); SPDLOG_DEBUG("Receiving shapshot {} (size {})", r->key()->c_str(), @@ -80,7 +77,7 @@ std::unique_ptr SnapshotServer::recvPushSnapshot( // this data? data.data = (uint8_t*)mmap( nullptr, data.size, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - std::memcpy(data.data, r->mutable_contents()->Data(), data.size); + std::memcpy(data.data, r->contents()->Data(), data.size); reg.takeSnapshot(r->key()->str(), data, true); @@ -88,10 +85,10 @@ std::unique_ptr SnapshotServer::recvPushSnapshot( return std::make_unique(); } -void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) +void SnapshotServer::recvThreadResult(const uint8_t* buffer, size_t bufferSize) { const ThreadResultRequest* r = - flatbuffers::GetMutableRoot(msg.udata()); + flatbuffers::GetRoot(buffer); // Apply snapshot diffs *first* (these must be applied before other threads // can continue) @@ -108,10 +105,10 @@ void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) } std::unique_ptr -SnapshotServer::recvPushSnapshotDiffs(faabric::transport::Message& msg) +SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) { const SnapshotDiffPushRequest* r = - flatbuffers::GetMutableRoot(msg.udata()); + flatbuffers::GetRoot(buffer); applyDiffsToSnapshot(r->key()->str(), r->chunks()); @@ -139,10 +136,11 @@ void SnapshotServer::applyDiffsToSnapshot( } } -void SnapshotServer::recvDeleteSnapshot(faabric::transport::Message& msg) +void SnapshotServer::recvDeleteSnapshot(const uint8_t* buffer, + size_t bufferSize) { const SnapshotDeleteRequest* r = - flatbuffers::GetRoot(msg.udata()); + flatbuffers::GetRoot(buffer); SPDLOG_INFO("Deleting shapshot {}", r->key()->c_str()); faabric::snapshot::SnapshotRegistry& reg = diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index e83f72470..b162fbee6 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -17,58 +17,56 @@ StateServer::StateServer(State& stateIn) , state(stateIn) {} -void StateServer::doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +void StateServer::doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) { throw std::runtime_error("State server does not support async recv"); } -std::unique_ptr StateServer::doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) +std::unique_ptr +StateServer::doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) { - assert(header.size() == sizeof(uint8_t)); - - uint8_t call = static_cast(*header.data()); - switch (call) { + switch (header) { case faabric::state::StateCalls::Pull: { - return recvPull(body); + return recvPull(buffer, bufferSize); } case faabric::state::StateCalls::Push: { - return recvPush(body); + return recvPush(buffer, bufferSize); } case faabric::state::StateCalls::Size: { - return recvSize(body); + return recvSize(buffer, bufferSize); } case faabric::state::StateCalls::Append: { - return recvAppend(body); + return recvAppend(buffer, bufferSize); } case faabric::state::StateCalls::ClearAppended: { - return recvClearAppended(body); + return recvClearAppended(buffer, bufferSize); } case faabric::state::StateCalls::PullAppended: { - return recvPullAppended(body); + return recvPullAppended(buffer, bufferSize); } case faabric::state::StateCalls::Lock: { - return recvLock(body); + return recvLock(buffer, bufferSize); } case faabric::state::StateCalls::Unlock: { - return recvUnlock(body); + return recvUnlock(buffer, bufferSize); } case faabric::state::StateCalls::Delete: { - return recvDelete(body); + return recvDelete(buffer, bufferSize); } default: { throw std::runtime_error( - fmt::format("Unrecognized state call header: {}", call)); + fmt::format("Unrecognized state call header: {}", header)); } } } std::unique_ptr StateServer::recvSize( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Prepare the response SPDLOG_TRACE("Size {}/{}", msg.user(), msg.key()); @@ -82,9 +80,10 @@ std::unique_ptr StateServer::recvSize( } std::unique_ptr StateServer::recvPull( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateChunkRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateChunkRequest, buffer, bufferSize) SPDLOG_TRACE("Pull {}/{} ({}->{})", msg.user(), @@ -109,9 +108,10 @@ std::unique_ptr StateServer::recvPull( } std::unique_ptr StateServer::recvPush( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StatePart, body.data(), body.size()) + PARSE_MSG(faabric::StatePart, buffer, bufferSize) // Update the KV store SPDLOG_TRACE("Push {}/{} ({}->{})", @@ -129,9 +129,10 @@ std::unique_ptr StateServer::recvPush( } std::unique_ptr StateServer::recvAppend( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Update the KV KV_FROM_REQUEST(msg) @@ -144,9 +145,10 @@ std::unique_ptr StateServer::recvAppend( } std::unique_ptr StateServer::recvPullAppended( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateAppendedRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateAppendedRequest, buffer, bufferSize) // Prepare response SPDLOG_TRACE("Pull appended {}/{}", msg.user(), msg.key()); @@ -166,9 +168,10 @@ std::unique_ptr StateServer::recvPullAppended( } std::unique_ptr StateServer::recvDelete( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Delete value SPDLOG_TRACE("Delete {}/{}", msg.user(), msg.key()); @@ -179,9 +182,10 @@ std::unique_ptr StateServer::recvDelete( } std::unique_ptr StateServer::recvClearAppended( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Perform operation SPDLOG_TRACE("Clear appended {}/{}", msg.user(), msg.key()); @@ -193,9 +197,10 @@ std::unique_ptr StateServer::recvClearAppended( } std::unique_ptr StateServer::recvLock( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Perform operation SPDLOG_TRACE("Lock {}/{}", msg.user(), msg.key()); @@ -207,9 +212,10 @@ std::unique_ptr StateServer::recvLock( } std::unique_ptr StateServer::recvUnlock( - faabric::transport::Message& body) + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Perform operation SPDLOG_TRACE("Unlock {}/{}", msg.user(), msg.key()); diff --git a/src/transport/Message.cpp b/src/transport/Message.cpp index 19ac4614d..37ee3e01a 100644 --- a/src/transport/Message.cpp +++ b/src/transport/Message.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 6ef573b73..2ef6d92b6 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -21,7 +21,7 @@ throw; \ } -#define CATCH_ZMQ_ERR_RETRY_ONCE(op, label) \ +#define CATCH_ZMQ_ERR_RETRY_ONCE(op, label) \ try { \ op; \ } catch (zmq::error_t & e) { \ diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 08f623afa..52a59eb57 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -40,18 +40,18 @@ void MessageEndpointServerThread::start( bool bodyReceived = false; try { // Receive header and body - Message header = endpoint->recv(); + Message headerMessage = endpoint->recv(); headerReceived = true; - if (header.size() == shutdownHeader.size()) { - if (header.dataCopy() == shutdownHeader) { + if (headerMessage.size() == shutdownHeader.size()) { + if (headerMessage.dataCopy() == shutdownHeader) { SPDLOG_TRACE("Server on {} received shutdown message", port); break; } } - if (!header.more()) { + if (!headerMessage.more()) { throw std::runtime_error( "Header sent without SNDMORE flag"); } @@ -62,13 +62,16 @@ void MessageEndpointServerThread::start( } bodyReceived = true; + assert(headerMessage.size() == sizeof(uint8_t)); + uint8_t header = static_cast(*headerMessage.data()); + if (async) { // Server-specific async handling - server->doAsyncRecv(header, body); + server->doAsyncRecv(header, body.udata(), body.size()); } else { // Server-specific sync handling std::unique_ptr resp = - server->doSyncRecv(header, body); + server->doSyncRecv(header, body.udata(), body.size()); size_t respSize = resp->ByteSizeLong(); uint8_t buffer[respSize]; diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 4c88c2814..2449d55b8 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -26,15 +26,15 @@ class DummyServer final : public MessageEndpointServer std::atomic messageCount = 0; private: - void doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override { messageCount++; } - std::unique_ptr doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) override + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override { messageCount++; @@ -50,20 +50,20 @@ class EchoServer final : public MessageEndpointServer {} protected: - void doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override { throw std::runtime_error("Echo server not expecting async recv"); } - std::unique_ptr doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) override + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override { - SPDLOG_TRACE("Echo server received {} bytes", body.size()); + SPDLOG_TRACE("Echo server received {} bytes", bufferSize); auto response = std::make_unique(); - response->set_data(body.data(), body.size()); + response->set_data(buffer, bufferSize); return response; } @@ -79,17 +79,17 @@ class SleepServer final : public MessageEndpointServer {} protected: - void doAsyncRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override { throw std::runtime_error("Sleep server not expecting async recv"); } - std::unique_ptr doSyncRecv( - faabric::transport::Message& header, - faabric::transport::Message& body) override + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override { - int* sleepTimeMs = (int*)body.udata(); + int* sleepTimeMs = (int*)buffer; SPDLOG_DEBUG("Sleep server sleeping for {}ms", *sleepTimeMs); SLEEP_MS(*sleepTimeMs);