diff --git a/examples/server.cpp b/examples/server.cpp index 19108c0a9..3bdcbc5b7 100644 --- a/examples/server.cpp +++ b/examples/server.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include using namespace faabric::scheduler; @@ -38,6 +39,7 @@ class ExampleExecutorFactory : public ExecutorFactory int main() { faabric::util::initLogging(); + faabric::transport::initGlobalMessageContext(); // Start the worker pool SPDLOG_INFO("Starting executor pool in the background"); @@ -53,6 +55,7 @@ int main() SPDLOG_INFO("Shutting down endpoint"); m.shutdown(); + faabric::transport::closeGlobalMessageContext(); return EXIT_SUCCESS; } diff --git a/include/faabric/scheduler/FunctionCallClient.h b/include/faabric/scheduler/FunctionCallClient.h index 55e105231..be86d98a1 100644 --- a/include/faabric/scheduler/FunctionCallClient.h +++ b/include/faabric/scheduler/FunctionCallClient.h @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include @@ -13,13 +13,13 @@ namespace faabric::scheduler { // ----------------------------------- std::vector> getFunctionCalls(); -std::vector> getFlushCalls(); +std::vector> getFlushCalls(); std::vector< std::pair>> getBatchRequests(); -std::vector> +std::vector> getResourceRequests(); std::vector> @@ -47,7 +47,7 @@ class FunctionCallClient : public faabric::transport::MessageEndpointClient void executeFunctions( const std::shared_ptr req); - void unregister(const faabric::UnregisterRequest& req); + void unregister(faabric::UnregisterRequest& req); private: void sendHeader(faabric::scheduler::FunctionCalls call); diff --git a/include/faabric/scheduler/FunctionCallServer.h b/include/faabric/scheduler/FunctionCallServer.h index 962cc7439..23227c43c 100644 --- a/include/faabric/scheduler/FunctionCallServer.h +++ b/include/faabric/scheduler/FunctionCallServer.h @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace faabric::scheduler { @@ -13,24 +12,25 @@ class FunctionCallServer final public: FunctionCallServer(); - void stop() override; - private: Scheduler& scheduler; - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; - - /* Function call server API */ + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override; - void recvFlush(faabric::transport::Message& body); + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override; - void recvExecuteFunctions(faabric::transport::Message& body); + std::unique_ptr recvFlush(const uint8_t* buffer, + size_t bufferSize); - void recvGetResources(faabric::transport::Message& body); + std::unique_ptr recvGetResources( + const uint8_t* buffer, + size_t bufferSize); - void recvUnregister(faabric::transport::Message& body); + void recvExecuteFunctions(const uint8_t* buffer, size_t bufferSize); - void recvSetThreadResult(faabric::transport::Message& body); + void recvUnregister(const uint8_t* buffer, size_t bufferSize); }; } diff --git a/include/faabric/scheduler/MpiWorld.h b/include/faabric/scheduler/MpiWorld.h index fc91d4bbb..2ce29ae1c 100644 --- a/include/faabric/scheduler/MpiWorld.h +++ b/include/faabric/scheduler/MpiWorld.h @@ -10,6 +10,7 @@ #include #include +#include namespace faabric::scheduler { typedef faabric::util::Queue> @@ -22,12 +23,11 @@ class MpiWorld void create(const faabric::Message& call, int newId, int newSize); - void initialiseFromMsg(const faabric::Message& msg, - bool forceLocal = false); + void broadcastHostsToRanks(); - std::string getHostForRank(int rank); + void initialiseFromMsg(const faabric::Message& msg); - void setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg); + std::string getHostForRank(int rank); std::string getUser(); @@ -181,12 +181,12 @@ class MpiWorld double getWTime(); private: - int id; - int size; + int id = -1; + int size = -1; std::string thisHost; + int basePort = DEFAULT_MPI_BASE_PORT; faabric::util::TimePoint creationTime; - std::shared_mutex worldMutex; std::atomic_flag isDestroyed = false; std::string user; @@ -208,18 +208,29 @@ class MpiWorld std::vector basePorts; std::vector initLocalBasePorts( const std::vector& executedAt); + void initRemoteMpiEndpoint(int localRank, int remoteRank); + std::pair getPortForRanks(int localRank, int remoteRank); + void sendRemoteMpiMessage(int sendRank, int recvRank, const std::shared_ptr& msg); + std::shared_ptr recvRemoteMpiMessage(int sendRank, int recvRank); + + faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); + + void sendMpiHostRankMsg(const std::string& hostIn, + const faabric::MpiHostsToRanksMessage msg); + void closeMpiMessageEndpoints(); // Support for asyncrhonous communications std::shared_ptr getUnackedMessageBuffer(int sendRank, int recvRank); + std::shared_ptr recvBatchReturnLast(int sendRank, int recvRank, int batchSize = 0); diff --git a/include/faabric/scheduler/Scheduler.h b/include/faabric/scheduler/Scheduler.h index 725feb224..68ed0574e 100644 --- a/include/faabric/scheduler/Scheduler.h +++ b/include/faabric/scheduler/Scheduler.h @@ -91,6 +91,8 @@ class Scheduler void reset(); + void resetThreadLocalCache(); + void shutdown(); void broadcastSnapshotDelete(const faabric::Message& msg, @@ -168,10 +170,6 @@ class Scheduler ExecGraph getFunctionExecGraph(unsigned int msgId); - void closeFunctionCallClients(); - - void closeSnapshotClients(); - private: std::string thisHost; @@ -186,15 +184,9 @@ class Scheduler std::unordered_map> threadResults; - std::shared_mutex functionCallClientsMx; - std::unordered_map - functionCallClients; faabric::scheduler::FunctionCallClient& getFunctionCallClient( const std::string& otherHost); - std::shared_mutex snapshotClientsMx; - std::unordered_map - snapshotClients; faabric::scheduler::SnapshotClient& getSnapshotClient( const std::string& otherHost); diff --git a/include/faabric/scheduler/SnapshotClient.h b/include/faabric/scheduler/SnapshotClient.h index c97cc497c..6d7fd8655 100644 --- a/include/faabric/scheduler/SnapshotClient.h +++ b/include/faabric/scheduler/SnapshotClient.h @@ -2,6 +2,7 @@ #include #include +#include #include #include diff --git a/include/faabric/scheduler/SnapshotServer.h b/include/faabric/scheduler/SnapshotServer.h index 326b0c7a2..861bb75a0 100644 --- a/include/faabric/scheduler/SnapshotServer.h +++ b/include/faabric/scheduler/SnapshotServer.h @@ -11,21 +11,25 @@ class SnapshotServer final : public faabric::transport::MessageEndpointServer public: SnapshotServer(); - void stop() override; - protected: - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override; - /* Snapshot server API */ + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override; - void recvPushSnapshot(faabric::transport::Message& msg); + std::unique_ptr recvPushSnapshot( + const uint8_t* buffer, + size_t bufferSize); - void recvDeleteSnapshot(faabric::transport::Message& msg); + std::unique_ptr recvPushSnapshotDiffs( + const uint8_t* buffer, + size_t bufferSize); - void recvPushSnapshotDiffs(faabric::transport::Message& msg); + void recvDeleteSnapshot(const uint8_t* buffer, size_t bufferSize); - void recvThreadResult(faabric::transport::Message& msg); + void recvThreadResult(const uint8_t* buffer, size_t bufferSize); private: void applyDiffsToSnapshot( diff --git a/include/faabric/state/DummyStateServer.h b/include/faabric/state/DummyStateServer.h deleted file mode 100644 index df3010c57..000000000 --- a/include/faabric/state/DummyStateServer.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include - -namespace faabric::state { -class DummyStateServer -{ - public: - DummyStateServer(); - - std::shared_ptr getRemoteKv(); - - std::shared_ptr getLocalKv(); - - std::vector getRemoteKvValue(); - - std::vector getLocalKvValue(); - - std::vector dummyData; - std::string dummyUser; - std::string dummyKey; - - void start(); - - void stop(); - - state::State remoteState; - state::StateServer stateServer; -}; - -} diff --git a/include/faabric/state/StateClient.h b/include/faabric/state/StateClient.h index 79b5ae9f1..fff3e2b4a 100644 --- a/include/faabric/state/StateClient.h +++ b/include/faabric/state/StateClient.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace faabric::state { @@ -15,11 +16,6 @@ class StateClient : public faabric::transport::MessageEndpointClient const std::string user; const std::string key; - const std::string host; - - InMemoryStateRegistry& reg; - - /* External state client API */ void pushChunks(const std::vector& chunks); @@ -41,16 +37,8 @@ class StateClient : public faabric::transport::MessageEndpointClient void unlock(); private: - void sendHeader(faabric::state::StateCalls call); - - // Block, but ignore return value - faabric::transport::Message awaitResponse(); - - void sendStateRequest(faabric::state::StateCalls header, bool expectReply); - void sendStateRequest(faabric::state::StateCalls header, - const uint8_t* data = nullptr, - int length = 0, - bool expectReply = false); + const uint8_t* data, + int length); }; } diff --git a/include/faabric/state/StateServer.h b/include/faabric/state/StateServer.h index 29508b88d..ebd62760a 100644 --- a/include/faabric/state/StateServer.h +++ b/include/faabric/state/StateServer.h @@ -13,27 +13,42 @@ class StateServer final : public faabric::transport::MessageEndpointServer private: State& state; - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override; + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override; - /* State server API */ + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override; - void recvSize(faabric::transport::Message& body); + // Sync methods - void recvPull(faabric::transport::Message& body); + std::unique_ptr recvSize(const uint8_t* buffer, + size_t bufferSize); - void recvPush(faabric::transport::Message& body); + std::unique_ptr recvPull(const uint8_t* buffer, + size_t bufferSize); - void recvAppend(faabric::transport::Message& body); + std::unique_ptr recvPush(const uint8_t* buffer, + size_t bufferSize); - void recvPullAppended(faabric::transport::Message& body); + std::unique_ptr recvAppend(const uint8_t* buffer, + size_t bufferSize); - void recvClearAppended(faabric::transport::Message& body); + std::unique_ptr recvPullAppended( + const uint8_t* buffer, + size_t bufferSize); - void recvDelete(faabric::transport::Message& body); + std::unique_ptr recvClearAppended( + const uint8_t* buffer, + size_t bufferSize); - void recvLock(faabric::transport::Message& body); + std::unique_ptr recvDelete(const uint8_t* buffer, + size_t bufferSize); - void recvUnlock(faabric::transport::Message& body); + std::unique_ptr recvLock(const uint8_t* buffer, + size_t bufferSize); + + std::unique_ptr recvUnlock(const uint8_t* buffer, + size_t bufferSize); }; } diff --git a/include/faabric/transport/Message.h b/include/faabric/transport/Message.h index f5c5eadbe..8e3378d63 100644 --- a/include/faabric/transport/Message.h +++ b/include/faabric/transport/Message.h @@ -18,22 +18,19 @@ class Message Message(); - ~Message(); - char* data(); uint8_t* udata(); + std::vector dataCopy(); + int size(); bool more(); - void persist(); - private: - uint8_t* msg; - int _size; + std::vector bytes; + bool _more; - bool _persist; }; } diff --git a/include/faabric/transport/MessageContext.h b/include/faabric/transport/MessageContext.h deleted file mode 100644 index 78c17afb2..000000000 --- a/include/faabric/transport/MessageContext.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include - -namespace faabric::transport { -/* Wrapper around zmq::context_t - * - * The context object is thread safe, and the constructor parameter indicates - * the number of hardware IO threads to be used. As a rule of thumb, use one - * IO thread per Gbps of data. - */ -class MessageContext -{ - public: - MessageContext(); - - // Message context should not be copied as there must only be one ZMQ - // context - MessageContext(const MessageContext& ctx) = delete; - - MessageContext(int overrideIoThreads); - - ~MessageContext(); - - zmq::context_t ctx; - - zmq::context_t& get(); - - /* Close the message context - * - * In 0MQ terms, this method calls close() on the context, which in turn - * first shuts down (i.e. stop blocking operations) and then closes. - */ - void close(); - - bool isContextShutDown; -}; - -faabric::transport::MessageContext& getGlobalMessageContext(); -} diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index c1852f89d..0e4dbef9d 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -1,9 +1,6 @@ #pragma once -#include - #include -#include #include #include @@ -19,91 +16,121 @@ #define DEFAULT_RECV_TIMEOUT_MS 20000 #define DEFAULT_SEND_TIMEOUT_MS 20000 +// How long undelivered messages will hang around when the socket is closed, +// which also determines how long the context will hang for when closing if +// things haven't yet completed (usually only when there's an error). +#define LINGER_MS 100 + namespace faabric::transport { -enum class SocketType -{ - PUSH, - PULL -}; -/* Wrapper arround zmq::socket_t - * - * Thread-unsafe socket-like object. MUST be open-ed and close-ed from the - * _same_ thread. For a proto://host:pair triple, one socket may bind, and all - * the rest must connect. Order does not matter. Sockets either send (PUSH) - * or recv (PULL) data. - */ +// Note: sockets must be open-ed and close-ed from the _same_ thread. In a given +// communication group, one socket may bind, and all the rest must connect. +// Order does not matter. class MessageEndpoint { public: - MessageEndpoint(const std::string& hostIn, int portIn); + MessageEndpoint(const std::string& hostIn, int portIn, int timeoutMsIn); - // Message endpoints shouldn't be assigned as ZeroMQ sockets are not thread - // safe + // Delete assignment and copy-constructor as we need to be very careful with + // socping and same-thread instantiation MessageEndpoint& operator=(const MessageEndpoint&) = delete; - // Neither copied MessageEndpoint(const MessageEndpoint& ctx) = delete; - ~MessageEndpoint(); - - void open(faabric::transport::MessageContext& context, - faabric::transport::SocketType sockTypeIn, - bool bind); + std::string getHost(); - void close(bool bind); + int getPort(); - void send(uint8_t* serialisedMsg, size_t msgSize, bool more = false); + protected: + const std::string host; + const int port; + const std::string address; + const int timeoutMs; + const std::thread::id tid; + const int id; - // If known, pass a size parameter to pre-allocate a recv buffer - Message recv(int size = 0); + zmq::socket_t setUpSocket(zmq::socket_type socketType, int socketPort); - // The MessageEndpointServer needs direct access to the socket - std::unique_ptr socket; + void doSend(zmq::socket_t& socket, + const uint8_t* data, + size_t dataSize, + bool more); - std::string getHost(); + Message doRecv(zmq::socket_t& socket, int size = 0); - int getPort(); + Message recvBuffer(zmq::socket_t& socket, int size); - void setRecvTimeoutMs(int value); + Message recvNoBuffer(zmq::socket_t& socket); +}; - void setSendTimeoutMs(int value); +class AsyncSendMessageEndpoint final : public MessageEndpoint +{ + public: + AsyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); - protected: - const std::string host; - const int port; - std::thread::id tid; - int id; + void sendHeader(int header); - int recvTimeoutMs = DEFAULT_RECV_TIMEOUT_MS; - int sendTimeoutMs = DEFAULT_SEND_TIMEOUT_MS; + void send(const uint8_t* data, size_t dataSize, bool more = false); - void validateTimeout(int value); + private: + zmq::socket_t pushSocket; }; -/* Send and Recv Message Endpoints */ - -class SendMessageEndpoint : public MessageEndpoint +class SyncSendMessageEndpoint final : public MessageEndpoint { public: - SendMessageEndpoint(const std::string& hostIn, int portIn); + SyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); + + void sendHeader(int header); - void open(MessageContext& context); + void sendRaw(const uint8_t* data, size_t dataSize); - void close(); + Message sendAwaitResponse(const uint8_t* data, + size_t dataSize, + bool more = false); + + private: + zmq::socket_t reqSocket; }; class RecvMessageEndpoint : public MessageEndpoint { public: - RecvMessageEndpoint(int portIn); + RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType); + + virtual ~RecvMessageEndpoint(){}; + + virtual Message recv(int size = 0); + + protected: + zmq::socket_t socket; +}; + +class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint +{ + public: + AsyncRecvMessageEndpoint(int portIn, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + Message recv(int size = 0) override; +}; + +class SyncRecvMessageEndpoint final : public RecvMessageEndpoint +{ + public: + SyncRecvMessageEndpoint(int portIn, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - void open(MessageContext& context); + Message recv(int size = 0) override; - void close(); + void sendResponse(const uint8_t* data, int size); }; -class MessageTimeoutException : public faabric::util::FaabricException +class MessageTimeoutException final : public faabric::util::FaabricException { public: explicit MessageTimeoutException(std::string message) diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h index 1cbfbc966..02e945925 100644 --- a/include/faabric/transport/MessageEndpointClient.h +++ b/include/faabric/transport/MessageEndpointClient.h @@ -1,25 +1,42 @@ #pragma once +#include +#include +#include #include -#include namespace faabric::transport { -/* Minimal message endpoint client - * - * Low-level and minimal message endpoint client to run in companion with - * a background-running server. - */ -class MessageEndpointClient : public faabric::transport::SendMessageEndpoint +class MessageEndpointClient { public: - MessageEndpointClient(const std::string& host, int port); - - /* Wait for a message - * - * This method blocks the calling thread until we receive a message from - * the specified host:port pair. When pointed at a server, this method - * allows for blocking communications. - */ - Message awaitResponse(int port); + MessageEndpointClient(std::string hostIn, + int asyncPort, + int syncPort, + int timeoutMs = DEFAULT_SEND_TIMEOUT_MS); + + void asyncSend(int header, google::protobuf::Message* msg); + + void asyncSend(int header, const uint8_t* buffer, size_t bufferSize); + + void syncSend(int header, + google::protobuf::Message* msg, + google::protobuf::Message* response); + + void syncSend(int header, + const uint8_t* buffer, + size_t bufferSize, + google::protobuf::Message* response); + + protected: + const std::string host; + + private: + const int asyncPort; + + const int syncPort; + + faabric::transport::AsyncSendMessageEndpoint asyncEndpoint; + + faabric::transport::SyncSendMessageEndpoint syncEndpoint; }; } diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 1f84d75ff..a7c3f88a2 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -1,72 +1,67 @@ #pragma once +#include #include -#include #include -#include +#include #include -#define ENDPOINT_SERVER_SHUTDOWN -1 - namespace faabric::transport { -/* Server handling a long-running 0MQ socket - * - * This abstract class implements a server-like loop functionality and will - * always run in the background. Note that message endpoints (i.e. 0MQ sockets) - * are _not_ thread safe, must be open-ed and close-ed from the _same_ thread, - * and thus should preferably live in the thread's local address space. - */ -class MessageEndpointServer + +// Each server has two underlying sockets, one for synchronous communication and +// one for asynchronous. Each is run inside its own background thread. +class MessageEndpointServer; + +class MessageEndpointServerThread { public: - MessageEndpointServer(int portIn); + MessageEndpointServerThread(MessageEndpointServer* serverIn, bool asyncIn); + + void start(std::shared_ptr latch); - /* Start and stop the server - * - * Generic methods to start and stop a message endpoint server. They take - * a, thread-safe, 0MQ context as an argument. The stop method will block - * until _all_ sockets within the context have been closed. Sockets blocking - * on a `recv` will be interrupted with ETERM upon context closure. - */ - void start(faabric::transport::MessageContext& context); + void join(); + + private: + MessageEndpointServer* server; + bool async = false; + + std::thread backgroundThread; +}; - void stop(faabric::transport::MessageContext& context); +class MessageEndpointServer +{ + public: + MessageEndpointServer(int asyncPortIn, int syncPortIn); - /* Common start and stop entrypoint - * - * Call the generic methods with the default global message context. - */ - void start(); + virtual void start(); virtual void stop(); + void setAsyncLatch(); + + void awaitAsyncLatch(); + protected: - int recv(faabric::transport::RecvMessageEndpoint& endpoint); - - /* Template function to handle message reception - * - * A message endpoint server in faabric expects each communication to be - * a multi-part 0MQ message. One message containing the header, and another - * one with the body. Note that 0MQ _guarantees_ in-order delivery. - */ - virtual void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) = 0; - - /* Send response to the client - * - * Send a one-off response to a client identified by host:port pair. - * Together with a blocking recv at the client side, this - * method can be used to achieve synchronous client-server communication. - */ - void sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort); + virtual void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) = 0; + + virtual std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) = 0; private: - const int port; + friend class MessageEndpointServerThread; + + const int asyncPort; + const int syncPort; + + MessageEndpointServerThread asyncThread; + MessageEndpointServerThread syncThread; + + AsyncSendMessageEndpoint asyncShutdownSender; + SyncSendMessageEndpoint syncShutdownSender; - std::thread servingThread; + std::shared_ptr asyncLatch; }; } diff --git a/include/faabric/transport/MpiMessageEndpoint.h b/include/faabric/transport/MpiMessageEndpoint.h index 991331bc5..d715a8881 100644 --- a/include/faabric/transport/MpiMessageEndpoint.h +++ b/include/faabric/transport/MpiMessageEndpoint.h @@ -6,33 +6,28 @@ #include namespace faabric::transport { -/* These two abstract methods are used to broadcast the host-rank mapping at - * initialisation time. - */ -faabric::MpiHostsToRanksMessage recvMpiHostRankMsg(); - -void sendMpiHostRankMsg(const std::string& hostIn, - const faabric::MpiHostsToRanksMessage msg); - -/* This class abstracts the notion of a communication channel between two remote - * MPI ranks. There will always be one rank local to this host, and one remote. - * Note that the port is unique per (user, function, sendRank, recvRank) tuple. - */ + +// This class abstracts the notion of a communication channel between two remote +// MPI ranks. There will always be one rank local to this host, and one remote. +// Note that the ports are unique per (user, function, sendRank, recvRank) +// tuple. +// +// This is different to our normal messaging clients as it wraps two +// _async_ sockets. + class MpiMessageEndpoint { public: - MpiMessageEndpoint(const std::string& hostIn, int portIn); - MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort); void sendMpiMessage(const std::shared_ptr& msg); std::shared_ptr recvMpiMessage(); - void close(); - private: - SendMessageEndpoint sendMessageEndpoint; - RecvMessageEndpoint recvMessageEndpoint; + std::string host; + + AsyncSendMessageEndpoint sendSocket; + AsyncRecvMessageEndpoint recvSocket; }; } diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index b7957801e..7ee8ee759 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -3,9 +3,12 @@ #define DEFAULT_STATE_HOST "0.0.0.0" #define DEFAULT_FUNCTION_CALL_HOST "0.0.0.0" #define DEFAULT_SNAPSHOT_HOST "0.0.0.0" -#define STATE_PORT 8003 -#define FUNCTION_CALL_PORT 8004 -#define SNAPSHOT_PORT 8005 -#define REPLY_PORT_OFFSET 100 -#define MPI_PORT 8800 +#define STATE_ASYNC_PORT 8003 +#define STATE_SYNC_PORT 8004 +#define FUNCTION_CALL_ASYNC_PORT 8005 +#define FUNCTION_CALL_SYNC_PORT 8006 +#define SNAPSHOT_ASYNC_PORT 8007 +#define SNAPSHOT_SYNC_PORT 8008 + +#define DEFAULT_MPI_BASE_PORT 8800 diff --git a/include/faabric/transport/context.h b/include/faabric/transport/context.h new file mode 100644 index 000000000..4e608254d --- /dev/null +++ b/include/faabric/transport/context.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +// We specify a number of background I/O threads when constructing the ZeroMQ +// context. Guidelines on how to scale this can be found here: +// https://zguide.zeromq.org/docs/chapter2/#I-O-Threads + +#define ZMQ_CONTEXT_IO_THREADS 1 + +namespace faabric::transport { + +void initGlobalMessageContext(); + +std::shared_ptr getGlobalMessageContext(); + +void closeGlobalMessageContext(); + +} diff --git a/include/faabric/transport/macros.h b/include/faabric/transport/macros.h index 6dfac596e..c24d5d125 100644 --- a/include/faabric/transport/macros.h +++ b/include/faabric/transport/macros.h @@ -1,45 +1,36 @@ #pragma once -#define SEND_MESSAGE(header, msg) \ - sendHeader(header); \ +#define PARSE_MSG(T, data, size) \ + T msg; \ + if (!msg.ParseFromArray(data, size)) { \ + throw std::runtime_error("Error deserialising message"); \ + } + +#define SERIALISE_MSG(msg) \ size_t msgSize = msg.ByteSizeLong(); \ - { \ - uint8_t sMsg[msgSize]; \ - if (!msg.SerializeToArray(sMsg, msgSize)) { \ - throw std::runtime_error("Error serialising message"); \ - } \ - send(sMsg, msgSize); \ + uint8_t buffer[msgSize]; \ + if (!msg.SerializeToArray(buffer, msgSize)) { \ + throw std::runtime_error("Error serialising message"); \ } -#define SEND_MESSAGE_PTR(header, msg) \ - sendHeader(header); \ +#define SERIALISE_MSG_PTR(msg) \ size_t msgSize = msg->ByteSizeLong(); \ - { \ - uint8_t sMsg[msgSize]; \ - if (!msg->SerializeToArray(sMsg, msgSize)) { \ - throw std::runtime_error("Error serialising message"); \ - } \ - send(sMsg, msgSize); \ + uint8_t buffer[msgSize]; \ + if (!msg->SerializeToArray(buffer, msgSize)) { \ + throw std::runtime_error("Error serialising message"); \ } -#define SEND_SERVER_RESPONSE(msg, host, port) \ - size_t msgSize = msg.ByteSizeLong(); \ +#define SEND_FB_MSG(T, mb) \ { \ - uint8_t sMsg[msgSize]; \ - if (!msg.SerializeToArray(sMsg, msgSize)) { \ - throw std::runtime_error("Error serialising message"); \ - } \ - sendResponse(sMsg, msgSize, host, port); \ + const uint8_t* buffer = mb.GetBufferPointer(); \ + int size = mb.GetSize(); \ + faabric::EmptyResponse response; \ + syncSend(T, buffer, size, &response); \ } -#define PARSE_MSG(T, data, size) \ - T msg; \ - if (!msg.ParseFromArray(data, size)) { \ - throw std::runtime_error("Error deserialising message"); \ - } - -#define PARSE_RESPONSE(T, data, size) \ - T response; \ - if (!response.ParseFromArray(data, size)) { \ - throw std::runtime_error("Error deserialising message"); \ +#define SEND_FB_MSG_ASYNC(T, mb) \ + { \ + const uint8_t* buffer = mb.GetBufferPointer(); \ + int size = mb.GetSize(); \ + asyncSend(T, buffer, size); \ } diff --git a/include/faabric/util/barrier.h b/include/faabric/util/barrier.h deleted file mode 100644 index b5afaf276..000000000 --- a/include/faabric/util/barrier.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include -#include - -namespace faabric::util { - -class Barrier -{ - public: - explicit Barrier(int count); - - void wait(); - - int getSlotCount(); - - int getUseCount(); - - private: - int threadCount; - int slotCount; - int uses; - std::mutex mx; - std::condition_variable cv; -}; -} diff --git a/include/faabric/util/latch.h b/include/faabric/util/latch.h new file mode 100644 index 000000000..b87419fe4 --- /dev/null +++ b/include/faabric/util/latch.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace faabric::util { + +#define DEFAULT_LATCH_TIMEOUT_MS 10000 + +class Latch +{ + public: + // WARNING: this latch must be shared between threads using a shared + // pointer, otherwise there seems to be some nasty race conditions related + // to its destruction. + static std::shared_ptr create( + int count, + int timeoutMs = DEFAULT_LATCH_TIMEOUT_MS); + + explicit Latch(int countIn, int timeoutMsIn); + + void wait(); + + private: + int count; + int waiters = 0; + + int timeoutMs; + + std::mutex mx; + std::condition_variable cv; +}; +} diff --git a/include/faabric/util/macros.h b/include/faabric/util/macros.h index 7b865679f..504e4be40 100644 --- a/include/faabric/util/macros.h +++ b/include/faabric/util/macros.h @@ -1,6 +1,10 @@ #pragma once +#include + #define BYTES(arr) reinterpret_cast(arr) #define BYTES_CONST(arr) reinterpret_cast(arr) +#define SLEEP_MS(ms) usleep((ms)*1000) + #define UNUSED(x) (void)(x) diff --git a/src/flat/faabric.fbs b/src/flat/faabric.fbs index d173a2d7d..80bf93568 100644 --- a/src/flat/faabric.fbs +++ b/src/flat/faabric.fbs @@ -1,5 +1,4 @@ table SnapshotPushRequest { - return_host:string; key:string; contents:[ubyte]; } @@ -14,7 +13,6 @@ table SnapshotDiffChunk { } table SnapshotDiffPushRequest { - return_host:string; key:string; chunks:[SnapshotDiffChunk]; } diff --git a/src/mpi_native/MpiExecutor.cpp b/src/mpi_native/MpiExecutor.cpp index edf7b35ab..0427830df 100644 --- a/src/mpi_native/MpiExecutor.cpp +++ b/src/mpi_native/MpiExecutor.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace faabric::mpi_native { @@ -26,6 +27,8 @@ int32_t MpiExecutor::executeTask( int mpiNativeMain(int argc, char** argv) { + faabric::transport::initGlobalMessageContext(); + auto& scheduler = faabric::scheduler::getScheduler(); auto& conf = faabric::util::getSystemConfig(); @@ -54,6 +57,8 @@ int mpiNativeMain(int argc, char** argv) scheduler.callFunction(msg); } + faabric::transport::closeGlobalMessageContext(); + return 0; } } diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index a44a4668c..296f4e232 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -10,6 +10,10 @@ message EmptyResponse { int32 empty = 1; } +message EmptyRequest { + int32 empty = 1; +} + // --------------------------------------------- // FUNCTION SCHEDULING // --------------------------------------------- @@ -35,10 +39,6 @@ message BatchExecuteRequest { bytes contextData = 7; } -message ResponseRequest { - string returnHost = 1; -} - message HostResources { int32 slots = 1; int32 usedSlots = 2; @@ -157,7 +157,6 @@ message StateRequest { string user = 1; string key = 2; bytes data = 3; - string returnHost = 4; } message StateChunkRequest { @@ -165,7 +164,6 @@ message StateChunkRequest { string key = 2; uint64 offset = 3; uint64 chunkSize = 4; - string returnHost = 5; } message StateResponse { @@ -179,7 +177,6 @@ message StatePart { string key = 2; uint64 offset = 3; bytes data = 4; - string returnHost = 5; } message StateSizeResponse { @@ -192,7 +189,6 @@ message StateAppendedRequest { string user = 1; string key = 2; uint32 nValues = 3; - string returnHost = 4; } message StateAppendedResponse { diff --git a/src/runner/FaabricMain.cpp b/src/runner/FaabricMain.cpp index ed9764b58..18625cfe4 100644 --- a/src/runner/FaabricMain.cpp +++ b/src/runner/FaabricMain.cpp @@ -79,9 +79,6 @@ void FaabricMain::shutdown() { SPDLOG_INFO("Removing from global working set"); - auto& sch = faabric::scheduler::getScheduler(); - sch.shutdown(); - SPDLOG_INFO("Waiting for the state server to finish"); stateServer.stop(); @@ -91,6 +88,9 @@ void FaabricMain::shutdown() SPDLOG_INFO("Waiting for the snapshot server to finish"); snapshotServer.stop(); + auto& sch = faabric::scheduler::getScheduler(); + sch.shutdown(); + SPDLOG_INFO("Faabric pool successfully shut down"); } } diff --git a/src/scheduler/Executor.cpp b/src/scheduler/Executor.cpp index 0551eb80f..8fb228cee 100644 --- a/src/scheduler/Executor.cpp +++ b/src/scheduler/Executor.cpp @@ -298,6 +298,10 @@ void Executor::threadPoolThread(int threadPoolIdx) sch.notifyExecutorShutdown(this, boundMessage); } } + + // We have to clean up TLS here as this should be the last use of the + // scheduler from this thread + sch.resetThreadLocalCache(); } bool Executor::tryClaim() diff --git a/src/scheduler/FunctionCallClient.cpp b/src/scheduler/FunctionCallClient.cpp index 27e974151..3492e04f2 100644 --- a/src/scheduler/FunctionCallClient.cpp +++ b/src/scheduler/FunctionCallClient.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -13,13 +14,13 @@ std::mutex mockMutex; static std::vector> functionCalls; -static std::vector> flushCalls; +static std::vector> flushCalls; static std::vector< std::pair>> batchMessages; -static std::vector> +static std::vector> resourceRequests; static std::unordered_map> getFunctionCalls() return functionCalls; } -std::vector> getFlushCalls() +std::vector> getFlushCalls() { return flushCalls; } @@ -46,8 +47,7 @@ getBatchRequests() return batchMessages; } -std::vector> -getResourceRequests() +std::vector> getResourceRequests() { return resourceRequests; } @@ -80,38 +80,28 @@ void clearMockRequests() // Message Client // ----------------------------------- FunctionCallClient::FunctionCallClient(const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, FUNCTION_CALL_PORT) -{ - this->open(faabric::transport::getGlobalMessageContext()); -} - -void FunctionCallClient::sendHeader(faabric::scheduler::FunctionCalls call) -{ - uint8_t header = static_cast(call); - send(&header, sizeof(header), true); -} + : faabric::transport::MessageEndpointClient(hostIn, + FUNCTION_CALL_ASYNC_PORT, + FUNCTION_CALL_SYNC_PORT) +{} void FunctionCallClient::sendFlush() { - faabric::ResponseRequest call; + faabric::EmptyRequest req; if (faabric::util::isMockMode()) { faabric::util::UniqueLock lock(mockMutex); - flushCalls.emplace_back(host, call); + flushCalls.emplace_back(host, req); } else { - // Prepare the message body - call.set_returnhost(faabric::util::getSystemConfig().endpointHost); - - SEND_MESSAGE(faabric::scheduler::FunctionCalls::Flush, call); - - // Await the response - awaitResponse(FUNCTION_CALL_PORT + REPLY_PORT_OFFSET); + faabric::EmptyResponse resp; + syncSend(faabric::scheduler::FunctionCalls::Flush, &req, &resp); } } faabric::HostResources FunctionCallClient::getResources() { - faabric::ResponseRequest request; + faabric::EmptyRequest request; faabric::HostResources response; + if (faabric::util::isMockMode()) { faabric::util::UniqueLock lock(mockMutex); @@ -123,17 +113,8 @@ faabric::HostResources FunctionCallClient::getResources() response = queuedResourceResponses[host].dequeue(); } } else { - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - - SEND_MESSAGE(faabric::scheduler::FunctionCalls::GetResources, request); - - // Receive message - faabric::transport::Message msg = - awaitResponse(FUNCTION_CALL_PORT + REPLY_PORT_OFFSET); - // Deserialise message string - if (!response.ParseFromArray(msg.data(), msg.size())) { - throw std::runtime_error("Error deserialising message"); - } + syncSend( + faabric::scheduler::FunctionCalls::GetResources, &request, &response); } return response; @@ -146,18 +127,18 @@ void FunctionCallClient::executeFunctions( faabric::util::UniqueLock lock(mockMutex); batchMessages.emplace_back(host, req); } else { - SEND_MESSAGE_PTR(faabric::scheduler::FunctionCalls::ExecuteFunctions, - req); + asyncSend(faabric::scheduler::FunctionCalls::ExecuteFunctions, + req.get()); } } -void FunctionCallClient::unregister(const faabric::UnregisterRequest& req) +void FunctionCallClient::unregister(faabric::UnregisterRequest& req) { if (faabric::util::isMockMode()) { faabric::util::UniqueLock lock(mockMutex); unregisterRequests.emplace_back(host, req); } else { - SEND_MESSAGE(faabric::scheduler::FunctionCalls::Unregister, req); + asyncSend(faabric::scheduler::FunctionCalls::Unregister, &req); } } } diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index e17a13b06..dcede0518 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -8,69 +9,77 @@ namespace faabric::scheduler { FunctionCallServer::FunctionCallServer() - : faabric::transport::MessageEndpointServer(FUNCTION_CALL_PORT) + : faabric::transport::MessageEndpointServer(FUNCTION_CALL_ASYNC_PORT, + FUNCTION_CALL_SYNC_PORT) , scheduler(getScheduler()) {} -void FunctionCallServer::stop() +void FunctionCallServer::doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) { - // Close the dangling scheduler endpoints - faabric::scheduler::getScheduler().closeFunctionCallClients(); - - // Call the parent stop - MessageEndpointServer::stop(faabric::transport::getGlobalMessageContext()); -} - -void FunctionCallServer::doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) -{ - assert(header.size() == sizeof(uint8_t)); - uint8_t call = static_cast(*header.data()); - switch (call) { - case faabric::scheduler::FunctionCalls::Flush: - this->recvFlush(body); - break; - case faabric::scheduler::FunctionCalls::ExecuteFunctions: - this->recvExecuteFunctions(body); - break; - case faabric::scheduler::FunctionCalls::Unregister: - this->recvUnregister(body); + switch (header) { + case faabric::scheduler::FunctionCalls::ExecuteFunctions: { + recvExecuteFunctions(buffer, bufferSize); break; - case faabric::scheduler::FunctionCalls::GetResources: - this->recvGetResources(body); + } + case faabric::scheduler::FunctionCalls::Unregister: { + recvUnregister(buffer, bufferSize); break; - default: + } + default: { throw std::runtime_error( - fmt::format("Unrecognized call header: {}", call)); + fmt::format("Unrecognized async call header: {}", header)); + } } } -void FunctionCallServer::recvFlush(faabric::transport::Message& body) +std::unique_ptr FunctionCallServer::doSyncRecv( + int header, + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::ResponseRequest, body.data(), body.size()); + switch (header) { + case faabric::scheduler::FunctionCalls::Flush: { + return recvFlush(buffer, bufferSize); + } + case faabric::scheduler::FunctionCalls::GetResources: { + return recvGetResources(buffer, bufferSize); + } + default: { + throw std::runtime_error( + fmt::format("Unrecognized sync call header: {}", header)); + } + } +} +std::unique_ptr FunctionCallServer::recvFlush( + const uint8_t* buffer, + size_t bufferSize) +{ // Clear out any cached state faabric::state::getGlobalState().forceClearAll(false); // Clear the scheduler scheduler.flushLocally(); - faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, msg.returnhost(), FUNCTION_CALL_PORT) + return std::make_unique(); } -void FunctionCallServer::recvExecuteFunctions(faabric::transport::Message& body) +void FunctionCallServer::recvExecuteFunctions(const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::BatchExecuteRequest, body.data(), body.size()) + PARSE_MSG(faabric::BatchExecuteRequest, buffer, bufferSize) // This host has now been told to execute these functions no matter what scheduler.callFunctions(std::make_shared(msg), true); } -void FunctionCallServer::recvUnregister(faabric::transport::Message& body) +void FunctionCallServer::recvUnregister(const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::UnregisterRequest, body.data(), body.size()) + PARSE_MSG(faabric::UnregisterRequest, buffer, bufferSize) std::string funcStr = faabric::util::funcToString(msg.function(), false); SPDLOG_DEBUG("Unregistering host {} for {}", msg.host(), funcStr); @@ -79,12 +88,12 @@ void FunctionCallServer::recvUnregister(faabric::transport::Message& body) scheduler.removeRegisteredHost(msg.host(), msg.function()); } -void FunctionCallServer::recvGetResources(faabric::transport::Message& body) +std::unique_ptr FunctionCallServer::recvGetResources( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::ResponseRequest, body.data(), body.size()) - - // Send the response body - faabric::HostResources response = scheduler.getThisHostResources(); - SEND_SERVER_RESPONSE(response, msg.returnhost(), FUNCTION_CALL_PORT) + auto response = std::make_unique( + scheduler.getThisHostResources()); + return response; } } diff --git a/src/scheduler/MpiContext.cpp b/src/scheduler/MpiContext.cpp index c279f5be4..4674da0b1 100644 --- a/src/scheduler/MpiContext.cpp +++ b/src/scheduler/MpiContext.cpp @@ -25,7 +25,10 @@ int MpiContext::createWorld(const faabric::Message& msg) // Create the MPI world scheduler::MpiWorldRegistry& reg = scheduler::getMpiWorldRegistry(); - reg.createWorld(msg, worldId); + scheduler::MpiWorld& world = reg.createWorld(msg, worldId); + + // Broadcast setup to other hosts + world.broadcastHostsToRanks(); // Set up this context isMpi = true; diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index e0a214fdd..33cfa1979 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -4,28 +4,87 @@ #include #include #include +#include -/* Each MPI rank runs in a separate thread, thus we use TLS to maintain the - * per-rank data structures. - */ +// Each MPI rank runs in a separate thread, thus we use TLS to maintain the +// per-rank data structures static thread_local std::vector< std::unique_ptr> mpiMessageEndpoints; + static thread_local std::vector< std::shared_ptr> unackedMessageBuffers; + static thread_local std::set iSendRequests; + static thread_local std::map> reqIdToRanks; +// These long-lived sockets are used by each world to communicate rank-to-host +// mappings. They are thread-local to ensure separation between concurrent +// worlds executing on the same host +static thread_local std::unique_ptr< + faabric::transport::AsyncRecvMessageEndpoint> + ranksRecvEndpoint; + +static thread_local std::unordered_map< + std::string, + std::unique_ptr> + ranksSendEndpoints; + +// This is used for mocking in tests +static std::vector rankMessages; + namespace faabric::scheduler { + MpiWorld::MpiWorld() - : id(-1) - , size(-1) - , thisHost(faabric::util::getSystemConfig().endpointHost) + : thisHost(faabric::util::getSystemConfig().endpointHost) , creationTime(faabric::util::startTimer()) , cartProcsPerDim(2) {} +faabric::MpiHostsToRanksMessage MpiWorld::recvMpiHostRankMsg() +{ + if (faabric::util::isMockMode()) { + assert(!rankMessages.empty()); + faabric::MpiHostsToRanksMessage msg = rankMessages.back(); + rankMessages.pop_back(); + return msg; + } + + if (ranksRecvEndpoint == nullptr) { + ranksRecvEndpoint = + std::make_unique( + basePort); + } + + SPDLOG_TRACE("Receiving MPI host ranks on {}", basePort); + faabric::transport::Message m = ranksRecvEndpoint->recv(); + PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); + + return msg; +} + +void MpiWorld::sendMpiHostRankMsg(const std::string& hostIn, + const faabric::MpiHostsToRanksMessage msg) +{ + if (faabric::util::isMockMode()) { + rankMessages.push_back(msg); + return; + } + + if (ranksSendEndpoints.find(hostIn) == ranksSendEndpoints.end()) { + ranksSendEndpoints.emplace( + hostIn, + std::make_unique( + hostIn, basePort)); + } + + SPDLOG_TRACE("Sending MPI host ranks to {}:{}", hostIn, basePort); + SERIALISE_MSG(msg) + ranksSendEndpoints[hostIn]->send(buffer, msgSize, false); +} + void MpiWorld::initRemoteMpiEndpoint(int localRank, int remoteRank) { SPDLOG_TRACE("Open MPI endpoint between ranks (local-remote) {} - {}", @@ -146,43 +205,50 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) // Prepend this host for rank 0 executedAt.insert(executedAt.begin(), thisHost); + // Record rank-to-host mapping and base ports + rankHosts = executedAt; + basePorts = initLocalBasePorts(executedAt); + + // Initialise the memory queues for message reception + initLocalQueues(); +} + +void MpiWorld::broadcastHostsToRanks() +{ + // Set up a list of hosts to broadcast to (excluding this host) + std::set targetHosts(rankHosts.begin(), rankHosts.end()); + targetHosts.erase(thisHost); + + if (targetHosts.empty()) { + SPDLOG_DEBUG("Not broadcasting rank-to-host mapping, no other hosts"); + return; + } + // Register hosts to rank mappings on this host faabric::MpiHostsToRanksMessage hostRankMsg; - *hostRankMsg.mutable_hosts() = { executedAt.begin(), executedAt.end() }; + *hostRankMsg.mutable_hosts() = { rankHosts.begin(), rankHosts.end() }; // Prepare the base port for each rank - std::vector basePortForRank = initLocalBasePorts(executedAt); - *hostRankMsg.mutable_baseports() = { basePortForRank.begin(), - basePortForRank.end() }; - - // Register hosts to rank mappins on this host - setAllRankHostsPorts(hostRankMsg); - - // Set up a list of hosts to broadcast to (excluding this host) - std::set hosts(executedAt.begin(), executedAt.end()); - hosts.erase(thisHost); + *hostRankMsg.mutable_baseports() = { basePorts.begin(), basePorts.end() }; // Do the broadcast - for (const auto& h : hosts) { - faabric::transport::sendMpiHostRankMsg(h, hostRankMsg); + for (const auto& h : targetHosts) { + sendMpiHostRankMsg(h, hostRankMsg); } - - // Initialise the memory queues for message reception - initLocalQueues(); } void MpiWorld::destroy() { - // Destroy once per thread the rank-specific data structures - // Remote message endpoints - if (!mpiMessageEndpoints.empty()) { - for (auto& e : mpiMessageEndpoints) { - if (e != nullptr) { - e->close(); - } - } - mpiMessageEndpoints.clear(); - } + SPDLOG_TRACE("Destroying MPI world {}", id); + + // Note that all ranks will call this function. + + // We must force the destructors for all message endpoints to run here + // rather than at the end of their global thread-local lifespan. If we + // don't, the ZMQ shutdown can hang. + mpiMessageEndpoints.clear(); + ranksRecvEndpoint = nullptr; + ranksSendEndpoints.clear(); // Unacked message buffers if (!unackedMessageBuffers.empty()) { @@ -217,25 +283,33 @@ void MpiWorld::destroy() } } -void MpiWorld::initialiseFromMsg(const faabric::Message& msg, bool forceLocal) +void MpiWorld::initialiseFromMsg(const faabric::Message& msg) { id = msg.mpiworldid(); user = msg.user(); function = msg.function(); size = msg.mpiworldsize(); - // Sometimes for testing purposes we may want to initialise a world in the - // _same_ host we have created one (note that this would never happen in - // reality). If so, we skip initialising resources already initialised - if (!forceLocal) { - // Block until we receive - faabric::MpiHostsToRanksMessage hostRankMsg = - faabric::transport::recvMpiHostRankMsg(); - setAllRankHostsPorts(hostRankMsg); - - // Initialise the memory queues for message reception - initLocalQueues(); - } + // Block until we receive + faabric::MpiHostsToRanksMessage hostRankMsg = recvMpiHostRankMsg(); + + // Prepare the host-rank map with a vector containing _all_ ranks + // Note - this method should be called by only one rank. This is + // enforced in the world registry. + + // Assert we are only setting the values once + assert(rankHosts.empty()); + assert(basePorts.empty()); + + assert(hostRankMsg.hosts().size() == size); + assert(hostRankMsg.baseports().size() == size); + + rankHosts = { hostRankMsg.hosts().begin(), hostRankMsg.hosts().end() }; + basePorts = { hostRankMsg.baseports().begin(), + hostRankMsg.baseports().end() }; + + // Initialise the memory queues for message reception + initLocalQueues(); } std::string MpiWorld::getHostForRank(int rank) @@ -283,21 +357,6 @@ std::pair MpiWorld::getPortForRanks(int localRank, int remoteRank) return sendRecvPortPair; } -// Prepare the host-rank map with a vector containing _all_ ranks -// Note - this method should be called by only one rank. This is enforced in -// the world registry -void MpiWorld::setAllRankHostsPorts(const faabric::MpiHostsToRanksMessage& msg) -{ - // Assert we are only setting the values once - assert(rankHosts.size() == 0); - assert(basePorts.size() == 0); - - assert(msg.hosts().size() == size); - assert(msg.baseports().size() == size); - rankHosts = { msg.hosts().begin(), msg.hosts().end() }; - basePorts = { msg.baseports().begin(), msg.baseports().end() }; -} - void MpiWorld::getCartesianRank(int rank, int maxDims, const int* dims, @@ -1224,10 +1283,10 @@ std::vector MpiWorld::initLocalBasePorts( basePortForRank.reserve(size); std::string lastHost = thisHost; - int lastPort = MPI_PORT; + int lastPort = basePort; for (const auto& host : executedAt) { if (host == thisHost) { - basePortForRank.push_back(MPI_PORT); + basePortForRank.push_back(basePort); } else if (host == lastHost) { basePortForRank.push_back(lastPort); } else { diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 8f271edc1..79c2a8515 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -22,6 +22,18 @@ using namespace faabric::util; namespace faabric::scheduler { +// 0MQ sockets are not thread-safe, and opening them and closing them from +// different threads messes things up. However, we don't want to constatnly +// create and recreate them to make calls in the scheduler, therefore we cache +// them in TLS, and perform thread-specific tidy-up. +static thread_local std::unordered_map + functionCallClients; + +static thread_local std::unordered_map + snapshotClients; + Scheduler& getScheduler() { static Scheduler sch; @@ -61,24 +73,21 @@ void Scheduler::addHostToGlobalSet() redis.sadd(AVAILABLE_HOST_SET, thisHost); } -void Scheduler::closeFunctionCallClients() +void Scheduler::resetThreadLocalCache() { - for (auto& iter : functionCallClients) { - iter.second.close(); - } - functionCallClients.clear(); -} + auto tid = (pid_t)syscall(SYS_gettid); + SPDLOG_DEBUG("Resetting scheduler thread-local cache for thread {}", tid); -void Scheduler::closeSnapshotClients() -{ - for (auto& iter : snapshotClients) { - iter.second.close(); - } + functionCallClients.clear(); snapshotClients.clear(); } void Scheduler::reset() { + SPDLOG_DEBUG("Resetting scheduler"); + + resetThreadLocalCache(); + // Shut down all Executors for (auto& p : executors) { for (auto& e : p.second) { @@ -109,9 +118,6 @@ void Scheduler::reset() recordedMessagesAll.clear(); recordedMessagesLocal.clear(); recordedMessagesShared.clear(); - - closeFunctionCallClients(); - closeSnapshotClients(); } void Scheduler::shutdown() @@ -559,54 +565,25 @@ std::vector Scheduler::getRecordedMessagesLocal() return recordedMessagesLocal; } -std::string getClientKey(const std::string& otherHost) -{ - // Note, our keys here have to include the tid as the clients can only be - // used within the same thread. - std::thread::id tid = std::this_thread::get_id(); - std::stringstream ss; - ss << otherHost << "_" << tid; - std::string key = ss.str(); - return key; -} - FunctionCallClient& Scheduler::getFunctionCallClient( const std::string& otherHost) { - std::string key = getClientKey(otherHost); - if (functionCallClients.find(key) == functionCallClients.end()) { - faabric::util::FullLock lock(functionCallClientsMx); - - if (functionCallClients.find(key) == functionCallClients.end()) { - SPDLOG_DEBUG( - "Adding new function call client for {} ({})", otherHost, key); - functionCallClients.emplace(key, otherHost); - } + if (functionCallClients.find(otherHost) == functionCallClients.end()) { + SPDLOG_DEBUG("Adding new function call client for {}", otherHost); + functionCallClients.emplace(otherHost, otherHost); } - { - faabric::util::SharedLock lock(functionCallClientsMx); - return functionCallClients.at(key); - } + return functionCallClients.at(otherHost); } SnapshotClient& Scheduler::getSnapshotClient(const std::string& otherHost) { - std::string key = getClientKey(otherHost); - if (snapshotClients.find(key) == snapshotClients.end()) { - faabric::util::FullLock lock(snapshotClientsMx); - - if (snapshotClients.find(key) == snapshotClients.end()) { - SPDLOG_DEBUG( - "Adding new snapshot client for {} ({})", otherHost, key); - snapshotClients.emplace(key, otherHost); - } + if (snapshotClients.find(otherHost) == snapshotClients.end()) { + SPDLOG_DEBUG("Adding new snapshot client for {}", otherHost); + snapshotClients.emplace(otherHost, otherHost); } - { - faabric::util::SharedLock lock(snapshotClientsMx); - return snapshotClients.at(key); - } + return snapshotClients.at(otherHost); } std::vector> diff --git a/src/scheduler/SnapshotClient.cpp b/src/scheduler/SnapshotClient.cpp index 7e99bc0c4..d3f80bc87 100644 --- a/src/scheduler/SnapshotClient.cpp +++ b/src/scheduler/SnapshotClient.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include #include @@ -68,24 +70,11 @@ void clearMockSnapshotRequests() // Snapshot client // ----------------------------------- -#define SEND_FB_REQUEST(T) \ - sendHeader(T); \ - mb.Finish(requestOffset); \ - uint8_t* buffer = mb.GetBufferPointer(); \ - int size = mb.GetSize(); \ - send(buffer, size); - SnapshotClient::SnapshotClient(const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, SNAPSHOT_PORT) -{ - this->open(faabric::transport::getGlobalMessageContext()); -} - -void SnapshotClient::sendHeader(faabric::scheduler::SnapshotCalls call) -{ - uint8_t header = static_cast(call); - send(&header, sizeof(header), true); -} + : faabric::transport::MessageEndpointClient(hostIn, + SNAPSHOT_ASYNC_PORT, + SNAPSHOT_SYNC_PORT) +{} void SnapshotClient::pushSnapshot(const std::string& key, const faabric::util::SnapshotData& data) @@ -96,23 +85,17 @@ void SnapshotClient::pushSnapshot(const std::string& key, faabric::util::UniqueLock lock(mockMutex); snapshotPushes.emplace_back(host, data); } else { - const faabric::util::SystemConfig& conf = - faabric::util::getSystemConfig(); - // Set up the main request // TODO - avoid copying data here flatbuffers::FlatBufferBuilder mb; - auto returnHostOffset = mb.CreateString(conf.endpointHost); auto keyOffset = mb.CreateString(key); auto dataOffset = mb.CreateVector(data.data, data.size); - auto requestOffset = CreateSnapshotPushRequest( - mb, returnHostOffset, keyOffset, dataOffset); + auto requestOffset = + CreateSnapshotPushRequest(mb, keyOffset, dataOffset); + mb.Finish(requestOffset); // Send it - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::PushSnapshot) - - // Await a response as this call must be synchronous - awaitResponse(SNAPSHOT_PORT + REPLY_PORT_OFFSET); + SEND_FB_MSG(SnapshotCalls::PushSnapshot, mb) } } @@ -129,9 +112,6 @@ void SnapshotClient::pushSnapshotDiffs( snapshotKey, host); - const faabric::util::SystemConfig& conf = - faabric::util::getSystemConfig(); - flatbuffers::FlatBufferBuilder mb; // Create objects for all the chunks @@ -144,17 +124,13 @@ void SnapshotClient::pushSnapshotDiffs( // Set up the main request // TODO - avoid copying data here - auto returnHostOffset = mb.CreateString(conf.endpointHost); auto keyOffset = mb.CreateString(snapshotKey); auto diffsOffset = mb.CreateVector(diffsFbVector); - auto requestOffset = CreateSnapshotDiffPushRequest( - mb, returnHostOffset, keyOffset, diffsOffset); - - // Send it - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::PushSnapshotDiffs) + auto requestOffset = + CreateSnapshotDiffPushRequest(mb, keyOffset, diffsOffset); + mb.Finish(requestOffset); - // Await a response as this call must be synchronous - awaitResponse(SNAPSHOT_PORT + REPLY_PORT_OFFSET); + SEND_FB_MSG(SnapshotCalls::PushSnapshotDiffs, mb); } } @@ -171,8 +147,9 @@ void SnapshotClient::deleteSnapshot(const std::string& key) flatbuffers::FlatBufferBuilder mb; auto keyOffset = mb.CreateString(key); auto requestOffset = CreateSnapshotDeleteRequest(mb, keyOffset); + mb.Finish(requestOffset); - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::DeleteSnapshot) + SEND_FB_MSG_ASYNC(SnapshotCalls::DeleteSnapshot, mb); } } @@ -229,7 +206,8 @@ void SnapshotClient::pushThreadResult( CreateThreadResultRequest(mb, messageId, returnValue); } - SEND_FB_REQUEST(faabric::scheduler::SnapshotCalls::ThreadResult) + mb.Finish(requestOffset); + SEND_FB_MSG_ASYNC(SnapshotCalls::ThreadResult, mb) } } } diff --git a/src/scheduler/SnapshotServer.cpp b/src/scheduler/SnapshotServer.cpp index 4312c0568..f220b5ed1 100644 --- a/src/scheduler/SnapshotServer.cpp +++ b/src/scheduler/SnapshotServer.cpp @@ -3,52 +3,62 @@ #include #include #include +#include #include #include #include +#include + namespace faabric::scheduler { SnapshotServer::SnapshotServer() - : faabric::transport::MessageEndpointServer(SNAPSHOT_PORT) + : faabric::transport::MessageEndpointServer(SNAPSHOT_ASYNC_PORT, + SNAPSHOT_SYNC_PORT) {} -void SnapshotServer::stop() +void SnapshotServer::doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) { - // Close the dangling clients - faabric::scheduler::getScheduler().closeSnapshotClients(); - - // Call the parent stop - MessageEndpointServer::stop(faabric::transport::getGlobalMessageContext()); + switch (header) { + case faabric::scheduler::SnapshotCalls::DeleteSnapshot: { + this->recvDeleteSnapshot(buffer, bufferSize); + break; + } + case faabric::scheduler::SnapshotCalls::ThreadResult: { + this->recvThreadResult(buffer, bufferSize); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unrecognized async call header: {}", header)); + } + } } -void SnapshotServer::doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +std::unique_ptr +SnapshotServer::doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) { - assert(header.size() == sizeof(uint8_t)); - uint8_t call = static_cast(*header.data()); - switch (call) { - case faabric::scheduler::SnapshotCalls::PushSnapshot: - this->recvPushSnapshot(body); - break; - case faabric::scheduler::SnapshotCalls::PushSnapshotDiffs: - this->recvPushSnapshotDiffs(body); - break; - case faabric::scheduler::SnapshotCalls::DeleteSnapshot: - this->recvDeleteSnapshot(body); - break; - case faabric::scheduler::SnapshotCalls::ThreadResult: - this->recvThreadResult(body); - break; - default: + switch (header) { + case faabric::scheduler::SnapshotCalls::PushSnapshot: { + return recvPushSnapshot(buffer, bufferSize); + } + case faabric::scheduler::SnapshotCalls::PushSnapshotDiffs: { + return recvPushSnapshotDiffs(buffer, bufferSize); + } + default: { throw std::runtime_error( - fmt::format("Unrecognized call header: {}", call)); + fmt::format("Unrecognized sync call header: {}", header)); + } } } -void SnapshotServer::recvPushSnapshot(faabric::transport::Message& msg) +std::unique_ptr SnapshotServer::recvPushSnapshot( + const uint8_t* buffer, + size_t bufferSize) { - SnapshotPushRequest* r = - flatbuffers::GetMutableRoot(msg.udata()); + const SnapshotPushRequest* r = + flatbuffers::GetRoot(buffer); SPDLOG_DEBUG("Receiving shapshot {} (size {})", r->key()->c_str(), @@ -60,22 +70,25 @@ void SnapshotServer::recvPushSnapshot(faabric::transport::Message& msg) // Set up the snapshot faabric::util::SnapshotData data; data.size = r->contents()->size(); - data.data = r->mutable_contents()->Data(); - reg.takeSnapshot(r->key()->str(), data, true); - // Note that now the snapshot data is owned by Faabric and will be deleted - // later, so we don't want the message to delete it - msg.persist(); + // TODO - avoid this copy by changing server superclass to allow subclasses + // to provide a buffer to receive data. + // TODO - work out snapshot ownership here, how do we know when to delete + // this data? + data.data = (uint8_t*)mmap( + nullptr, data.size, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + std::memcpy(data.data, r->contents()->Data(), data.size); + + reg.takeSnapshot(r->key()->str(), data, true); // Send response - faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, r->return_host()->str(), SNAPSHOT_PORT) + return std::make_unique(); } -void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) +void SnapshotServer::recvThreadResult(const uint8_t* buffer, size_t bufferSize) { const ThreadResultRequest* r = - flatbuffers::GetMutableRoot(msg.udata()); + flatbuffers::GetRoot(buffer); // Apply snapshot diffs *first* (these must be applied before other threads // can continue) @@ -91,16 +104,16 @@ void SnapshotServer::recvThreadResult(faabric::transport::Message& msg) sch.setThreadResultLocally(r->message_id(), r->return_value()); } -void SnapshotServer::recvPushSnapshotDiffs(faabric::transport::Message& msg) +std::unique_ptr +SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) { const SnapshotDiffPushRequest* r = - flatbuffers::GetMutableRoot(msg.udata()); + flatbuffers::GetRoot(buffer); applyDiffsToSnapshot(r->key()->str(), r->chunks()); // Send response - faabric::EmptyResponse response; - SEND_SERVER_RESPONSE(response, r->return_host()->str(), SNAPSHOT_PORT) + return std::make_unique(); } void SnapshotServer::applyDiffsToSnapshot( @@ -123,10 +136,11 @@ void SnapshotServer::applyDiffsToSnapshot( } } -void SnapshotServer::recvDeleteSnapshot(faabric::transport::Message& msg) +void SnapshotServer::recvDeleteSnapshot(const uint8_t* buffer, + size_t bufferSize) { const SnapshotDeleteRequest* r = - flatbuffers::GetRoot(msg.udata()); + flatbuffers::GetRoot(buffer); SPDLOG_INFO("Deleting shapshot {}", r->key()->c_str()); faabric::snapshot::SnapshotRegistry& reg = diff --git a/src/state/CMakeLists.txt b/src/state/CMakeLists.txt index 18b2f73d4..ea7455753 100644 --- a/src/state/CMakeLists.txt +++ b/src/state/CMakeLists.txt @@ -1,7 +1,6 @@ file(GLOB HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/state/*.h") set(LIB_FILES - DummyStateServer.cpp InMemoryStateKeyValue.cpp InMemoryStateRegistry.cpp State.cpp diff --git a/src/state/DummyStateServer.cpp b/src/state/DummyStateServer.cpp deleted file mode 100644 index e45c4e6cb..000000000 --- a/src/state/DummyStateServer.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include -#include -#include -#include - -namespace faabric::state { -DummyStateServer::DummyStateServer() - : remoteState(LOCALHOST) - , stateServer(remoteState) -{} - -std::shared_ptr DummyStateServer::getRemoteKv() -{ - if (dummyData.empty()) { - return remoteState.getKV(dummyUser, dummyKey); - } else { - return remoteState.getKV(dummyUser, dummyKey, dummyData.size()); - } -} - -std::shared_ptr DummyStateServer::getLocalKv() -{ - if (dummyData.empty()) { - return state::getGlobalState().getKV(dummyUser, dummyKey); - } else { - return state::getGlobalState().getKV( - dummyUser, dummyKey, dummyData.size()); - } -} - -std::vector DummyStateServer::getRemoteKvValue() -{ - std::vector actual(dummyData.size(), 0); - getRemoteKv()->get(actual.data()); - return actual; -} - -std::vector DummyStateServer::getLocalKvValue() -{ - std::vector actual(dummyData.size(), 0); - getLocalKv()->get(actual.data()); - return actual; -} - -void DummyStateServer::start() -{ - // NOTE - We want to test the server being on a different host. - // To do this we run the server in a separate thread, forcing it to - // have a localhost IP, then the main thread is the "client" with a - // junk IP. - - // Override the host endpoint for the server thread. Must be localhost - faabric::util::getSystemConfig().endpointHost = LOCALHOST; - - // Master the dummy data in this thread - if (!dummyData.empty()) { - const std::shared_ptr& kv = - remoteState.getKV(dummyUser, dummyKey, dummyData.size()); - std::shared_ptr inMemKv = - std::static_pointer_cast(kv); - - // Check this kv "thinks" it's master - if (!inMemKv->isMaster()) { - SPDLOG_ERROR("Dummy state server not master for data"); - throw std::runtime_error("Remote state server failed"); - } - - // Set the data - kv->set(dummyData.data()); - SPDLOG_DEBUG( - "Finished setting master for test {}/{}", kv->user, kv->key); - } - - // Start the state server - // Note - by default the state server runs in a background thread - SPDLOG_DEBUG("Running state server"); - stateServer.start(); - - // Give it time to start - usleep(1000 * 1000); -} - -void DummyStateServer::stop() -{ - stateServer.stop(); -} -} diff --git a/src/state/InMemoryStateKeyValue.cpp b/src/state/InMemoryStateKeyValue.cpp index e5d15833c..a73703cdf 100644 --- a/src/state/InMemoryStateKeyValue.cpp +++ b/src/state/InMemoryStateKeyValue.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -25,7 +26,6 @@ size_t InMemoryStateKeyValue::getStateSizeFromRemote(const std::string& userIn, StateClient stateClient(userIn, keyIn, masterIP); size_t stateSize = stateClient.stateSize(); - stateClient.close(); return stateSize; } @@ -43,7 +43,6 @@ void InMemoryStateKeyValue::deleteFromRemote(const std::string& userIn, StateClient stateClient(userIn, keyIn, masterIP); stateClient.deleteState(); - stateClient.close(); } void InMemoryStateKeyValue::clearAll(bool global) @@ -66,7 +65,15 @@ InMemoryStateKeyValue::InMemoryStateKeyValue(const std::string& userIn, , status(masterIP == thisIP ? InMemoryStateKeyStatus::MASTER : InMemoryStateKeyStatus::NOT_MASTER) , stateRegistry(getInMemoryStateRegistry()) -{} +{ + SPDLOG_TRACE("Creating in-memory state key-value for {}/{} size {} (this " + "host {}, master {})", + userIn, + keyIn, + sizeIn, + thisIP, + masterIP); +} InMemoryStateKeyValue::InMemoryStateKeyValue(const std::string& userIn, const std::string& keyIn, @@ -90,7 +97,6 @@ void InMemoryStateKeyValue::lockGlobal() } else { StateClient cli(user, key, masterIP); cli.lock(); - cli.close(); } } @@ -101,7 +107,6 @@ void InMemoryStateKeyValue::unlockGlobal() } else { StateClient cli(user, key, masterIP); cli.unlock(); - cli.close(); } } @@ -114,7 +119,6 @@ void InMemoryStateKeyValue::pullFromRemote() std::vector chunks = getAllChunks(); StateClient cli(user, key, masterIP); cli.pullChunks(chunks, BYTES(sharedMemory)); - cli.close(); } void InMemoryStateKeyValue::pullChunkFromRemote(long offset, size_t length) @@ -127,7 +131,6 @@ void InMemoryStateKeyValue::pullChunkFromRemote(long offset, size_t length) std::vector chunks = { StateChunk(offset, length, chunkStart) }; StateClient cli(user, key, masterIP); cli.pullChunks(chunks, BYTES(sharedMemory)); - cli.close(); } void InMemoryStateKeyValue::pushToRemote() @@ -139,7 +142,6 @@ void InMemoryStateKeyValue::pushToRemote() std::vector allChunks = getAllChunks(); StateClient cli(user, key, masterIP); cli.pushChunks(allChunks); - cli.close(); } void InMemoryStateKeyValue::pushPartialToRemote( @@ -150,7 +152,6 @@ void InMemoryStateKeyValue::pushPartialToRemote( } else { StateClient cli(user, key, masterIP); cli.pushChunks(chunks); - cli.close(); } } @@ -166,7 +167,6 @@ void InMemoryStateKeyValue::appendToRemote(const uint8_t* data, size_t length) } else { StateClient cli(user, key, masterIP); cli.append(data, length); - cli.close(); } } @@ -186,7 +186,6 @@ void InMemoryStateKeyValue::pullAppendedFromRemote(uint8_t* data, } else { StateClient cli(user, key, masterIP); cli.pullAppended(data, length, nValues); - cli.close(); } } @@ -198,7 +197,6 @@ void InMemoryStateKeyValue::clearAppendedFromRemote() } else { StateClient cli(user, key, masterIP); cli.clearAppended(); - cli.close(); } } diff --git a/src/state/InMemoryStateRegistry.cpp b/src/state/InMemoryStateRegistry.cpp index f29b9426c..e115aa28b 100644 --- a/src/state/InMemoryStateRegistry.cpp +++ b/src/state/InMemoryStateRegistry.cpp @@ -30,7 +30,6 @@ std::string InMemoryStateRegistry::getMasterIP(const std::string& user, const std::string& thisIP, bool claim) { - std::string lookupKey = faabric::util::keyForUser(user, key); // See if we already have the master diff --git a/src/state/StateClient.cpp b/src/state/StateClient.cpp index bc1a8242e..66e81524d 100644 --- a/src/state/StateClient.cpp +++ b/src/state/StateClient.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -7,46 +8,27 @@ namespace faabric::state { StateClient::StateClient(const std::string& userIn, const std::string& keyIn, const std::string& hostIn) - : faabric::transport::MessageEndpointClient(hostIn, STATE_PORT) + : faabric::transport::MessageEndpointClient(hostIn, + STATE_ASYNC_PORT, + STATE_SYNC_PORT) , user(userIn) , key(keyIn) - , host(hostIn) - , reg(state::getInMemoryStateRegistry()) -{ - this->open(faabric::transport::getGlobalMessageContext()); -} - -void StateClient::sendHeader(faabric::state::StateCalls call) -{ - uint8_t header = static_cast(call); - send(&header, sizeof(header), true); -} - -faabric::transport::Message StateClient::awaitResponse() -{ - // Call the superclass implementation - return MessageEndpointClient::awaitResponse(STATE_PORT + REPLY_PORT_OFFSET); -} - -void StateClient::sendStateRequest(faabric::state::StateCalls header, - bool expectReply) -{ - sendStateRequest(header, nullptr, 0, expectReply); -} +{} void StateClient::sendStateRequest(faabric::state::StateCalls header, const uint8_t* data, - int length, - bool expectReply) + int length) { faabric::StateRequest request; request.set_user(user); request.set_key(key); + if (length > 0) { request.set_data(data, length); } - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(header, request) + + faabric::EmptyResponse resp; + syncSend(header, &request, &resp); } void StateClient::pushChunks(const std::vector& chunks) @@ -57,17 +39,9 @@ void StateClient::pushChunks(const std::vector& chunks) stateChunk.set_key(key); stateChunk.set_offset(chunk.offset); stateChunk.set_data(chunk.data, chunk.length); - stateChunk.set_returnhost( - faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(faabric::state::StateCalls::Push, stateChunk) - - // Await for a response, but discard it as it is empty - try { - (void)awaitResponse(); - } catch (...) { - SPDLOG_ERROR("Error in awaitReponse"); - throw; - } + + faabric::EmptyResponse resp; + syncSend(faabric::state::StateCalls::Push, &stateChunk, &resp); } } @@ -81,12 +55,12 @@ void StateClient::pullChunks(const std::vector& chunks, request.set_key(key); request.set_offset(chunk.offset); request.set_chunksize(chunk.length); - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(faabric::state::StateCalls::Pull, request) - // Receive message - faabric::transport::Message recvMsg = awaitResponse(); - PARSE_RESPONSE(faabric::StatePart, recvMsg.data(), recvMsg.size()) + // Send request + faabric::StatePart response; + syncSend(faabric::state::StateCalls::Pull, &request, &response); + + // Copy response data std::copy(response.data().begin(), response.data().end(), bufferStart + response.offset()); @@ -95,11 +69,7 @@ void StateClient::pullChunks(const std::vector& chunks, void StateClient::append(const uint8_t* data, size_t length) { - // Send request sendStateRequest(faabric::state::StateCalls::Append, data, length); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); } void StateClient::pullAppended(uint8_t* buffer, size_t length, long nValues) @@ -109,13 +79,9 @@ void StateClient::pullAppended(uint8_t* buffer, size_t length, long nValues) request.set_user(user); request.set_key(key); request.set_nvalues(nValues); - request.set_returnhost(faabric::util::getSystemConfig().endpointHost); - SEND_MESSAGE(faabric::state::StateCalls::PullAppended, request) - // Receive response - faabric::transport::Message recvMsg = awaitResponse(); - PARSE_RESPONSE( - faabric::StateAppendedResponse, recvMsg.data(), recvMsg.size()) + faabric::StateAppendedResponse response; + syncSend(faabric::state::StateCalls::PullAppended, &request, &response); // Process response size_t offset = 0; @@ -135,47 +101,33 @@ void StateClient::pullAppended(uint8_t* buffer, size_t length, long nValues) void StateClient::clearAppended() { - // Send request - sendStateRequest(faabric::state::StateCalls::ClearAppended); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::ClearAppended, nullptr, 0); } size_t StateClient::stateSize() { - // Include the return address in the message body - sendStateRequest(faabric::state::StateCalls::Size, true); + faabric::StateRequest request; + request.set_user(user); + request.set_key(key); + + faabric::StateSizeResponse response; + syncSend(faabric::state::StateCalls::Size, &request, &response); - // Receive message - faabric::transport::Message recvMsg = awaitResponse(); - PARSE_RESPONSE(faabric::StateSizeResponse, recvMsg.data(), recvMsg.size()) return response.statesize(); } void StateClient::deleteState() { - // Send request - sendStateRequest(faabric::state::StateCalls::Delete, true); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::Delete, nullptr, 0); } void StateClient::lock() { - // Send request - sendStateRequest(faabric::state::StateCalls::Lock); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::Lock, nullptr, 0); } void StateClient::unlock() { - sendStateRequest(faabric::state::StateCalls::Unlock); - - // Await for a response, but discard it as it is empty - (void)awaitResponse(); + sendStateRequest(faabric::state::StateCalls::Unlock, nullptr, 0); } } diff --git a/src/state/StateKeyValue.cpp b/src/state/StateKeyValue.cpp index c53dddaf1..6f7c223c9 100644 --- a/src/state/StateKeyValue.cpp +++ b/src/state/StateKeyValue.cpp @@ -559,7 +559,7 @@ uint32_t StateKeyValue::waitOnRedisRemoteLock(const std::string& redisKey) break; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + SLEEP_MS(500); remoteLockId = redis.acquireLock(redisKey, REMOTE_LOCK_TIMEOUT_SECS); retryCount++; diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index 4c3feffd1..b162fbee6 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -12,66 +13,77 @@ namespace faabric::state { StateServer::StateServer(State& stateIn) - : faabric::transport::MessageEndpointServer(STATE_PORT) + : faabric::transport::MessageEndpointServer(STATE_ASYNC_PORT, STATE_SYNC_PORT) , state(stateIn) {} -void StateServer::doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) +void StateServer::doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) { - assert(header.size() == sizeof(uint8_t)); - uint8_t call = static_cast(*header.data()); - switch (call) { - case faabric::state::StateCalls::Pull: - this->recvPull(body); - break; - case faabric::state::StateCalls::Push: - this->recvPush(body); - break; - case faabric::state::StateCalls::Size: - this->recvSize(body); - break; - case faabric::state::StateCalls::Append: - this->recvAppend(body); - break; - case faabric::state::StateCalls::ClearAppended: - this->recvClearAppended(body); - break; - case faabric::state::StateCalls::PullAppended: - this->recvPullAppended(body); - break; - case faabric::state::StateCalls::Lock: - this->recvLock(body); - break; - case faabric::state::StateCalls::Unlock: - this->recvUnlock(body); - break; - case faabric::state::StateCalls::Delete: - this->recvDelete(body); - break; - default: + throw std::runtime_error("State server does not support async recv"); +} + +std::unique_ptr +StateServer::doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) +{ + switch (header) { + case faabric::state::StateCalls::Pull: { + return recvPull(buffer, bufferSize); + } + case faabric::state::StateCalls::Push: { + return recvPush(buffer, bufferSize); + } + case faabric::state::StateCalls::Size: { + return recvSize(buffer, bufferSize); + } + case faabric::state::StateCalls::Append: { + return recvAppend(buffer, bufferSize); + } + case faabric::state::StateCalls::ClearAppended: { + return recvClearAppended(buffer, bufferSize); + } + case faabric::state::StateCalls::PullAppended: { + return recvPullAppended(buffer, bufferSize); + } + case faabric::state::StateCalls::Lock: { + return recvLock(buffer, bufferSize); + } + case faabric::state::StateCalls::Unlock: { + return recvUnlock(buffer, bufferSize); + } + case faabric::state::StateCalls::Delete: { + return recvDelete(buffer, bufferSize); + } + default: { throw std::runtime_error( - fmt::format("Unrecognized state call header: {}", call)); + fmt::format("Unrecognized state call header: {}", header)); + } } } -void StateServer::recvSize(faabric::transport::Message& body) +std::unique_ptr StateServer::recvSize( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Prepare the response SPDLOG_TRACE("Size {}/{}", msg.user(), msg.key()); KV_FROM_REQUEST(msg) - faabric::StateSizeResponse response; - response.set_user(kv->user); - response.set_key(kv->key); - response.set_statesize(kv->size()); - SEND_SERVER_RESPONSE(response, msg.returnhost(), STATE_PORT) + auto response = std::make_unique(); + response->set_user(kv->user); + response->set_key(kv->key); + response->set_statesize(kv->size()); + + return response; } -void StateServer::recvPull(faabric::transport::Message& body) +std::unique_ptr StateServer::recvPull( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateChunkRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateChunkRequest, buffer, bufferSize) SPDLOG_TRACE("Pull {}/{} ({}->{})", msg.user(), @@ -80,22 +92,26 @@ void StateServer::recvPull(faabric::transport::Message& body) msg.offset() + msg.chunksize()); // Write the response - faabric::StatePart response; KV_FROM_REQUEST(msg) uint64_t chunkOffset = msg.offset(); uint64_t chunkLen = msg.chunksize(); uint8_t* chunk = kv->getChunk(chunkOffset, chunkLen); - response.set_user(msg.user()); - response.set_key(msg.key()); - response.set_offset(chunkOffset); + + auto response = std::make_unique(); + response->set_user(msg.user()); + response->set_key(msg.key()); + response->set_offset(chunkOffset); // TODO: avoid copying here - response.set_data(chunk, chunkLen); - SEND_SERVER_RESPONSE(response, msg.returnhost(), STATE_PORT) + response->set_data(chunk, chunkLen); + + return response; } -void StateServer::recvPush(faabric::transport::Message& body) +std::unique_ptr StateServer::recvPush( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StatePart, body.data(), body.size()) + PARSE_MSG(faabric::StatePart, buffer, bufferSize) // Update the KV store SPDLOG_TRACE("Push {}/{} ({}->{})", @@ -103,17 +119,20 @@ void StateServer::recvPush(faabric::transport::Message& body) msg.key(), msg.offset(), msg.offset() + msg.data().size()); + KV_FROM_REQUEST(msg) kv->setChunk( msg.offset(), BYTES_CONST(msg.data().c_str()), msg.data().size()); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + auto response = std::make_unique(); + return response; } -void StateServer::recvAppend(faabric::transport::Message& body) +std::unique_ptr StateServer::recvAppend( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Update the KV KV_FROM_REQUEST(msg) @@ -121,77 +140,89 @@ void StateServer::recvAppend(faabric::transport::Message& body) uint64_t dataLen = msg.data().size(); kv->append(reqData, dataLen); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + auto response = std::make_unique(); + return response; } -void StateServer::recvPullAppended(faabric::transport::Message& body) +std::unique_ptr StateServer::recvPullAppended( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateAppendedRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateAppendedRequest, buffer, bufferSize) // Prepare response - faabric::StateAppendedResponse response; SPDLOG_TRACE("Pull appended {}/{}", msg.user(), msg.key()); KV_FROM_REQUEST(msg) - response.set_user(msg.user()); - response.set_key(msg.key()); + + auto response = std::make_unique(); + response->set_user(msg.user()); + response->set_key(msg.key()); for (uint32_t i = 0; i < msg.nvalues(); i++) { AppendedInMemoryState& value = kv->getAppendedValue(i); - auto appendedValue = response.add_values(); + auto appendedValue = response->add_values(); appendedValue->set_data(reinterpret_cast(value.data.get()), value.length); } - SEND_SERVER_RESPONSE(response, msg.returnhost(), STATE_PORT) + + return response; } -void StateServer::recvDelete(faabric::transport::Message& body) +std::unique_ptr StateServer::recvDelete( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Delete value SPDLOG_TRACE("Delete {}/{}", msg.user(), msg.key()); state.deleteKV(msg.user(), msg.key()); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + auto response = std::make_unique(); + return response; } -void StateServer::recvClearAppended(faabric::transport::Message& body) +std::unique_ptr StateServer::recvClearAppended( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Perform operation SPDLOG_TRACE("Clear appended {}/{}", msg.user(), msg.key()); KV_FROM_REQUEST(msg) kv->clearAppended(); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + auto response = std::make_unique(); + return response; } -void StateServer::recvLock(faabric::transport::Message& body) +std::unique_ptr StateServer::recvLock( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Perform operation SPDLOG_TRACE("Lock {}/{}", msg.user(), msg.key()); KV_FROM_REQUEST(msg) kv->lockWrite(); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + auto response = std::make_unique(); + return response; } -void StateServer::recvUnlock(faabric::transport::Message& body) +std::unique_ptr StateServer::recvUnlock( + const uint8_t* buffer, + size_t bufferSize) { - PARSE_MSG(faabric::StateRequest, body.data(), body.size()) + PARSE_MSG(faabric::StateRequest, buffer, bufferSize) // Perform operation SPDLOG_TRACE("Unlock {}/{}", msg.user(), msg.key()); KV_FROM_REQUEST(msg) kv->unlockWrite(); - faabric::StateResponse emptyResponse; - SEND_SERVER_RESPONSE(emptyResponse, msg.returnhost(), STATE_PORT) + auto response = std::make_unique(); + return response; } } diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index 3ac400f19..b49d705f9 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -4,18 +4,18 @@ set(HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/transport/common.h" + "${FAABRIC_INCLUDE_DIR}/faabric/transport/context.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/macros.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/Message.h" - "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageContext.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpoint.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointClient.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointServer.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MpiMessageEndpoint.h" ) -set(LIB_FILES +set(LIB_FILES + context.cpp Message.cpp - MessageContext.cpp MessageEndpoint.cpp MessageEndpointClient.cpp MessageEndpointServer.cpp @@ -25,5 +25,5 @@ set(LIB_FILES faabric_lib(transport "${LIB_FILES}") -target_link_libraries(transport proto zeromq_imported) +target_link_libraries(transport util proto zeromq_imported) diff --git a/src/transport/Message.cpp b/src/transport/Message.cpp index 489eab098..37ee3e01a 100644 --- a/src/transport/Message.cpp +++ b/src/transport/Message.cpp @@ -1,57 +1,46 @@ #include +#include namespace faabric::transport { Message::Message(const zmq::message_t& msgIn) - : _size(msgIn.size()) - , _more(msgIn.more()) - , _persist(false) + : _more(msgIn.more()) { - msg = reinterpret_cast(malloc(_size * sizeof(uint8_t))); - memcpy(msg, msgIn.data(), _size); + if (msgIn.data() != nullptr) { + bytes = std::vector(BYTES_CONST(msgIn.data()), + BYTES_CONST(msgIn.data()) + msgIn.size()); + } } Message::Message(int sizeIn) - : _size(sizeIn) + : bytes(sizeIn) , _more(false) - , _persist(false) -{ - msg = reinterpret_cast(malloc(_size * sizeof(uint8_t))); -} +{} // Empty message signals shutdown -Message::Message() - : msg(nullptr) -{} +Message::Message() {} -Message::~Message() +char* Message::data() { - if (!_persist) { - free(reinterpret_cast(msg)); - } + return reinterpret_cast(bytes.data()); } -char* Message::data() +uint8_t* Message::udata() { - return reinterpret_cast(msg); + return bytes.data(); } -uint8_t* Message::udata() +std::vector Message::dataCopy() { - return msg; + return bytes; } int Message::size() { - return _size; + return bytes.size(); } bool Message::more() { return _more; } - -void Message::persist() -{ - _persist = true; -} } diff --git a/src/transport/MessageContext.cpp b/src/transport/MessageContext.cpp deleted file mode 100644 index 8c7f74649..000000000 --- a/src/transport/MessageContext.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -namespace faabric::transport { -MessageContext::MessageContext() - : ctx(1) - , isContextShutDown(false) -{} - -MessageContext::MessageContext(int overrideIoThreads) - : ctx(overrideIoThreads) -{} - -MessageContext::~MessageContext() -{ - this->close(); -} - -void MessageContext::close() -{ - this->ctx.close(); - this->isContextShutDown = true; -} - -zmq::context_t& MessageContext::get() -{ - return this->ctx; -} - -faabric::transport::MessageContext& getGlobalMessageContext() -{ - static auto msgContext = - std::make_unique(); - // The message context needs to be opened and closed every server instance. - // Sometimes (e.g. tests) the scheduler is re-used, but the message context - // needs to be reset. In this situations we override the shut-down message - // context. - if (msgContext->isContextShutDown) { - msgContext = std::make_unique(); - } - return *msgContext; -} -} diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index f122a680b..2ef6d92b6 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,258 +1,208 @@ #include +#include +#include #include #include +#include #include +#define RETRY_SLEEP_MS 1000 + +#define CATCH_ZMQ_ERR(op, label) \ + try { \ + op; \ + } catch (zmq::error_t & e) { \ + SPDLOG_ERROR("Caught ZeroMQ error for {} on address {}: {} ({})", \ + label, \ + address, \ + e.num(), \ + e.what()); \ + throw; \ + } + +#define CATCH_ZMQ_ERR_RETRY_ONCE(op, label) \ + try { \ + op; \ + } catch (zmq::error_t & e) { \ + SPDLOG_WARN("Caught ZeroMQ error for {} on address {}: {} ({})", \ + label, \ + address, \ + e.num(), \ + e.what()); \ + SPDLOG_WARN("Retrying {} on address {}", label, address); \ + SLEEP_MS(RETRY_SLEEP_MS); \ + try { \ + op; \ + } catch (zmq::error_t & e2) { \ + SPDLOG_ERROR( \ + "Caught ZeroMQ error on retry for {} on address {}: {} ({})", \ + label, \ + address, \ + e2.num(), \ + e2.what()); \ + throw; \ + } \ + } + namespace faabric::transport { -MessageEndpoint::MessageEndpoint(const std::string& hostIn, int portIn) + +MessageEndpoint::MessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMsIn) : host(hostIn) , port(portIn) + , address("tcp://" + host + ":" + std::to_string(port)) + , timeoutMs(timeoutMsIn) , tid(std::this_thread::get_id()) , id(faabric::util::generateGid()) -{} - -MessageEndpoint::~MessageEndpoint() { - if (this->socket != nullptr) { - SPDLOG_WARN("Destroying an open message endpoint!"); - this->close(false); + // Check and set socket timeout + if (timeoutMs <= 0) { + SPDLOG_ERROR("Setting invalid timeout of {}", timeoutMs); + throw std::runtime_error("Setting invalid timeout"); } } -void MessageEndpoint::open(faabric::transport::MessageContext& context, - faabric::transport::SocketType sockType, - bool bind) +zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, + int socketPort) { - // Check we are opening from the same thread. We assert not to incur in - // costly checks when running a Release build. - assert(tid == std::this_thread::get_id()); - - std::string address = - "tcp://" + this->host + ":" + std::to_string(this->port); - - // Note - only one socket may bind, but several can connect. This - // allows for easy N - 1 or 1 - N PUSH/PULL patterns. Order between - // bind and connect does not matter. - switch (sockType) { - case faabric::transport::SocketType::PUSH: - try { - this->socket = std::make_unique( - context.get(), zmq::socket_type::push); - } catch (zmq::error_t& e) { - SPDLOG_ERROR( - "Error opening SEND socket to {}: {}", address, e.what()); - throw; - } + zmq::socket_t socket; + + // Create the socket + CATCH_ZMQ_ERR(socket = + zmq::socket_t(*getGlobalMessageContext(), socketType), + "socket_create") + socket.set(zmq::sockopt::rcvtimeo, timeoutMs); + socket.set(zmq::sockopt::sndtimeo, timeoutMs); + + // Note - setting linger here is essential to avoid infinite hangs + socket.set(zmq::sockopt::linger, LINGER_MS); + + switch (socketType) { + case zmq::socket_type::req: { + SPDLOG_TRACE( + "New socket: req {}:{} (timeout {}ms)", host, port, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") break; - case faabric::transport::SocketType::PULL: - try { - this->socket = std::make_unique( - context.get(), zmq::socket_type::pull); - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error opening RECV socket bound to {}: {}", - address, - e.what()); - throw; - } - + } + case zmq::socket_type::push: { + SPDLOG_TRACE( + "New socket: push {}:{} (timeout {}ms)", host, port, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + break; + } + case zmq::socket_type::pull: { + SPDLOG_TRACE( + "New socket: pull {}:{} (timeout {}ms)", host, port, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + case zmq::socket_type::rep: { + SPDLOG_TRACE( + "New socket: rep {}:{} (timeout {}ms)", host, port, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; - default: - throw std::runtime_error("Unrecognized socket type"); - } - assert(this->socket != nullptr); - - // Bind or connect the socket - if (bind) { - try { - this->socket->bind(address); - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error binding socket to {}: {}", address, e.what()); - throw; } - } else { - try { - this->socket->connect(address); - } catch (zmq::error_t& e) { - SPDLOG_ERROR( - "Error connecting socket to {}: {}", address, e.what()); - throw; + default: { + throw std::runtime_error("Opening unrecognized socket type"); } } - // Set socket options - this->socket->set(zmq::sockopt::rcvtimeo, recvTimeoutMs); - this->socket->set(zmq::sockopt::sndtimeo, recvTimeoutMs); + return socket; } -void MessageEndpoint::send(uint8_t* serialisedMsg, size_t msgSize, bool more) +void MessageEndpoint::doSend(zmq::socket_t& socket, + const uint8_t* data, + size_t dataSize, + bool more) { assert(tid == std::this_thread::get_id()); - assert(this->socket != nullptr); - - if (more) { - try { - auto res = this->socket->send(zmq::buffer(serialisedMsg, msgSize), - zmq::send_flags::sndmore); - if (res != msgSize) { - SPDLOG_ERROR("Sent different bytes than expected (sent " - "{}, expected {})", - res.value_or(0), - msgSize); - throw std::runtime_error("Error sending message"); - } - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error sending message: {}", e.what()); - throw; - } - } else { - try { - auto res = this->socket->send(zmq::buffer(serialisedMsg, msgSize), - zmq::send_flags::none); - if (res != msgSize) { - SPDLOG_ERROR("Sent different bytes than expected (sent " - "{}, expected {})", - res.value_or(0), - msgSize); - throw std::runtime_error("Error sending message"); - } - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error sending message: {}", e.what()); - throw; - } - } + zmq::send_flags sendFlags = + more ? zmq::send_flags::sndmore : zmq::send_flags::none; + + CATCH_ZMQ_ERR( + { + auto res = socket.send(zmq::buffer(data, dataSize), sendFlags); + if (res != dataSize) { + SPDLOG_ERROR("Sent different bytes than expected (sent " + "{}, expected {})", + res.value_or(0), + dataSize); + throw std::runtime_error("Error sending message"); + } + }, + "send") } -// By passing the expected recv buffer size, we instrument zeromq to receive on -// our provisioned buffer -Message MessageEndpoint::recv(int size) +Message MessageEndpoint::doRecv(zmq::socket_t& socket, int size) { assert(tid == std::this_thread::get_id()); - assert(this->socket != nullptr); assert(size >= 0); - // Pre-allocate buffer to avoid copying data - if (size > 0) { - Message msg(size); - - try { - auto res = this->socket->recv(zmq::buffer(msg.udata(), msg.size())); - - if (!res.has_value()) { - SPDLOG_ERROR("Timed out receiving message of size {}", size); - throw MessageTimeoutException("Timed out receiving message"); - } - - if (res.has_value() && (res->size != res->untruncated_size)) { - SPDLOG_ERROR("Received more bytes than buffer can hold. " - "Received: {}, capacity {}", - res->untruncated_size, - res->size); - throw std::runtime_error("Error receiving message"); - } - } catch (zmq::error_t& e) { - if (e.num() == ZMQ_ETERM) { - // Return empty message to signify termination - SPDLOG_TRACE("Shutting endpoint down after receiving ETERM"); - return Message(); - } - - // Print default message and rethrow - SPDLOG_ERROR("Error receiving message: {} ({})", e.num(), e.what()); - throw; - } - - return msg; - } - - // Allocate a message to receive data - zmq::message_t msg; - try { - auto res = this->socket->recv(msg); - if (!res.has_value()) { - SPDLOG_ERROR("Timed out receiving message with no size"); - throw MessageTimeoutException("Timed out receiving message"); - } - } catch (zmq::error_t& e) { - if (e.num() == ZMQ_ETERM) { - // Return empty message to signify termination - SPDLOG_TRACE("Shutting endpoint down after receiving ETERM"); - return Message(); - } else { - SPDLOG_ERROR("Error receiving message: {} ({})", e.num(), e.what()); - throw; - } + if (size == 0) { + return recvNoBuffer(socket); } - // Copy the received message to a buffer whose scope we control - return Message(msg); + return recvBuffer(socket, size); } -void MessageEndpoint::close(bool bind) +Message MessageEndpoint::recvBuffer(zmq::socket_t& socket, int size) { - if (this->socket != nullptr) { - - if (tid != std::this_thread::get_id()) { - SPDLOG_WARN("Closing socket from a different thread"); - } + // Pre-allocate buffer to avoid copying data + Message msg(size); + + CATCH_ZMQ_ERR( + try { + auto res = socket.recv(zmq::buffer(msg.udata(), msg.size())); + + if (!res.has_value()) { + SPDLOG_TRACE("Timed out receiving message of size {}", size); + throw MessageTimeoutException("Timed out receiving message"); + } + + if (res.has_value() && (res->size != res->untruncated_size)) { + SPDLOG_ERROR("Received more bytes than buffer can hold. " + "Received: {}, capacity {}", + res->untruncated_size, + res->size); + throw std::runtime_error("Error receiving message"); + } + } catch (zmq::error_t& e) { + if (e.num() == ZMQ_ETERM) { + SPDLOG_WARN("Endpoint {}:{} received ETERM on recv", host, port); + return Message(); + } + + throw; + }, + "recv_buffer") + + return msg; +} - std::string address = - "tcp://" + this->host + ":" + std::to_string(this->port); - - // We duplicate the call to close() because when unbinding, we want to - // block until we _actually_ have unbinded, i.e. 0MQ has closed the - // socket (which happens asynchronously). For connect()-ed sockets we - // don't care. - // Not blobking on un-bind can cause race-conditions when the underlying - // system is slow at closing sockets, and the application relies a lot - // on synchronous message-passing. - if (bind) { - try { - this->socket->unbind(address); - } catch (zmq::error_t& e) { - if (e.num() != ZMQ_ETERM) { - SPDLOG_ERROR("Error unbinding socket: {}", e.what()); - throw; - } - } - // NOTE - unbinding a socket has a considerable overhead compared to - // disconnecting it. - // TODO - could we reuse the monitor? - try { - { - zmq::monitor_t mon; - const std::string monAddr = - "inproc://monitor_" + std::to_string(id); - mon.init(*(this->socket), monAddr, ZMQ_EVENT_CLOSED); - this->socket->close(); - mon.check_event(-1); - } - } catch (zmq::error_t& e) { - if (e.num() != ZMQ_ETERM) { - SPDLOG_ERROR("Error closing bind socket: {}", e.what()); - throw; - } - } - } else { - try { - this->socket->disconnect(address); - } catch (zmq::error_t& e) { - if (e.num() != ZMQ_ETERM) { - SPDLOG_ERROR("Error disconnecting socket: {}", e.what()); - throw; - } - } - try { - this->socket->close(); - } catch (zmq::error_t& e) { - SPDLOG_ERROR("Error closing connect socket: {}", e.what()); - throw; - } - } +Message MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) +{ + // Allocate a message to receive data + zmq::message_t msg; + CATCH_ZMQ_ERR( + try { + auto res = socket.recv(msg); + if (!res.has_value()) { + SPDLOG_TRACE("Timed out receiving message with no size"); + throw MessageTimeoutException("Timed out receiving message"); + } + } catch (zmq::error_t& e) { + if (e.num() == ZMQ_ETERM) { + SPDLOG_WARN("Endpoint {}:{} received ETERM on recv", host, port); + return Message(); + } + throw; + }, + "recv_no_buffer") - // Finally, null the socket - this->socket = nullptr; - } + // Copy the received message to a buffer whose scope we control + return Message(msg); } std::string MessageEndpoint::getHost() @@ -265,70 +215,116 @@ int MessageEndpoint::getPort() return port; } -void MessageEndpoint::validateTimeout(int value) +// ---------------------------------------------- +// ASYNC SEND ENDPOINT +// ---------------------------------------------- + +AsyncSendMessageEndpoint::AsyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs) + : MessageEndpoint(hostIn, portIn, timeoutMs) { - if (value <= 0) { - SPDLOG_ERROR("Setting invalid timeout of {}", value); - throw std::runtime_error("Setting invalid timeout"); - } + pushSocket = setUpSocket(zmq::socket_type::push, portIn); +} - if (socket != nullptr) { - SPDLOG_ERROR("Setting timeout of {} after socket created", value); - throw std::runtime_error("Setting timeout after socket created"); - } +void AsyncSendMessageEndpoint::sendHeader(int header) +{ + uint8_t headerBytes = static_cast(header); + doSend(pushSocket, &headerBytes, sizeof(headerBytes), true); } -void MessageEndpoint::setRecvTimeoutMs(int value) +void AsyncSendMessageEndpoint::send(const uint8_t* data, + size_t dataSize, + bool more) { - validateTimeout(value); - recvTimeoutMs = value; + SPDLOG_TRACE("PUSH {}:{} ({} bytes, more {})", host, port, dataSize, more); + doSend(pushSocket, data, dataSize, more); } -void MessageEndpoint::setSendTimeoutMs(int value) +// ---------------------------------------------- +// SYNC SEND ENDPOINT +// ---------------------------------------------- + +SyncSendMessageEndpoint::SyncSendMessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMs) + : MessageEndpoint(hostIn, portIn, timeoutMs) { - validateTimeout(value); - sendTimeoutMs = value; + reqSocket = setUpSocket(zmq::socket_type::req, portIn); } -/* Send and Recv Message Endpoints */ +void SyncSendMessageEndpoint::sendHeader(int header) +{ + uint8_t headerBytes = static_cast(header); + doSend(reqSocket, &headerBytes, sizeof(headerBytes), true); +} -SendMessageEndpoint::SendMessageEndpoint(const std::string& hostIn, int portIn) - : MessageEndpoint(hostIn, portIn) -{} +void SyncSendMessageEndpoint::sendRaw(const uint8_t* data, size_t dataSize) +{ + SPDLOG_TRACE("REQ {}:{} ({} bytes)", host, port, dataSize); + doSend(reqSocket, data, dataSize, false); +} -void SendMessageEndpoint::open(MessageContext& context) +Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, + size_t dataSize, + bool more) { - SPDLOG_TRACE( - fmt::format("Opening socket: {} (SEND {}:{})", id, host, port)); + SPDLOG_TRACE("REQ {}:{} ({} bytes, more {})", host, port, dataSize, more); + doSend(reqSocket, data, dataSize, more); - MessageEndpoint::open(context, SocketType::PUSH, false); + // Do the receive + SPDLOG_TRACE("RECV (REQ) {}", port); + return recvNoBuffer(reqSocket); } -void SendMessageEndpoint::close() +// ---------------------------------------------- +// RECV ENDPOINT +// ---------------------------------------------- + +RecvMessageEndpoint::RecvMessageEndpoint(int portIn, + int timeoutMs, + zmq::socket_type socketType) + : MessageEndpoint(ANY_HOST, portIn, timeoutMs) { - SPDLOG_TRACE( - fmt::format("Closing socket: {} (SEND {}:{})", id, host, port)); + socket = setUpSocket(socketType, portIn); +} - MessageEndpoint::close(false); +Message RecvMessageEndpoint::recv(int size) +{ + return doRecv(socket, size); } -RecvMessageEndpoint::RecvMessageEndpoint(int portIn) - : MessageEndpoint(ANY_HOST, portIn) +// ---------------------------------------------- +// ASYNC RECV ENDPOINT +// ---------------------------------------------- + +AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) {} -void RecvMessageEndpoint::open(MessageContext& context) +Message AsyncRecvMessageEndpoint::recv(int size) { - SPDLOG_TRACE( - fmt::format("Opening socket: {} (RECV {}:{})", id, host, port)); - - MessageEndpoint::open(context, SocketType::PULL, true); + SPDLOG_TRACE("PULL {} ({} bytes)", port, size); + return RecvMessageEndpoint::recv(size); } -void RecvMessageEndpoint::close() +// ---------------------------------------------- +// SYNC RECV ENDPOINT +// ---------------------------------------------- + +SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::rep) +{} + +Message SyncRecvMessageEndpoint::recv(int size) { - SPDLOG_TRACE( - fmt::format("Closing socket: {} (RECV {}:{})", id, host, port)); + SPDLOG_TRACE("RECV (REP) {} ({} bytes)", port, size); + return RecvMessageEndpoint::recv(size); +} - MessageEndpoint::close(true); +void SyncRecvMessageEndpoint::sendResponse(const uint8_t* data, int size) +{ + SPDLOG_TRACE("REP {} ({} bytes)", port, size); + doSend(socket, data, size, false); } } diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index cdd2461f9..a041c5b0e 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -1,25 +1,65 @@ #include namespace faabric::transport { -MessageEndpointClient::MessageEndpointClient(const std::string& host, int port) - : SendMessageEndpoint(host, port) + +MessageEndpointClient::MessageEndpointClient(std::string hostIn, + int asyncPortIn, + int syncPortIn, + int timeoutMs) + : host(hostIn) + , asyncPort(asyncPortIn) + , syncPort(syncPortIn) + , asyncEndpoint(host, asyncPort, timeoutMs) + , syncEndpoint(host, syncPort, timeoutMs) {} -// Block until we receive a response from the server -Message MessageEndpointClient::awaitResponse(int port) +void MessageEndpointClient::asyncSend(int header, + google::protobuf::Message* msg) { - // Wait for the response, open a temporary endpoint for it - // Note - we use a different host/port not to clash with existing server - RecvMessageEndpoint endpoint(port); + size_t msgSize = msg->ByteSizeLong(); + uint8_t buffer[msgSize]; + + if (!msg->SerializeToArray(buffer, msgSize)) { + throw std::runtime_error("Error serialising message"); + } + + asyncSend(header, buffer, msgSize); +} - // Inherit timeouts on temporary endpoint - endpoint.setRecvTimeoutMs(recvTimeoutMs); - endpoint.setSendTimeoutMs(sendTimeoutMs); +void MessageEndpointClient::asyncSend(int header, + const uint8_t* buffer, + size_t bufferSize) +{ + asyncEndpoint.sendHeader(header); + + asyncEndpoint.send(buffer, bufferSize); +} + +void MessageEndpointClient::syncSend(int header, + google::protobuf::Message* msg, + google::protobuf::Message* response) +{ + size_t msgSize = msg->ByteSizeLong(); + uint8_t buffer[msgSize]; + if (!msg->SerializeToArray(buffer, msgSize)) { + throw std::runtime_error("Error serialising message"); + } + + syncSend(header, buffer, msgSize, response); +} + +void MessageEndpointClient::syncSend(int header, + const uint8_t* buffer, + const size_t bufferSize, + google::protobuf::Message* response) +{ + syncEndpoint.sendHeader(header); - endpoint.open(faabric::transport::getGlobalMessageContext()); - Message receivedMessage = endpoint.recv(); - endpoint.close(); + Message responseMsg = syncEndpoint.sendAwaitResponse(buffer, bufferSize); - return receivedMessage; + // Deserialise response + if (!response->ParseFromArray(responseMsg.data(), responseMsg.size())) { + throw std::runtime_error("Error deserialising message"); + } } } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 13600759c..52a59eb57 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,95 +1,171 @@ #include +#include +#include +#include +#include #include #include namespace faabric::transport { -MessageEndpointServer::MessageEndpointServer(int portIn) - : port(portIn) + +static const std::vector shutdownHeader = { 0, 0, 1, 1 }; + +MessageEndpointServerThread::MessageEndpointServerThread( + MessageEndpointServer* serverIn, + bool asyncIn) + : server(serverIn) + , async(asyncIn) {} -void MessageEndpointServer::start() +void MessageEndpointServerThread::start( + std::shared_ptr latch) { - start(faabric::transport::getGlobalMessageContext()); -} + backgroundThread = std::thread([this, latch] { + std::unique_ptr endpoint = nullptr; + int port = -1; -void MessageEndpointServer::start(faabric::transport::MessageContext& context) -{ - // Start serving thread in background - this->servingThread = std::thread([this, &context] { - RecvMessageEndpoint serverEndpoint(this->port); + if (async) { + port = server->asyncPort; + endpoint = std::make_unique(port); + } else { + port = server->syncPort; + endpoint = std::make_unique(port); + } - // Open message endpoint, and bind - serverEndpoint.open(context); - assert(serverEndpoint.socket != nullptr); + latch->wait(); - // Loop until context is terminated while (true) { - int rc = this->recv(serverEndpoint); - if (rc == ENDPOINT_SERVER_SHUTDOWN) { - serverEndpoint.close(); - break; + bool headerReceived = false; + bool bodyReceived = false; + try { + // Receive header and body + Message headerMessage = endpoint->recv(); + headerReceived = true; + + if (headerMessage.size() == shutdownHeader.size()) { + if (headerMessage.dataCopy() == shutdownHeader) { + SPDLOG_TRACE("Server on {} received shutdown message", + port); + break; + } + } + + if (!headerMessage.more()) { + throw std::runtime_error( + "Header sent without SNDMORE flag"); + } + + Message body = endpoint->recv(); + if (body.more()) { + throw std::runtime_error("Body sent with SNDMORE flag"); + } + bodyReceived = true; + + assert(headerMessage.size() == sizeof(uint8_t)); + uint8_t header = static_cast(*headerMessage.data()); + + if (async) { + // Server-specific async handling + server->doAsyncRecv(header, body.udata(), body.size()); + } else { + // Server-specific sync handling + std::unique_ptr resp = + server->doSyncRecv(header, body.udata(), body.size()); + size_t respSize = resp->ByteSizeLong(); + + uint8_t buffer[respSize]; + if (!resp->SerializeToArray(buffer, respSize)) { + throw std::runtime_error("Error serialising message"); + } + + // Return the response + static_cast(endpoint.get()) + ->sendResponse(buffer, respSize); + } + } catch (MessageTimeoutException& ex) { + // If we don't get a header in the timeout, we're ok to just + // loop round and try again + if (!headerReceived) { + SPDLOG_TRACE("Server on port {}, looping after no message", + port); + continue; + } + + if (headerReceived && !bodyReceived) { + SPDLOG_ERROR( + "Server on port {}, got header, timed out on body", port); + throw; + } + } + + // Wait on the async latch if necessary + if (server->asyncLatch != nullptr) { + SPDLOG_TRACE("Server thread waiting on async latch"); + server->asyncLatch->wait(); } + + headerReceived = false; + bodyReceived = false; } }); } -void MessageEndpointServer::stop() +void MessageEndpointServerThread::join() { - stop(faabric::transport::getGlobalMessageContext()); + if (backgroundThread.joinable()) { + backgroundThread.join(); + } } -void MessageEndpointServer::stop(faabric::transport::MessageContext& context) +MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) + : asyncPort(asyncPortIn) + , syncPort(syncPortIn) + , asyncThread(this, true) + , syncThread(this, false) + , asyncShutdownSender(LOCALHOST, asyncPort) + , syncShutdownSender(LOCALHOST, syncPort) +{} + +void MessageEndpointServer::start() { - // Note - different servers will concurrently close the server context, but - // this structure is thread-safe, and the close operation idempotent. - context.close(); + // This latch means that callers can guarantee that when this function + // completes, both sockets will have been opened (and hence the server is + // ready to use). + auto startLatch = faabric::util::Latch::create(3); - // Finally join the serving thread - if (this->servingThread.joinable()) { - this->servingThread.join(); - } + asyncThread.start(startLatch); + syncThread.start(startLatch); + + startLatch->wait(); } -int MessageEndpointServer::recv(RecvMessageEndpoint& endpoint) +void MessageEndpointServer::stop() { - assert(endpoint.socket != nullptr); + // Send shutdown messages + SPDLOG_TRACE( + "Server sending shutdown messages to ports {} {}", asyncPort, syncPort); - // Receive header and body - Message header = endpoint.recv(); - // Detect shutdown condition - if (header.udata() == nullptr) { - return ENDPOINT_SERVER_SHUTDOWN; - } - // Check the header was sent with ZMQ_SNDMORE flag - if (!header.more()) { - throw std::runtime_error("Header sent without SNDMORE flag"); - } + asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); - Message body = endpoint.recv(); - // Check that there are no more messages to receive - if (body.more()) { - throw std::runtime_error("Body sent with SNDMORE flag"); - } - assert(body.udata() != nullptr); + syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); - // Server-specific message handling - doRecv(header, body); + // Join the threads + asyncThread.join(); + syncThread.join(); +} - return 0; +void MessageEndpointServer::setAsyncLatch() +{ + asyncLatch = faabric::util::Latch::create(2); } -// We create a new endpoint every time. Re-using them would be a possible -// optimisation if needed. -void MessageEndpointServer::sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort) +void MessageEndpointServer::awaitAsyncLatch() { - // Open the endpoint socket, server connects (not bind) to remote address - SendMessageEndpoint endpoint(returnHost, returnPort + REPLY_PORT_OFFSET); - endpoint.open(faabric::transport::getGlobalMessageContext()); - endpoint.send(serialisedMsg, size); - endpoint.close(); + SPDLOG_TRACE("Waiting on async latch for port {}", asyncPort); + asyncLatch->wait(); + + SPDLOG_TRACE("Finished async latch for port {}", asyncPort); + asyncLatch = nullptr; } } diff --git a/src/transport/MpiMessageEndpoint.cpp b/src/transport/MpiMessageEndpoint.cpp index 67be90801..136456238 100644 --- a/src/transport/MpiMessageEndpoint.cpp +++ b/src/transport/MpiMessageEndpoint.cpp @@ -1,79 +1,29 @@ #include +#include +#include namespace faabric::transport { -faabric::MpiHostsToRanksMessage recvMpiHostRankMsg() -{ - faabric::transport::RecvMessageEndpoint endpoint(MPI_PORT); - endpoint.open(faabric::transport::getGlobalMessageContext()); - faabric::transport::Message m = endpoint.recv(); - PARSE_MSG(faabric::MpiHostsToRanksMessage, m.data(), m.size()); - endpoint.close(); - - return msg; -} - -void sendMpiHostRankMsg(const std::string& hostIn, - const faabric::MpiHostsToRanksMessage msg) -{ - size_t msgSize = msg.ByteSizeLong(); - { - uint8_t sMsg[msgSize]; - if (!msg.SerializeToArray(sMsg, msgSize)) { - throw std::runtime_error("Error serialising message"); - } - faabric::transport::SendMessageEndpoint endpoint(hostIn, MPI_PORT); - endpoint.open(faabric::transport::getGlobalMessageContext()); - endpoint.send(sMsg, msgSize, false); - endpoint.close(); - } -} - -MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int portIn) - : sendMessageEndpoint(hostIn, portIn) - , recvMessageEndpoint(portIn) -{ - sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); -} MpiMessageEndpoint::MpiMessageEndpoint(const std::string& hostIn, int sendPort, int recvPort) - : sendMessageEndpoint(hostIn, sendPort) - , recvMessageEndpoint(recvPort) -{ - sendMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); - recvMessageEndpoint.open(faabric::transport::getGlobalMessageContext()); -} + : host(hostIn) + , sendSocket(hostIn, sendPort) + , recvSocket(recvPort) +{} void MpiMessageEndpoint::sendMpiMessage( const std::shared_ptr& msg) { - size_t msgSize = msg->ByteSizeLong(); - { - uint8_t sMsg[msgSize]; - if (!msg->SerializeToArray(sMsg, msgSize)) { - throw std::runtime_error("Error serialising message"); - } - sendMessageEndpoint.send(sMsg, msgSize, false); - } + SERIALISE_MSG_PTR(msg) + sendSocket.send(buffer, msgSize, false); } std::shared_ptr MpiMessageEndpoint::recvMpiMessage() { - Message m = recvMessageEndpoint.recv(); + Message m = recvSocket.recv(); PARSE_MSG(faabric::MPIMessage, m.data(), m.size()); return std::make_shared(msg); } - -void MpiMessageEndpoint::close() -{ - if (sendMessageEndpoint.socket != nullptr) { - sendMessageEndpoint.close(); - } - if (recvMessageEndpoint.socket != nullptr) { - recvMessageEndpoint.close(); - } -} } diff --git a/src/transport/context.cpp b/src/transport/context.cpp new file mode 100644 index 000000000..3c6b9d918 --- /dev/null +++ b/src/transport/context.cpp @@ -0,0 +1,48 @@ +#include +#include +#include + +namespace faabric::transport { + +// The ZeroMQ context object is thread safe, so we're ok to have a single global +// instance. +static std::shared_ptr instance = nullptr; + +void initGlobalMessageContext() +{ + if (instance != nullptr) { + SPDLOG_WARN("ZeroMQ context already initialised. Skipping"); + return; + } + + SPDLOG_TRACE("Initialising global ZeroMQ context"); + instance = std::make_shared(ZMQ_CONTEXT_IO_THREADS); +} + +std::shared_ptr getGlobalMessageContext() +{ + if (instance == nullptr) { + throw std::runtime_error( + "Trying to access uninitialised ZeroMQ context"); + } + + return instance; +} + +void closeGlobalMessageContext() +{ + if (instance == nullptr) { + SPDLOG_WARN( + "ZeroMQ context already closed (or not initialised). Skipping"); + return; + } + + SPDLOG_TRACE("Destroying global ZeroMQ context"); + + // Force outstanding ops to return ETERM + instance->shutdown(); + + // Close the context + instance->close(); +} +} diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index cc204a2fb..3db722736 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -3,7 +3,6 @@ find_package(RapidJSON) file(GLOB HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/util/*.h") set(LIB_FILES - barrier.cpp bytes.cpp config.cpp clock.cpp @@ -14,6 +13,7 @@ set(LIB_FILES gids.cpp http.cpp json.cpp + latch.cpp logging.cpp memory.cpp network.cpp diff --git a/src/util/barrier.cpp b/src/util/barrier.cpp deleted file mode 100644 index 67bed11a5..000000000 --- a/src/util/barrier.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include - -namespace faabric::util { -Barrier::Barrier(int count) - : threadCount(count) - , slotCount(count) - , uses(0) -{} - -void Barrier::wait() -{ - { - UniqueLock lock(mx); - int usesCopy = uses; - - slotCount--; - if (slotCount == 0) { - uses++; - // Checks for overflow - if (uses < 0) { - throw std::runtime_error("Barrier was used too many times"); - } - slotCount = threadCount; - cv.notify_all(); - } else { - cv.wait(lock, [&] { return usesCopy < uses; }); - } - } -} - -int Barrier::getSlotCount() -{ - return slotCount; -} - -int Barrier::getUseCount() -{ - return uses; -} - -} diff --git a/src/util/latch.cpp b/src/util/latch.cpp new file mode 100644 index 000000000..8d90e007d --- /dev/null +++ b/src/util/latch.cpp @@ -0,0 +1,38 @@ +#include +#include +#include + +namespace faabric::util { + +std::shared_ptr Latch::create(int count, int timeoutMs) +{ + return std::make_shared(count, timeoutMs); +} + +Latch::Latch(int countIn, int timeoutMsIn) + : count(countIn) + , timeoutMs(timeoutMsIn) +{} + +void Latch::wait() +{ + UniqueLock lock(mx); + + waiters++; + + if (waiters > count) { + throw std::runtime_error("Latch already used"); + } + + if (waiters == count) { + cv.notify_all(); + } else { + auto timePoint = std::chrono::system_clock::now() + + std::chrono::milliseconds(timeoutMs); + + if (!cv.wait_until(lock, timePoint, [&] { return waiters >= count; })) { + throw std::runtime_error("Latch timed out"); + } + } +} +} diff --git a/tests/dist/main.cpp b/tests/dist/main.cpp index 424110470..a93cfcd35 100644 --- a/tests/dist/main.cpp +++ b/tests/dist/main.cpp @@ -1,44 +1,46 @@ #define CATCH_CONFIG_RUNNER -#include "DistTestExecutor.h" -#include "faabric_utils.h" #include +#include "DistTestExecutor.h" +#include "faabric_utils.h" #include "init.h" -#include #include #include +#include #include -using namespace faabric::scheduler; - FAABRIC_CATCH_LOGGER int main(int argc, char* argv[]) { + faabric::transport::initGlobalMessageContext(); faabric::util::initLogging(); - - // Set up the distributed tests tests::initDistTests(); - // Start everything up - SPDLOG_INFO("Starting distributed test server on master"); - std::shared_ptr fac = + std::shared_ptr fac = std::make_shared(); - faabric::runner::FaabricMain m(fac); - m.startBackground(); - - // Wait for things to start - usleep(3000 * 1000); - - // Run the tests - int result = Catch::Session().run(argc, argv); - fflush(stdout); - // Shut down - SPDLOG_INFO("Shutting down"); - m.shutdown(); + // WARNING: all 0MQ sockets have to have gone *out of scope* before we shut + // down the context, therefore this segment must be in a nested scope (or + // another function). + int result; + { + SPDLOG_INFO("Starting distributed test server on master"); + faabric::runner::FaabricMain m(fac); + m.startBackground(); + + // Run the tests + result = Catch::Session().run(argc, argv); + fflush(stdout); + + // Shut down + SPDLOG_INFO("Shutting down"); + m.shutdown(); + } + + faabric::transport::closeGlobalMessageContext(); return result; } diff --git a/tests/dist/server.cpp b/tests/dist/server.cpp index 2ab2d8431..d62248bd9 100644 --- a/tests/dist/server.cpp +++ b/tests/dist/server.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include using namespace faabric::scheduler; @@ -11,6 +12,7 @@ using namespace faabric::scheduler; int main() { faabric::util::initLogging(); + faabric::transport::initGlobalMessageContext(); tests::initDistTests(); SPDLOG_INFO("Starting distributed test server on worker"); @@ -19,12 +21,14 @@ int main() faabric::runner::FaabricMain m(fac); m.startBackground(); + // Note, endpoint will block until killed SPDLOG_INFO("Starting HTTP endpoint on worker"); faabric::endpoint::FaabricEndpoint endpoint; endpoint.start(); SPDLOG_INFO("Shutting down"); m.shutdown(); + faabric::transport::closeGlobalMessageContext(); return EXIT_SUCCESS; } diff --git a/tests/test/main.cpp b/tests/test/main.cpp index bddede144..28aac75a0 100644 --- a/tests/test/main.cpp +++ b/tests/test/main.cpp @@ -4,6 +4,7 @@ #include "faabric_utils.h" +#include #include #include @@ -11,6 +12,7 @@ FAABRIC_CATCH_LOGGER int main(int argc, char* argv[]) { + faabric::transport::initGlobalMessageContext(); faabric::util::setTestMode(true); faabric::util::initLogging(); @@ -18,5 +20,7 @@ int main(int argc, char* argv[]) fflush(stdout); + faabric::transport::closeGlobalMessageContext(); + return result; } diff --git a/tests/test/redis/test_redis.cpp b/tests/test/redis/test_redis.cpp index 6550c0dcc..c6711b67b 100644 --- a/tests/test/redis/test_redis.cpp +++ b/tests/test/redis/test_redis.cpp @@ -4,6 +4,7 @@ #include #include +#include #include @@ -544,7 +545,7 @@ TEST_CASE("Test enqueue after blocking dequeue") }); // Wait a bit (assume the waiting thread will get to block by now) - sleep(1); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); redisQueue.enqueue("foobar", "baz"); // If this hangs, the redis client isn't dequeueing after an enqueue is diff --git a/tests/test/runner/test_main.cpp b/tests/test/runner/test_main.cpp index 4268f7be3..9d81fc8aa 100644 --- a/tests/test/runner/test_main.cpp +++ b/tests/test/runner/test_main.cpp @@ -13,26 +13,45 @@ using namespace faabric::scheduler; namespace tests { -TEST_CASE("Test main runner", "[runner]") +class MainRunnerTestFixture : public SchedulerTestFixture +{ + public: + MainRunnerTestFixture() + { + std::shared_ptr fac = + std::make_shared(); + faabric::scheduler::setExecutorFactory(fac); + } +}; + +TEST_CASE_METHOD(MainRunnerTestFixture, "Test main runner", "[runner]") { - cleanFaabric(); std::shared_ptr fac = faabric::scheduler::getExecutorFactory(); faabric::runner::FaabricMain m(fac); - m.startRunner(); + m.startBackground(); - std::shared_ptr req = - faabric::util::batchExecFactory("foo", "bar", 4); + SECTION("Do nothing") {} - auto& sch = faabric::scheduler::getScheduler(); - sch.callFunctions(req); + SECTION("Make calls") + { + std::shared_ptr req = + faabric::util::batchExecFactory("foo", "bar", 4); - for (const auto& m : req->messages()) { - std::string expected = fmt::format("DummyExecutor executed {}", m.id()); - faabric::Message res = - sch.getFunctionResult(m.id(), SHORT_TEST_TIMEOUT_MS); - REQUIRE(res.outputdata() == expected); + auto& sch = faabric::scheduler::getScheduler(); + sch.callFunctions(req); + + for (const auto& m : req->messages()) { + std::string expected = + fmt::format("DummyExecutor executed {}", m.id()); + faabric::Message res = + sch.getFunctionResult(m.id(), SHORT_TEST_TIMEOUT_MS); + REQUIRE(res.outputdata() == expected); + } } + + m.shutdown(); } + } diff --git a/tests/test/scheduler/test_executor.cpp b/tests/test/scheduler/test_executor.cpp index a5a52a527..87a87b236 100644 --- a/tests/test/scheduler/test_executor.cpp +++ b/tests/test/scheduler/test_executor.cpp @@ -439,7 +439,8 @@ TEST_CASE_METHOD(TestExecutorFixture, faabric::scheduler::queueResourceResponse(otherHost, resOther); // Background thread to execute main function and await results - std::thread t([] { + auto latch = faabric::util::Latch::create(2); + std::thread t([&latch] { int nThreads = 8; std::shared_ptr req = faabric::util::batchExecFactory("dummy", "thread-check", 1); @@ -449,19 +450,20 @@ TEST_CASE_METHOD(TestExecutorFixture, auto& sch = faabric::scheduler::getScheduler(); sch.callFunctions(req, false); + latch->wait(); faabric::Message res = sch.getFunctionResult(msg.id(), 2000); assert(res.returnvalue() == 0); }); - // Give it time to have made the request - usleep(SHORT_TEST_TIMEOUT_MS * 1000); + // Wait until the function has executed and submit another request + auto reqs = faabric::scheduler::getBatchRequests(); + REQUIRE_RETRY(reqs = faabric::scheduler::getBatchRequests(), + reqs.size() == 1); // Check restore hasn't been called yet REQUIRE(restoreCount == 0); // Get the request that's been submitted - auto reqs = faabric::scheduler::getBatchRequests(); - REQUIRE(reqs.size() == 1); std::string actualHost = reqs.at(0).first; REQUIRE(actualHost == otherHost); @@ -491,6 +493,7 @@ TEST_CASE_METHOD(TestExecutorFixture, } // Rejoin the other thread + latch->wait(); if (t.joinable()) { t.join(); } @@ -519,11 +522,11 @@ TEST_CASE_METHOD(TestExecutorFixture, executeWithTestExecutor(req, true); - // We have to manually add a wait here as the thread results won't actually - // get logged on this host - usleep(SHORT_TEST_TIMEOUT_MS * 1000); + // Note that because the results don't actually get logged on this host, we + // can't wait on them as usual. auto actual = faabric::scheduler::getThreadResults(); - REQUIRE(actual.size() == nThreads); + REQUIRE_RETRY(actual = faabric::scheduler::getThreadResults(), + actual.size() == nThreads); std::vector actualMessageIds; for (auto& p : actual) { @@ -681,9 +684,7 @@ TEST_CASE_METHOD(TestExecutorFixture, REQUIRE(sch.getFunctionExecutorCount(msg) == 1); - usleep((conf.boundTimeout + 500) * 1000); - - REQUIRE(sch.getFunctionExecutorCount(msg) == 0); + REQUIRE_RETRY({}, sch.getFunctionExecutorCount(msg) == 0); } TEST_CASE_METHOD(TestExecutorFixture, @@ -711,11 +712,10 @@ TEST_CASE_METHOD(TestExecutorFixture, executeWithTestExecutor(req, true); - // Wait for executor to have finished - sometimes takes a while - usleep(SHORT_TEST_TIMEOUT_MS * 1000); - - // Check thread results returned - REQUIRE(faabric::scheduler::getThreadResults().size() == nThreads); + // Results aren't set on this host as it's not the master, so we have to + // wait + REQUIRE_RETRY({}, + faabric::scheduler::getThreadResults().size() == nThreads); // Check results have been sent back to the master host auto actualResults = faabric::scheduler::getThreadResults(); diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index b6f5c7b4f..ccbfca203 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -13,12 +13,11 @@ #include #include #include +#include #include #include #include -#define TEST_TIMEOUT_MS 500 - using namespace faabric::scheduler; namespace tests { @@ -36,17 +35,15 @@ class ClientServerFixture ClientServerFixture() : cli(LOCALHOST) { - server.start(); - usleep(1000 * TEST_TIMEOUT_MS); - // Set up executor executorFactory = std::make_shared(); setExecutorFactory(executorFactory); + + server.start(); } ~ClientServerFixture() { - cli.close(); server.stop(); executorFactory->reset(); } @@ -79,9 +76,8 @@ TEST_CASE_METHOD(ClientServerFixture, REQUIRE(msgs.at(1).function() == "bar"); sch.clearRecordedMessages(); - // Send flush message + // Send flush message (which is synchronous) cli.sendFlush(); - usleep(1000 * TEST_TIMEOUT_MS); // Check the scheduler has been flushed REQUIRE(sch.getFunctionRegisteredHostCount(msgA) == 0); @@ -138,7 +134,11 @@ TEST_CASE_METHOD(ClientServerFixture, // Make the request cli.executeFunctions(req); - usleep(1000 * TEST_TIMEOUT_MS); + + for (const auto& m : req->messages()) { + // This timeout can be long as it shouldn't fail + sch.getFunctionResult(m.id(), 5 * SHORT_TEST_TIMEOUT_MS); + } // Check no other hosts have been registered faabric::Message m = req->messages().at(0); @@ -219,15 +219,21 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") reqA.set_host("foobar"); *reqA.mutable_function() = msg; + // Check that nothing's happened + server.setAsyncLatch(); cli.unregister(reqA); + server.awaitAsyncLatch(); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 1); // Make the request to unregister the actual host faabric::UnregisterRequest reqB; reqB.set_host(otherHost); *reqB.mutable_function() = msg; + + server.setAsyncLatch(); cli.unregister(reqB); - usleep(1000 * TEST_TIMEOUT_MS); + server.awaitAsyncLatch(); + REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 0); sch.setThisHostResources(originalResources); diff --git a/tests/test/scheduler/test_mpi_world.cpp b/tests/test/scheduler/test_mpi_world.cpp index 5144de4ba..0dec3519e 100644 --- a/tests/test/scheduler/test_mpi_world.cpp +++ b/tests/test/scheduler/test_mpi_world.cpp @@ -63,22 +63,6 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test creating world of size 1", "[mpi]") world.destroy(); } -TEST_CASE_METHOD(MpiTestFixture, "Test world loading from msg", "[mpi]") -{ - // Create another copy from state - scheduler::MpiWorld worldB; - // Force creating the second world in the _same_ host - bool forceLocal = true; - worldB.initialiseFromMsg(msg, forceLocal); - - REQUIRE(worldB.getSize() == worldSize); - REQUIRE(worldB.getId() == worldId); - REQUIRE(worldB.getUser() == user); - REQUIRE(worldB.getFunction() == func); - - worldB.destroy(); -} - TEST_CASE_METHOD(MpiBaseTestFixture, "Test cartesian communicator", "[mpi]") { MpiWorld world; diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index 0669fb245..1cceed404 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -18,83 +18,118 @@ class RemoteCollectiveTestFixture : public RemoteMpiTestFixture { public: RemoteCollectiveTestFixture() - : thisWorldSize(6) - , remoteRankA(1) - , remoteRankB(2) - , remoteRankC(3) - , localRankA(4) - , localRankB(5) - , remoteWorldRanks({ remoteRankB, remoteRankC, remoteRankA }) - , localWorldRanks({ localRankB, localRankA, 0 }) - {} + { + thisWorldRanks = { thisHostRankB, thisHostRankA, 0 }; + otherWorldRanks = { otherHostRankB, otherHostRankC, otherHostRankA }; + + // Here we rely on the scheduler running out of resources and + // overloading this world with ranks 4 and 5 + setWorldSizes(thisWorldSize, 1, 3); + } + + MpiWorld& setUpThisWorld() + { + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); + faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); + + // Check it's set up as we expect + for (auto r : otherWorldRanks) { + REQUIRE(thisWorld.getHostForRank(r) == otherHost); + } + + for (auto r : thisWorldRanks) { + REQUIRE(thisWorld.getHostForRank(r) == thisHost); + } + + return thisWorld; + } protected: - int thisWorldSize; - int remoteRankA, remoteRankB, remoteRankC; - int localRankA, localRankB; - std::vector remoteWorldRanks; - std::vector localWorldRanks; + int thisWorldSize = 6; + + int otherHostRankA = 1; + int otherHostRankB = 2; + int otherHostRankC = 3; + + int thisHostRankA = 4; + int thisHostRankB = 5; + + std::vector otherWorldRanks; + std::vector thisWorldRanks; }; TEST_CASE_METHOD(RemoteMpiTestFixture, "Test rank allocation", "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - remoteWorld.initialiseFromMsg(msg); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); + + // Background thread to receive the allocation + std::thread otherWorldThread([this] { + otherWorld.initialiseFromMsg(msg); + + REQUIRE(otherWorld.getHostForRank(0) == thisHost); + REQUIRE(otherWorld.getHostForRank(1) == otherHost); + + otherWorld.destroy(); + }); + + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } - // Now check both world instances report the same mappings - REQUIRE(localWorld.getHostForRank(0) == thisHost); - REQUIRE(localWorld.getHostForRank(1) == otherHost); + REQUIRE(thisWorld.getHostForRank(0) == thisHost); + REQUIRE(thisWorld.getHostForRank(1) == otherHost); - // Destroy worlds - localWorld.destroy(); - remoteWorld.destroy(); + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; std::vector messageData = { 0, 1, 2 }; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, rankA, rankB, &messageData] { - remoteWorld.initialiseFromMsg(msg); + // Start the "remote" world in the background + std::thread otherWorldThread([this, rankA, rankB, &messageData] { + otherWorld.initialiseFromMsg(msg); - // Send a message that should get sent to this host - remoteWorld.send( - rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + // Receive the message for the given rank + MPI_Status status{}; + auto buffer = new int[messageData.size()]; + otherWorld.recv( + rankA, rankB, BYTES(buffer), MPI_INT, messageData.size(), &status); - usleep(1000 * 500); + std::vector actual(buffer, buffer + messageData.size()); + assert(actual == messageData); - remoteWorld.destroy(); - }); + assert(status.MPI_SOURCE == rankA); + assert(status.MPI_ERROR == MPI_SUCCESS); + assert(status.bytesSize == messageData.size() * sizeof(int)); - // Receive the message for the given rank - MPI_Status status{}; - auto buffer = new int[messageData.size()]; - localWorld.recv( - rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); + otherWorld.destroy(); + }); - std::vector actual(buffer, buffer + messageData.size()); - REQUIRE(actual == messageData); + // Send a message that should get sent to the "remote" world + thisWorld.send( + rankA, rankB, BYTES(messageData.data()), MPI_INT, messageData.size()); - REQUIRE(status.MPI_SOURCE == rankB); - REQUIRE(status.MPI_ERROR == MPI_SUCCESS); - REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } - // Destroy worlds - senderThread.join(); - localWorld.destroy(); + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -102,101 +137,115 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; std::vector messageData = { 0, 1, 2 }; std::vector messageData2 = { 3, 4, 5 }; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, rankA, rankB, &messageData, &messageData2] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread( + [this, rankA, rankB, &messageData, &messageData2] { + otherWorld.initialiseFromMsg(msg); - // Send a message that should get sent to this host - remoteWorld.send( - rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size()); + // Send a message that should get sent to this host + otherWorld.send(rankB, + rankA, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); - // Now recv - auto buffer = new int[messageData2.size()]; - remoteWorld.recv(rankA, - rankB, - BYTES(buffer), - MPI_INT, - messageData2.size(), - MPI_STATUS_IGNORE); - std::vector actual(buffer, buffer + messageData2.size()); - REQUIRE(actual == messageData2); + // Now recv + auto buffer = new int[messageData2.size()]; + otherWorld.recv(rankA, + rankB, + BYTES(buffer), + MPI_INT, + messageData2.size(), + MPI_STATUS_IGNORE); + std::vector actual(buffer, buffer + messageData2.size()); + REQUIRE(actual == messageData2); - usleep(1000 * 500); + testLatch->wait(); - remoteWorld.destroy(); - }); + otherWorld.destroy(); + }); // Receive the message for the given rank MPI_Status status{}; auto buffer = new int[messageData.size()]; - localWorld.recv( + thisWorld.recv( rankB, rankA, BYTES(buffer), MPI_INT, messageData.size(), &status); std::vector actual(buffer, buffer + messageData.size()); REQUIRE(actual == messageData); // Now send a message - localWorld.send( + thisWorld.send( rankA, rankB, BYTES(messageData2.data()), MPI_INT, messageData2.size()); REQUIRE(status.MPI_SOURCE == rankB); REQUIRE(status.MPI_ERROR == MPI_SUCCESS); REQUIRE(status.bytesSize == messageData.size() * sizeof(int)); - // Destroy worlds - senderThread.join(); - localWorld.destroy(); + testLatch->wait(); + + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, "Test barrier across hosts", "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; std::vector sendData = { 0, 1, 2 }; std::vector recvData = { -1, -1, -1 }; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - std::thread senderThread([this, rankA, rankB, &sendData, &recvData] { - remoteWorld.initialiseFromMsg(msg); + thisWorld.broadcastHostsToRanks(); + + std::thread otherWorldThread([this, rankA, rankB, &sendData, &recvData] { + otherWorld.initialiseFromMsg(msg); - remoteWorld.send( + otherWorld.send( rankB, rankA, BYTES(sendData.data()), MPI_INT, sendData.size()); // Barrier on this rank - remoteWorld.barrier(rankB); + otherWorld.barrier(rankB); assert(sendData == recvData); - - remoteWorld.destroy(); + otherWorld.destroy(); }); // Receive the message for the given rank - localWorld.recv(rankB, - rankA, - BYTES(recvData.data()), - MPI_INT, - recvData.size(), - MPI_STATUS_IGNORE); + thisWorld.recv(rankB, + rankA, + BYTES(recvData.data()), + MPI_INT, + recvData.size(), + MPI_STATUS_IGNORE); REQUIRE(recvData == sendData); // Call barrier to synchronise remote host - localWorld.barrier(rankA); + thisWorld.barrier(rankA); + + // Clean up + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } - // Destroy worlds - senderThread.join(); - localWorld.destroy(); + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -204,30 +253,31 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Register two ranks (one on each host) - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int rankA = 0; int rankB = 1; int numMessages = 1000; // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); - std::thread senderThread([this, rankA, rankB, numMessages] { - remoteWorld.initialiseFromMsg(msg); + thisWorld.broadcastHostsToRanks(); + + std::thread otherWorldThread([this, rankA, rankB, numMessages] { + otherWorld.initialiseFromMsg(msg); for (int i = 0; i < numMessages; i++) { - remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); + otherWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1); } - usleep(1000 * 500); - - remoteWorld.destroy(); + testLatch->wait(); + otherWorld.destroy(); }); int recv; for (int i = 0; i < numMessages; i++) { - localWorld.recv( + thisWorld.recv( rankB, rankA, BYTES(&recv), MPI_INT, 1, MPI_STATUS_IGNORE); // Check in-order delivery @@ -236,71 +286,71 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, } } - // Destroy worlds - senderThread.join(); - localWorld.destroy(); + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, "Test broadcast across hosts", "[mpi]") { - // Here we rely on the scheduler running out of resources, and overloading - // the localWorld with ranks 4 and 5 - this->setWorldsSizes(thisWorldSize, 1, 3); - std::vector messageData = { 0, 1, 2 }; + MpiWorld& thisWorld = setUpThisWorld(); - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - faabric::util::setMockMode(false); + std::vector messageData = { 0, 1, 2 }; - std::thread senderThread([this, &messageData] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, &messageData] { + otherWorld.initialiseFromMsg(msg); // Broadcast a message - remoteWorld.broadcast( - remoteRankB, BYTES(messageData.data()), MPI_INT, messageData.size()); - - // Check the host that the root is on - for (int rank : remoteWorldRanks) { - if (rank == remoteRankB) { + otherWorld.broadcast(otherHostRankB, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + + // Check the broadcast is received on this host by the other ranks + for (int rank : otherWorldRanks) { + if (rank == otherHostRankB) { continue; } std::vector actual(3, -1); - remoteWorld.recv( - remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); + otherWorld.recv( + otherHostRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); assert(actual == messageData); } - usleep(1000 * 500); - - remoteWorld.destroy(); + // Give the other host time to receive the broadcast + testLatch->wait(); + otherWorld.destroy(); }); - // Check the local host - for (int rank : localWorldRanks) { + // Check the ranks on this host receive the broadcast + for (int rank : thisWorldRanks) { std::vector actual(3, -1); - localWorld.recv( - remoteRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); + thisWorld.recv( + otherHostRankB, rank, BYTES(actual.data()), MPI_INT, 3, nullptr); REQUIRE(actual == messageData); } - senderThread.join(); - localWorld.destroy(); + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, "Test scatter across hosts", "[mpi]") { - // Here we rely on the scheduler running out of resources, and overloading - // the localWorld with ranks 4 and 5 - this->setWorldsSizes(thisWorldSize, 1, 3); - - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - faabric::util::setMockMode(false); + MpiWorld& thisWorld = setUpThisWorld(); // Build the data int nPerRank = 4; @@ -310,95 +360,94 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, messageData[i] = i; } - std::thread senderThread([this, nPerRank, &messageData] { - remoteWorld.initialiseFromMsg(msg); - // Do the scatter + std::thread otherWorldThread([this, nPerRank, &messageData] { + otherWorld.initialiseFromMsg(msg); + + // Do the scatter (when send rank == recv rank) std::vector actual(nPerRank, -1); - remoteWorld.scatter(remoteRankB, - remoteRankB, - BYTES(messageData.data()), - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + otherWorld.scatter(otherHostRankB, + otherHostRankB, + BYTES(messageData.data()), + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); // Check for root assert(actual == std::vector({ 8, 9, 10, 11 })); - // Check for other remote ranks - remoteWorld.scatter(remoteRankB, - remoteRankA, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + // Check the other ranks on this host have received the data + otherWorld.scatter(otherHostRankB, + otherHostRankA, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); assert(actual == std::vector({ 4, 5, 6, 7 })); - remoteWorld.scatter(remoteRankB, - remoteRankC, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + otherWorld.scatter(otherHostRankB, + otherHostRankC, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); assert(actual == std::vector({ 12, 13, 14, 15 })); - usleep(1000 * 500); - - remoteWorld.destroy(); + testLatch->wait(); + otherWorld.destroy(); }); - // Check for local ranks + // Check for ranks on this host std::vector actual(nPerRank, -1); - localWorld.scatter(remoteRankB, - 0, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.scatter(otherHostRankB, + 0, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); REQUIRE(actual == std::vector({ 0, 1, 2, 3 })); - localWorld.scatter(remoteRankB, - localRankB, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.scatter(otherHostRankB, + thisHostRankB, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); REQUIRE(actual == std::vector({ 20, 21, 22, 23 })); - localWorld.scatter(remoteRankB, - localRankA, - nullptr, - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.scatter(otherHostRankB, + thisHostRankA, + nullptr, + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); REQUIRE(actual == std::vector({ 16, 17, 18, 19 })); - senderThread.join(); - localWorld.destroy(); + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteCollectiveTestFixture, "Test gather across hosts", "[mpi]") { - // Here we rely on the scheduler running out of resources, and overloading - // the localWorld with ranks 4 and 5 - this->setWorldsSizes(thisWorldSize, 1, 3); - - // Init worlds - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); - faabric::util::setMockMode(false); + MpiWorld& thisWorld = setUpThisWorld(); // Build the data for each rank int nPerRank = 4; @@ -421,55 +470,59 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, std::vector actual(thisWorldSize * nPerRank, -1); // Call gather for each rank other than the root (out of order) - int root = localRankA; - std::thread senderThread([this, root, &rankData, nPerRank] { - remoteWorld.initialiseFromMsg(msg); - - for (int rank : remoteWorldRanks) { - remoteWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); + int root = thisHostRankA; + std::thread otherWorldThread([this, root, &rankData, nPerRank] { + otherWorld.initialiseFromMsg(msg); + + for (int rank : otherWorldRanks) { + otherWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); } - usleep(1000 * 500); - - remoteWorld.destroy(); + testLatch->wait(); + otherWorld.destroy(); }); - for (int rank : localWorldRanks) { + for (int rank : thisWorldRanks) { if (rank == root) { continue; } - localWorld.gather(rank, - root, - BYTES(rankData[rank].data()), - MPI_INT, - nPerRank, - nullptr, - MPI_INT, - nPerRank); + thisWorld.gather(rank, + root, + BYTES(rankData[rank].data()), + MPI_INT, + nPerRank, + nullptr, + MPI_INT, + nPerRank); } // Call gather for root - localWorld.gather(root, - root, - BYTES(rankData[root].data()), - MPI_INT, - nPerRank, - BYTES(actual.data()), - MPI_INT, - nPerRank); + thisWorld.gather(root, + root, + BYTES(rankData[root].data()), + MPI_INT, + nPerRank, + BYTES(actual.data()), + MPI_INT, + nPerRank); // Check data REQUIRE(actual == expected); - senderThread.join(); - localWorld.destroy(); + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -477,62 +530,66 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int sendRank = 1; int recvRank = 0; std::vector messageData = { 0, 1, 2 }; // Init world - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, sendRank, recvRank, &messageData] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, sendRank, recvRank, &messageData] { + otherWorld.initialiseFromMsg(msg); // Send message twice - remoteWorld.send(sendRank, - recvRank, - BYTES(messageData.data()), - MPI_INT, - messageData.size()); - remoteWorld.send(sendRank, - recvRank, - BYTES(messageData.data()), - MPI_INT, - messageData.size()); - - usleep(1000 * 500); - - remoteWorld.destroy(); + otherWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + otherWorld.send(sendRank, + recvRank, + BYTES(messageData.data()), + MPI_INT, + messageData.size()); + + testLatch->wait(); + otherWorld.destroy(); }); // Receive one message asynchronously std::vector asyncMessage(messageData.size(), 0); - int recvId = localWorld.irecv(sendRank, - recvRank, - BYTES(asyncMessage.data()), - MPI_INT, - asyncMessage.size()); + int recvId = thisWorld.irecv(sendRank, + recvRank, + BYTES(asyncMessage.data()), + MPI_INT, + asyncMessage.size()); // Receive one message synchronously std::vector syncMessage(messageData.size(), 0); - localWorld.recv(sendRank, - recvRank, - BYTES(syncMessage.data()), - MPI_INT, - syncMessage.size(), - MPI_STATUS_IGNORE); + thisWorld.recv(sendRank, + recvRank, + BYTES(syncMessage.data()), + MPI_INT, + syncMessage.size(), + MPI_STATUS_IGNORE); // Wait for the async message - localWorld.awaitAsyncRequest(recvId); + thisWorld.awaitAsyncRequest(recvId); // Checks REQUIRE(syncMessage == messageData); REQUIRE(asyncMessage == messageData); - // Destroy world - senderThread.join(); - localWorld.destroy(); + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -540,49 +597,49 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(2, 1, 1); + setWorldSizes(2, 1, 1); int sendRank = 1; int recvRank = 0; // Init world - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, sendRank, recvRank] { - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, sendRank, recvRank] { + otherWorld.initialiseFromMsg(msg); // Send different messages for (int i = 0; i < 3; i++) { - remoteWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); + otherWorld.send(sendRank, recvRank, BYTES(&i), MPI_INT, 1); } - usleep(1000 * 500); - - remoteWorld.destroy(); + testLatch->wait(); + otherWorld.destroy(); }); // Receive two messages asynchronously int recv1, recv2, recv3; int recvId1 = - localWorld.irecv(sendRank, recvRank, BYTES(&recv1), MPI_INT, 1); + thisWorld.irecv(sendRank, recvRank, BYTES(&recv1), MPI_INT, 1); int recvId2 = - localWorld.irecv(sendRank, recvRank, BYTES(&recv2), MPI_INT, 1); + thisWorld.irecv(sendRank, recvRank, BYTES(&recv2), MPI_INT, 1); // Receive one message synchronously - localWorld.recv( + thisWorld.recv( sendRank, recvRank, BYTES(&recv3), MPI_INT, 1, MPI_STATUS_IGNORE); SECTION("Wait out of order") { - localWorld.awaitAsyncRequest(recvId2); - localWorld.awaitAsyncRequest(recvId1); + thisWorld.awaitAsyncRequest(recvId2); + thisWorld.awaitAsyncRequest(recvId1); } SECTION("Wait in order") { - localWorld.awaitAsyncRequest(recvId1); - localWorld.awaitAsyncRequest(recvId2); + thisWorld.awaitAsyncRequest(recvId1); + thisWorld.awaitAsyncRequest(recvId2); } // Checks @@ -590,9 +647,13 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, REQUIRE(recv2 == 1); REQUIRE(recv3 == 2); - // Destroy world - senderThread.join(); - localWorld.destroy(); + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } TEST_CASE_METHOD(RemoteMpiTestFixture, @@ -600,62 +661,66 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "[mpi]") { // Allocate two ranks in total, one rank per host - this->setWorldsSizes(3, 1, 2); + setWorldSizes(3, 1, 2); int worldSize = 3; - std::vector localRanks = { 0 }; + std::vector thisHostRanks = { 0 }; // Init world - MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId); + MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId); faabric::util::setMockMode(false); + thisWorld.broadcastHostsToRanks(); - std::thread senderThread([this, worldSize] { - std::vector remoteRanks = { 1, 2 }; - remoteWorld.initialiseFromMsg(msg); + std::thread otherWorldThread([this, worldSize] { + std::vector otherHostRanks = { 1, 2 }; + otherWorld.initialiseFromMsg(msg); // Send different messages - for (auto& rank : remoteRanks) { + for (auto& rank : otherHostRanks) { int left = rank > 0 ? rank - 1 : worldSize - 1; int right = (rank + 1) % worldSize; int recvData = -1; - remoteWorld.sendRecv(BYTES(&rank), - 1, - MPI_INT, - right, - BYTES(&recvData), - 1, - MPI_INT, - left, - rank, - MPI_STATUS_IGNORE); + otherWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); } - usleep(1000 * 500); - - remoteWorld.destroy(); + testLatch->wait(); + otherWorld.destroy(); }); - for (auto& rank : localRanks) { + for (auto& rank : thisHostRanks) { int left = rank > 0 ? rank - 1 : worldSize - 1; int right = (rank + 1) % worldSize; int recvData = -1; - localWorld.sendRecv(BYTES(&rank), - 1, - MPI_INT, - right, - BYTES(&recvData), - 1, - MPI_INT, - left, - rank, - MPI_STATUS_IGNORE); + thisWorld.sendRecv(BYTES(&rank), + 1, + MPI_INT, + right, + BYTES(&recvData), + 1, + MPI_INT, + left, + rank, + MPI_STATUS_IGNORE); REQUIRE(recvData == left); } - // Destroy world - senderThread.join(); - localWorld.destroy(); + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); } } diff --git a/tests/test/scheduler/test_scheduler.cpp b/tests/test/scheduler/test_scheduler.cpp index 2b906a152..dfb00fe9c 100644 --- a/tests/test/scheduler/test_scheduler.cpp +++ b/tests/test/scheduler/test_scheduler.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include using namespace faabric::scheduler; @@ -32,10 +33,10 @@ class SlowExecutor final : public Executor int msgIdx, std::shared_ptr req) override { - SPDLOG_DEBUG("SlowExecutor executing task{}", + SPDLOG_DEBUG("Slow executor executing task{}", req->mutable_messages()->at(msgIdx).id()); - usleep(SHORT_TEST_TIMEOUT_MS * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); return 0; } }; diff --git a/tests/test/scheduler/test_snapshot_client_server.cpp b/tests/test/scheduler/test_snapshot_client_server.cpp index 27461c97e..9319c76ed 100644 --- a/tests/test/scheduler/test_snapshot_client_server.cpp +++ b/tests/test/scheduler/test_snapshot_client_server.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -29,14 +30,9 @@ class SnapshotClientServerFixture : cli(LOCALHOST) { server.start(); - usleep(1000 * SHORT_TEST_TIMEOUT_MS); } - ~SnapshotClientServerFixture() - { - cli.close(); - server.stop(); - } + ~SnapshotClientServerFixture() { server.stop(); } }; TEST_CASE_METHOD(SnapshotClientServerFixture, @@ -66,9 +62,7 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, cli.pushSnapshot(snapKeyA, snapA); cli.pushSnapshot(snapKeyB, snapB); - usleep(1000 * 500); - - // Check snapshots created in regsitry + // Check snapshots created in registry REQUIRE(reg.getSnapshotCount() == 2); const faabric::util::SnapshotData& actualA = reg.getSnapshot(snapKeyA); const faabric::util::SnapshotData& actualB = reg.getSnapshot(snapKeyB); diff --git a/tests/test/state/test_state.cpp b/tests/test/state/test_state.cpp index 9b2a21e5c..1827531cd 100644 --- a/tests/test/state/test_state.cpp +++ b/tests/test/state/test_state.cpp @@ -3,33 +3,154 @@ #include "faabric_utils.h" #include -#include #include #include +#include #include +#include #include - +#include #include + #include using namespace faabric::state; namespace tests { + static int staticCount = 0; -static void setUpDummyServer(DummyStateServer& server, - const std::vector& values) +class StateServerTestFixture + : public StateTestFixture + , ConfTestFixture { - cleanFaabric(); + public: + // Set up a local server with a *different* state instance to the main + // thread. This way we can fake the master/ non-master setup + StateServerTestFixture() + : remoteState(LOCALHOST) + , stateServer(remoteState) + { + staticCount++; + const std::string stateKey = "state_key_" + std::to_string(staticCount); - staticCount++; - const std::string stateKey = "state_key_" + std::to_string(staticCount); + conf.stateMode = "inmemory"; - // Set state remotely - server.dummyUser = "demo"; - server.dummyKey = stateKey; - server.dummyData = values; -} + dummyUser = "demo"; + dummyKey = stateKey; + + // Start the state server + SPDLOG_DEBUG("Running state server"); + stateServer.start(); + } + + ~StateServerTestFixture() { stateServer.stop(); } + + void setDummyData(std::vector data) + { + dummyData = data; + + // Master the dummy data in the "remote" state + if (!dummyData.empty()) { + std::string originalHost = conf.endpointHost; + conf.endpointHost = LOCALHOST; + + const std::shared_ptr& kv = + remoteState.getKV(dummyUser, dummyKey, dummyData.size()); + + std::shared_ptr inMemKv = + std::static_pointer_cast(kv); + + // Check this kv "thinks" it's master + if (!inMemKv->isMaster()) { + SPDLOG_ERROR("Dummy state server not master for data"); + throw std::runtime_error("Dummy state server failed"); + } + + // Set the data + kv->set(dummyData.data()); + SPDLOG_DEBUG( + "Finished setting master for test {}/{}", kv->user, kv->key); + + conf.endpointHost = originalHost; + } + } + + std::shared_ptr getRemoteKv() + { + if (dummyData.empty()) { + return remoteState.getKV(dummyUser, dummyKey); + } + return remoteState.getKV(dummyUser, dummyKey, dummyData.size()); + } + + std::shared_ptr getLocalKv() + { + if (dummyData.empty()) { + return state::getGlobalState().getKV(dummyUser, dummyKey); + } + + return state::getGlobalState().getKV( + dummyUser, dummyKey, dummyData.size()); + } + + std::vector getRemoteKvValue() + { + std::vector actual(dummyData.size(), 0); + getRemoteKv()->get(actual.data()); + return actual; + } + + std::vector getLocalKvValue() + { + std::vector actual(dummyData.size(), 0); + getLocalKv()->get(actual.data()); + return actual; + } + + void checkPulling(bool doPull) + { + std::vector values = { 0, 1, 2, 3 }; + std::vector actual(values.size(), 0); + setDummyData(values); + + // Get, with optional pull + int nMessages = 1; + if (doPull) { + nMessages = 2; + } + + // Initial pull + const std::shared_ptr& localKv = getLocalKv(); + localKv->pull(); + + // Update directly on the remote KV + std::shared_ptr remoteKv = getRemoteKv(); + std::vector newValues = { 5, 5, 5, 5 }; + remoteKv->set(newValues.data()); + + if (doPull) { + // Check locak changed with another pull + localKv->pull(); + localKv->get(actual.data()); + REQUIRE(actual == newValues); + } else { + // Check local unchanged without another pull + localKv->get(actual.data()); + REQUIRE(actual == values); + } + } + + protected: + std::string dummyUser; + std::string dummyKey; + + state::State remoteState; + state::StateServer stateServer; + + private: + std::vector dummyData; +}; static std::shared_ptr setupKV(size_t size) { @@ -46,30 +167,29 @@ static std::shared_ptr setupKV(size_t size) return kv; } -TEST_CASE("Test in-memory state sizes", "[state]") +TEST_CASE_METHOD(StateTestFixture, "Test in-memory state sizes", "[state]") { - cleanFaabric(); - State& s = getGlobalState(); std::string user = "alpha"; std::string key = "beta"; // Empty should be none - size_t initialSize = s.getStateSize(user, key); + size_t initialSize = state.getStateSize(user, key); REQUIRE(initialSize == 0); // Set a value std::vector bytes = { 0, 1, 2, 3, 4 }; - auto kv = s.getKV(user, key, bytes.size()); + auto kv = state.getKV(user, key, bytes.size()); kv->set(bytes.data()); kv->pushFull(); // Get size - REQUIRE(s.getStateSize(user, key) == bytes.size()); + REQUIRE(state.getStateSize(user, key) == bytes.size()); } -TEST_CASE("Test simple in memory state get/set", "[state]") +TEST_CASE_METHOD(StateTestFixture, + "Test simple in memory state get/set", + "[state]") { - cleanFaabric(); auto kv = setupKV(5); std::vector actual(5); @@ -91,22 +211,20 @@ TEST_CASE("Test simple in memory state get/set", "[state]") REQUIRE(actual == values); } -TEST_CASE("Test in memory get/ set chunk", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test in memory get/ set chunk", + "[state]") { - DummyStateServer server; std::vector values = { 0, 0, 1, 1, 2, 2, 3, 3, 4, 4 }; - setUpDummyServer(server, values); - - // Start the server - server.start(); + setDummyData(values); // Get locally - std::vector actual = server.getLocalKvValue(); + std::vector actual = getLocalKvValue(); REQUIRE(actual == values); // Update a subsection std::vector update = { 8, 8, 8 }; - std::shared_ptr localKv = server.getLocalKv(); + std::shared_ptr localKv = getLocalKv(); localKv->setChunk(3, update.data(), 3); std::vector expected = { 0, 0, 1, 8, 8, 8, 3, 3, 4, 4 }; @@ -114,7 +232,7 @@ TEST_CASE("Test in memory get/ set chunk", "[state]") REQUIRE(actual == expected); // Check remote is unchanged - REQUIRE(server.getRemoteKvValue() == values); + REQUIRE(getRemoteKvValue() == values); // Try getting a chunk locally std::vector actualChunk(3); @@ -123,21 +241,18 @@ TEST_CASE("Test in memory get/ set chunk", "[state]") // Run push and check remote is updated localKv->pushPartial(); - REQUIRE(server.getRemoteKvValue() == expected); - - // Wait for server to finish - server.stop(); + REQUIRE(getRemoteKvValue() == expected); } -TEST_CASE("Test in memory marking chunks dirty", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test in memory marking chunks dirty", + "[state]") { std::vector values = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; - DummyStateServer server; - setUpDummyServer(server, values); - server.start(); + setDummyData(values); // Get pointer to local and update in memory only - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); uint8_t* ptr = localKv->get(); ptr[0] = 8; ptr[5] = 7; @@ -150,27 +265,23 @@ TEST_CASE("Test in memory marking chunks dirty", "[state]") values.at(0) = 8; // Check remote - REQUIRE(server.getRemoteKvValue() == values); + REQUIRE(getRemoteKvValue() == values); // Check local value has been set with the latest remote value std::vector actualMemory(ptr, ptr + values.size()); REQUIRE(actualMemory == values); - - server.stop(); } -TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test overlaps with multiple chunks dirty", + "[state]") { std::vector values = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, push, pull - server.start(); + setDummyData(values); // Get pointer to local data - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); uint8_t* statePtr = localKv->get(); // Update a couple of areas @@ -193,8 +304,7 @@ TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") // Update one non-overlapping value remotely std::vector directA = { 2, 2 }; - const std::shared_ptr& remoteKv = - server.getRemoteKv(); + const std::shared_ptr& remoteKv = getRemoteKv(); remoteKv->setChunk(6, directA.data(), 2); // Update one overlapping value remotely @@ -207,8 +317,8 @@ TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") std::vector expectedRemote = { 6, 6, 6, 6, 6, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - REQUIRE(server.getLocalKvValue() == expectedLocal); - REQUIRE(server.getRemoteKvValue() == expectedRemote); + REQUIRE(getLocalKvValue() == expectedLocal); + REQUIRE(getRemoteKvValue() == expectedRemote); // Push changes localKv->pushPartial(); @@ -217,23 +327,19 @@ TEST_CASE("Test overlaps with multiple chunks dirty", "[state]") std::vector expected = { 6, 1, 2, 3, 6, 0, 2, 2, 0, 0, 4, 5, 0, 0, 7, 7, 7, 7, 0, 0 }; - REQUIRE(server.getLocalKvValue() == expected); - REQUIRE(server.getRemoteKvValue() == expected); - - server.stop(); + REQUIRE(getLocalKvValue() == expected); + REQUIRE(getRemoteKvValue() == expected); } -TEST_CASE("Test in memory partial update of doubles in state", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test in memory partial update of doubles in state", + "[state]") { long nDoubles = 20; long nBytes = nDoubles * sizeof(double); std::vector values(nBytes, 0); - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, push, pull - server.start(); + setDummyData(values); // Set up both with zeroes initially std::vector expected(nDoubles); @@ -242,7 +348,7 @@ TEST_CASE("Test in memory partial update of doubles in state", "[state]") memset(actualBytes.data(), 0, nBytes); // Update a value locally and flag dirty - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); auto actualPtr = reinterpret_cast(localKv->get()); auto expectedPtr = expected.data(); actualPtr[0] = 123.456; @@ -274,18 +380,17 @@ TEST_CASE("Test in memory partial update of doubles in state", "[state]") REQUIRE(expected == actualPostPush); // Check remote - std::vector remoteValue = server.getRemoteKvValue(); + std::vector remoteValue = getRemoteKvValue(); std::vector actualPostPushRemote(postPushDoublePtr, postPushDoublePtr + nDoubles); REQUIRE(expected == actualPostPushRemote); - - server.stop(); } -TEST_CASE("Test set chunk cannot be over the size of the allocated memory", - "[state]") +TEST_CASE_METHOD( + StateServerTestFixture, + "Test set chunk cannot be over the size of the allocated memory", + "[state]") { - cleanFaabric(); auto kv = setupKV(2); // Set a chunk offset @@ -295,29 +400,27 @@ TEST_CASE("Test set chunk cannot be over the size of the allocated memory", REQUIRE_THROWS(kv->setChunk(offset, update.data(), 3)); } -TEST_CASE("Test partially setting just first/ last element", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test partially setting just first/ last element", + "[state]") { std::vector values = { 0, 1, 2, 3, 4 }; - DummyStateServer server; - setUpDummyServer(server, values); - - // Only 3 push-partial messages as kv not fully allocated - server.start(); + setDummyData(values); // Update just the last element std::vector update = { 8 }; - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->setChunk(4, update.data(), 1); localKv->pushPartial(); std::vector expected = { 0, 1, 2, 3, 8 }; - REQUIRE(server.getRemoteKvValue() == expected); + REQUIRE(getRemoteKvValue() == expected); // Update the first localKv->setChunk(0, update.data(), 1); localKv->pushPartial(); expected = { 8, 1, 2, 3, 8 }; - REQUIRE(server.getRemoteKvValue() == expected); + REQUIRE(getRemoteKvValue() == expected); // Update two update = { 6 }; @@ -326,27 +429,22 @@ TEST_CASE("Test partially setting just first/ last element", "[state]") localKv->pushPartial(); expected = { 6, 1, 2, 3, 6 }; - REQUIRE(server.getRemoteKvValue() == expected); - - server.stop(); + REQUIRE(getRemoteKvValue() == expected); } -TEST_CASE("Test push partial with mask", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test push partial with mask", + "[state]") { size_t stateSize = 4 * sizeof(double); std::vector values(stateSize, 0); - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, full push, push partial - server.start(); + setDummyData(values); // Create another local KV of same size - State& state = getGlobalState(); auto maskKv = state.getKV("demo", "dummy_mask", stateSize); // Set up value locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); uint8_t* dataBytePtr = localKv->get(); auto dataDoublePtr = reinterpret_cast(dataBytePtr); std::vector initial = { 1.2345, 12.345, 987.6543, 10987654.3 }; @@ -358,7 +456,7 @@ TEST_CASE("Test push partial with mask", "[state]") localKv->pushFull(); // Check pushed remotely - std::vector actualBytes = server.getRemoteKvValue(); + std::vector actualBytes = getRemoteKvValue(); auto actualDoublePtr = reinterpret_cast(actualBytes.data()); std::vector actualDoubles(actualDoublePtr, actualDoublePtr + 4); @@ -387,100 +485,58 @@ TEST_CASE("Test push partial with mask", "[state]") }; // Check remotely - std::vector actualValue2 = server.getRemoteKvValue(); + std::vector actualValue2 = getRemoteKvValue(); auto actualDoublesPtr = reinterpret_cast(actualValue2.data()); std::vector actualDoubles2(actualDoublesPtr, actualDoublesPtr + 4); REQUIRE(actualDoubles2 == expected); - - server.stop(); } -void checkPulling(bool doPull) -{ - std::vector values = { 0, 1, 2, 3 }; - std::vector actual(values.size(), 0); - - DummyStateServer server; - setUpDummyServer(server, values); - - // Get, with optional pull - int nMessages = 1; - if (doPull) { - nMessages = 2; - } - - server.start(); - - // Initial pull - const std::shared_ptr& localKv = server.getLocalKv(); - localKv->pull(); - - // Update directly on the remote KV - std::shared_ptr remoteKv = server.getRemoteKv(); - std::vector newValues = { 5, 5, 5, 5 }; - remoteKv->set(newValues.data()); - - if (doPull) { - // Check locak changed with another pull - localKv->pull(); - localKv->get(actual.data()); - REQUIRE(actual == newValues); - } else { - // Check local unchanged without another pull - localKv->get(actual.data()); - REQUIRE(actual == values); - } - - server.stop(); -} - -TEST_CASE("Test updates pulled from remote", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test updates pulled from remote", + "[state]") { checkPulling(true); } -TEST_CASE("Test updates not pulled from remote without call to pull", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test updates not pulled from remote without call to pull", + "[state]") { checkPulling(false); } -TEST_CASE("Test pushing only happens when dirty", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test pushing only happens when dirty", + "[state]") { std::vector values = { 0, 1, 2, 3 }; - DummyStateServer server; - setUpDummyServer(server, values); - - server.start(); + setDummyData(values); // Pull locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->pull(); // Change remote directly - std::shared_ptr remoteKv = server.getRemoteKv(); + std::shared_ptr remoteKv = getRemoteKv(); std::vector newValues = { 3, 4, 5, 6 }; remoteKv->set(newValues.data()); // Push and make sure remote has not changed without local being dirty localKv->pushFull(); - REQUIRE(server.getRemoteKvValue() == newValues); + REQUIRE(getRemoteKvValue() == newValues); // Now change locally and check push happens std::vector newValues2 = { 7, 7, 7, 7 }; localKv->set(newValues2.data()); localKv->pushFull(); - REQUIRE(server.getRemoteKvValue() == newValues2); - - server.stop(); + REQUIRE(getRemoteKvValue() == newValues2); } -TEST_CASE("Test mapping shared memory", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping shared memory", + "[state]") { - cleanFaabric(); - - // Set up the KV - State& s = getGlobalState(); - auto kv = s.getKV("demo", "mapping_test", 5); + auto kv = state.getKV("demo", "mapping_test", 5); std::vector value = { 0, 1, 2, 3, 4 }; kv->set(value.data()); @@ -529,19 +585,16 @@ TEST_CASE("Test mapping shared memory", "[state]") } } -TEST_CASE("Test mapping shared memory does not pull", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping shared memory does not pull", + "[state]") { std::vector values = { 0, 1, 2, 3, 4 }; std::vector zeroes(values.size(), 0); - DummyStateServer server; - setUpDummyServer(server, values); - - // One implicit pull - server.start(); + setDummyData(values); // Write value to remote - const std::shared_ptr& remoteKv = - server.getRemoteKv(); + const std::shared_ptr& remoteKv = getRemoteKv(); remoteKv->set(values.data()); // Map the KV locally @@ -551,7 +604,7 @@ TEST_CASE("Test mapping shared memory does not pull", "[state]") MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->mapSharedMemory(mappedRegion, 0, 1); // Check it's zeroed @@ -564,18 +617,17 @@ TEST_CASE("Test mapping shared memory does not pull", "[state]") std::vector actualValueAfterGet(byteRegion, byteRegion + values.size()); REQUIRE(actualValueAfterGet == values); - - server.stop(); } -TEST_CASE("Test mapping small shared memory offsets", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping small shared memory offsets", + "[state]") { - cleanFaabric(); - // Set up the KV std::vector values = { 0, 1, 2, 3, 4, 5, 6 }; - State& s = getGlobalState(); - auto kv = s.getKV("demo", "mapping_small_test", values.size()); + setDummyData(values); + + auto kv = state.getKV("demo", "mapping_small_test", values.size()); kv->set(values.data()); // Map a single page of host memory @@ -622,7 +674,9 @@ TEST_CASE("Test mapping small shared memory offsets", "[state]") REQUIRE(chunkB[1] == 1); } -TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test mapping bigger uninitialized shared memory offsets", + "[state]") { // Define some mapping larger than a page size_t mappingSize = 3 * faabric::util::HOST_PAGE_SIZE; @@ -630,13 +684,7 @@ TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") // Set up a larger total value full of ones size_t totalSize = (10 * faabric::util::HOST_PAGE_SIZE) + 15; std::vector values(totalSize, 1); - - // Set up remote server - DummyStateServer server; - setUpDummyServer(server, values); - - // Expecting two implicit pulls - server.start(); + setDummyData(values); // Map a couple of chunks in host memory (as would be done by the wasm // module) @@ -646,7 +694,7 @@ TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") nullptr, mappingSize, PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); // Do the mapping - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->mapSharedMemory(mappedRegionA, 6, 3); localKv->mapSharedMemory(mappedRegionB, 2, 3); @@ -668,36 +716,28 @@ TEST_CASE("Test mapping bigger uninitialized shared memory offsets", "[state]") REQUIRE(chunkB[0] == 1); REQUIRE(chunkA[5] == 5); REQUIRE(chunkB[9] == 9); - - server.stop(); } -TEST_CASE("Test deletion", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, "Test deletion", "[state]") { std::vector values = { 0, 1, 2, 3, 4 }; - DummyStateServer server; - setUpDummyServer(server, values); - - // One pull, one deletion - server.start(); + setDummyData(values); // Check data remotely and locally - REQUIRE(server.getLocalKvValue() == values); - REQUIRE(server.getRemoteKvValue() == values); + REQUIRE(getLocalKvValue() == values); + REQUIRE(getRemoteKvValue() == values); // Delete from state - getGlobalState().deleteKV(server.dummyUser, server.dummyKey); + getGlobalState().deleteKV(dummyUser, dummyKey); // Check it's gone - REQUIRE(server.remoteState.getKVCount() == 0); - - server.stop(); + REQUIRE(remoteState.getKVCount() == 0); } -TEST_CASE("Test appended state with KV", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test appended state with KV", + "[state]") { - cleanFaabric(); - // Set up the KV State& s = getGlobalState(); std::shared_ptr kv; @@ -750,25 +790,19 @@ TEST_CASE("Test appended state with KV", "[state]") REQUIRE(actualAfterClear == expectedAfterClear); } -TEST_CASE("Test remote appended state", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test remote appended state", + "[state]") { - DummyStateServer server; - std::vector empty; - setUpDummyServer(server, empty); - - // One appends, two retrievals, one clear - server.start(); - std::vector valuesA = { 0, 1, 2, 3, 4 }; std::vector valuesB = { 3, 3, 5, 5 }; // Append some data remotely - const std::shared_ptr& remoteKv = - server.getRemoteKv(); + const std::shared_ptr& remoteKv = getRemoteKv(); remoteKv->append(valuesA.data(), valuesA.size()); // Append locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); auto localInMemKv = std::static_pointer_cast(localKv); REQUIRE(!localInMemKv->isMaster()); @@ -797,28 +831,24 @@ TEST_CASE("Test remote appended state", "[state]") localKv->getAppended( actualLocalAfterClear.data(), actualLocalAfterClear.size(), 1); REQUIRE(actualLocalAfterClear == valuesB); - - server.stop(); } -TEST_CASE("Test pushing pulling large state", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test pushing pulling large state", + "[state]") { size_t valueSize = (3 * STATE_STREAMING_CHUNK_SIZE) + 123; std::vector valuesA(valueSize, 1); std::vector valuesB(valueSize, 2); - DummyStateServer server; - setUpDummyServer(server, valuesA); - - // One pull, one push - server.start(); + setDummyData(valuesA); // Pull locally - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->pull(); // Check equality - const std::vector& actualLocal = server.getLocalKvValue(); + const std::vector& actualLocal = getLocalKvValue(); REQUIRE(actualLocal == valuesA); // Push @@ -826,28 +856,24 @@ TEST_CASE("Test pushing pulling large state", "[state]") localKv->pushFull(); // Check equality of remote - const std::vector& actualRemote = server.getRemoteKvValue(); + const std::vector& actualRemote = getRemoteKvValue(); REQUIRE(actualRemote == valuesB); - - server.stop(); } -TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") +TEST_CASE_METHOD(StateServerTestFixture, + "Test pushing pulling chunks over multiple requests", + "[state]") { // Set up a chunk of state size_t chunkSize = 1024; size_t valueSize = 10 * chunkSize + 123; std::vector values(valueSize, 1); - DummyStateServer server; - setUpDummyServer(server, values); - - // Two chunk pulls, one push partial - server.start(); + setDummyData(values); // Set a chunk in the remote value size_t offsetA = 3 * chunkSize + 5; std::vector segA = { 4, 4 }; - std::shared_ptr remoteKv = server.getRemoteKv(); + std::shared_ptr remoteKv = getRemoteKv(); remoteKv->setChunk(offsetA, segA.data(), segA.size()); // Set a chunk at the end of the remote value @@ -856,7 +882,7 @@ TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") remoteKv->setChunk(offsetB, segB.data(), segB.size()); // Get only these chunks locally - std::shared_ptr localKv = server.getLocalKv(); + std::shared_ptr localKv = getLocalKv(); std::vector actualSegA(segA.size(), 0); localKv->getChunk(offsetA, actualSegA.data(), segA.size()); REQUIRE(actualSegA == segA); @@ -879,7 +905,7 @@ TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") localKv->pushPartial(); // Check the chunks in the remote value - std::vector actualAfterPush = server.getRemoteKvValue(); + std::vector actualAfterPush = getRemoteKvValue(); std::vector actualSegC(actualAfterPush.begin() + offsetC, actualAfterPush.begin() + offsetC + segC.size()); @@ -889,21 +915,17 @@ TEST_CASE("Test pushing pulling chunks over multiple requests", "[state]") REQUIRE(actualSegC == segC); REQUIRE(actualSegD == segD); - - server.stop(); } -TEST_CASE("Test pulling disjoint chunks of the same value which share pages", - "[state]") +TEST_CASE_METHOD( + StateServerTestFixture, + "Test pulling disjoint chunks of the same value which share pages", + "[state]") { // Set up state size_t valueSize = 20 * faabric::util::HOST_PAGE_SIZE + 123; std::vector values(valueSize, 1); - DummyStateServer server; - setUpDummyServer(server, values); - - // Expect two chunk pulls - server.start(); + setDummyData(values); // Set up two chunks both from the same page of memory but not overlapping long offsetA = 2 * faabric::util::HOST_PAGE_SIZE + 10; @@ -916,7 +938,7 @@ TEST_CASE("Test pulling disjoint chunks of the same value which share pages", std::vector actualB(lenA, 0); std::vector expectedB(lenA, 1); - const std::shared_ptr& localKv = server.getLocalKv(); + const std::shared_ptr& localKv = getLocalKv(); localKv->getChunk(offsetA, actualA.data(), lenA); localKv->getChunk(offsetB, actualB.data(), lenB); @@ -927,7 +949,49 @@ TEST_CASE("Test pulling disjoint chunks of the same value which share pages", // Check both chunks are as expected REQUIRE(actualA == expectedA); REQUIRE(actualB == expectedB); +} + +TEST_CASE_METHOD(StateServerTestFixture, + "Test state server as remote master", + "[state]") +{ + REQUIRE(state.getKVCount() == 0); + + const char* userA = "foo"; + const char* keyA = "bar"; + std::vector dataA = { 0, 1, 2, 3, 4, 5, 6, 7 }; + std::vector dataB = { 7, 6, 5, 4, 3, 2, 1, 0 }; + + dummyUser = userA; + dummyKey = keyA; + setDummyData(dataA); + + // Get the state size before accessing the value locally + size_t actualSize = state.getStateSize(userA, keyA); + REQUIRE(actualSize == dataA.size()); - server.stop(); + // Access locally and check not master + auto localKv = getLocalKv(); + auto localStateKv = + std::static_pointer_cast(localKv); + REQUIRE(!localStateKv->isMaster()); + + // Set the state locally and check + const std::shared_ptr& kv = + state.getKV(userA, keyA, dataA.size()); + kv->set(dataB.data()); + + std::vector actualLocal(dataA.size(), 0); + kv->get(actualLocal.data()); + REQUIRE(actualLocal == dataB); + + // Check it's not changed remotely + std::vector actualRemote = getRemoteKvValue(); + REQUIRE(actualRemote == dataA); + + // Push and check remote is updated + kv->pushFull(); + actualRemote = getRemoteKvValue(); + REQUIRE(actualRemote == dataB); } } diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index f3d9c7f91..2a5603dca 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -2,73 +2,77 @@ #include "faabric_utils.h" -#include #include #include #include #include #include +#include #include using namespace faabric::state; namespace tests { -static const char* userA = "foo"; -static const char* keyA = "bar"; -static const char* keyB = "baz"; -static std::vector dataA = { 0, 1, 2, 3, 4, 5, 6, 7 }; -static std::vector dataB = { 7, 6, 5, 4, 3, 2, 1, 0 }; -static std::string originalStateMode; - -static void setUpStateMode() +class SimpleStateServerTestFixture + : public StateTestFixture + , public ConfTestFixture { - faabric::util::SystemConfig& conf = faabric::util::getSystemConfig(); - originalStateMode = conf.stateMode; - cleanFaabric(); -} + public: + SimpleStateServerTestFixture() + : server(faabric::state::getGlobalState()) + , dataA({ 0, 1, 2, 3, 4, 5, 6, 7 }) + , dataB({ 7, 6, 5, 4, 3, 2, 1, 0 }) + { + conf.stateMode = "inmemory"; -static void resetStateMode() -{ - faabric::util::getSystemConfig().stateMode = originalStateMode; -} + server.start(); + } -std::shared_ptr getKv(const std::string& user, - const std::string& key, - size_t stateSize) -{ - State& state = getGlobalState(); + ~SimpleStateServerTestFixture() { server.stop(); } - std::shared_ptr localKv = state.getKV(user, key, stateSize); - std::shared_ptr inMemLocalKv = - std::static_pointer_cast(localKv); + std::shared_ptr getKv(const std::string& user, + const std::string& key, + size_t stateSize) + { + std::shared_ptr localKv = + state.getKV(user, key, stateSize); - return inMemLocalKv; -} + std::shared_ptr inMemLocalKv = + std::static_pointer_cast(localKv); -TEST_CASE("Test request/ response", "[state]") -{ - setUpStateMode(); + return inMemLocalKv; + } + + protected: + StateServer server; + + const char* userA = "foo"; + const char* keyA = "bar"; + const char* keyB = "baz"; - // Create server - StateServer s(getGlobalState()); - s.start(); - usleep(1000 * 100); + std::vector dataA; + std::vector dataB; +}; +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test state request/ response", + "[state]") +{ std::vector actual(dataA.size(), 0); // Prepare a key-value with data auto kvA = getKv(userA, keyA, dataA.size()); kvA->set(dataA.data()); - // Prepare a key-value with no data + // Prepare a key-value with no data (but a size) auto kvB = getKv(userA, keyB, dataA.size()); // Prepare a key-value with same key but different data (for pushing) - std::string thisIP = faabric::util::getSystemConfig().endpointHost; + std::string thisHost = faabric::util::getSystemConfig().endpointHost; auto kvADuplicate = - InMemoryStateKeyValue(userA, keyA, dataB.size(), thisIP); + InMemoryStateKeyValue(userA, keyA, dataB.size(), thisHost); kvADuplicate.set(dataB.data()); StateClient client(userA, keyA, DEFAULT_STATE_HOST); @@ -133,19 +137,12 @@ TEST_CASE("Test request/ response", "[state]") REQUIRE(actualAppended == expected); } - - // Close the state client - client.close(); - - s.stop(); - - resetStateMode(); } -TEST_CASE("Test local-only push/ pull", "[state]") +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test local-only push/ pull", + "[state]") { - setUpStateMode(); - // Create a key-value locally auto localKv = getKv(userA, keyA, dataA.size()); @@ -159,15 +156,12 @@ TEST_CASE("Test local-only push/ pull", "[state]") // Check that we get the expected size State& state = getGlobalState(); REQUIRE(state.getStateSize(userA, keyA) == dataA.size()); - - resetStateMode(); } -TEST_CASE("Test local-only append", "[state]") +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test local-only append", + "[state]") { - setUpStateMode(); - - // Append a few chunks std::vector chunkA = { 1, 1 }; std::vector chunkB = { 2, 2, 2 }; std::vector chunkC = { 3, 3 }; @@ -187,59 +181,12 @@ TEST_CASE("Test local-only append", "[state]") kv->getAppended(actual.data(), actual.size(), 3); REQUIRE(actual == expected); - - resetStateMode(); } -TEST_CASE("Test state server as remote master", "[state]") +TEST_CASE_METHOD(SimpleStateServerTestFixture, + "Test state server with local master", + "[state]") { - setUpStateMode(); - - State& globalState = getGlobalState(); - REQUIRE(globalState.getKVCount() == 0); - - DummyStateServer server; - server.dummyData = dataA; - server.dummyUser = userA; - server.dummyKey = keyA; - server.start(); - - // Get the state size before accessing the value locally - size_t actualSize = globalState.getStateSize(userA, keyA); - REQUIRE(actualSize == dataA.size()); - - // Access locally and check not master - auto localKv = getKv(userA, keyA, dataA.size()); - REQUIRE(!localKv->isMaster()); - - // Set the state locally and check - State& state = getGlobalState(); - const std::shared_ptr& kv = - state.getKV(userA, keyA, dataA.size()); - kv->set(dataB.data()); - - std::vector actualLocal(dataA.size(), 0); - kv->get(actualLocal.data()); - REQUIRE(actualLocal == dataB); - - // Check it's not changed remotely - std::vector actualRemote = server.getRemoteKvValue(); - REQUIRE(actualRemote == dataA); - - // Push and check remote is updated - kv->pushFull(); - actualRemote = server.getRemoteKvValue(); - REQUIRE(actualRemote == dataB); - - server.stop(); - - resetStateMode(); -} - -TEST_CASE("Test state server with local master", "[state]") -{ - setUpStateMode(); - // Set and push auto localKv = getKv(userA, keyA, dataA.size()); localKv->set(dataA.data()); @@ -250,7 +197,6 @@ TEST_CASE("Test state server with local master", "[state]") localKv->set(dataB.data()); // Pull - State& state = getGlobalState(); const std::shared_ptr& kv = state.getKV(userA, keyA, dataA.size()); kv->pull(); @@ -258,7 +204,5 @@ TEST_CASE("Test state server with local master", "[state]") // Check it's still the same locally set value std::vector actual(kv->get(), kv->get() + dataA.size()); REQUIRE(actual == dataB); - - resetStateMode(); } } diff --git a/tests/test/transport/test_message_context.cpp b/tests/test/transport/test_message_context.cpp deleted file mode 100644 index 35173263e..000000000 --- a/tests/test/transport/test_message_context.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include - -#include - -using namespace faabric::transport; - -namespace tests { -TEST_CASE("Test global message context", "[transport]") -{ - // Get message context - MessageContext& context = getGlobalMessageContext(); - - // Context not shut down - REQUIRE(!context.isContextShutDown); - - // Close message context - REQUIRE_NOTHROW(context.close()); - - // Context is shut down - REQUIRE(context.isContextShutDown); - - // Get message context again, lazy-initialise it - MessageContext& newContext = getGlobalMessageContext(); - - // Context not shut down - REQUIRE(!newContext.isContextShutDown); - - newContext.close(); -} -} diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index ea41b3585..b42ff3e6d 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -4,135 +4,126 @@ #include #include -#include +#include +#include using namespace faabric::transport; -const std::string thisHost = "127.0.0.1"; -const int testPort = 9999; -const int testReplyPort = 9996; +#define TEST_PORT 9800 namespace tests { -TEST_CASE_METHOD(MessageContextFixture, - "Test open/close one client", - "[transport]") -{ - // Open an endpoint client, don't bind - MessageEndpoint cli(thisHost, testPort); - REQUIRE_NOTHROW(cli.open(context, SocketType::PULL, false)); - - // Open another endpoint client, bind - MessageEndpoint secondCli(thisHost, testPort); - REQUIRE_NOTHROW(secondCli.open(context, SocketType::PUSH, true)); - - // Close all endpoint clients - REQUIRE_NOTHROW(cli.close(false)); - REQUIRE_NOTHROW(secondCli.close(true)); -} -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv one message", "[transport]") { - // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); - src.open(context); - - // Open the destination endpoint client, bind - RecvMessageEndpoint dst(testPort); - dst.open(context); + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); + AsyncRecvMessageEndpoint dst(TEST_PORT); // Send message std::string expectedMsg = "Hello world!"; - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - REQUIRE_NOTHROW(src.send(msg, expectedMsg.size())); + const uint8_t* msg = BYTES_CONST(expectedMsg.c_str()); + src.send(msg, expectedMsg.size()); // Receive message faabric::transport::Message recvMsg = dst.recv(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); REQUIRE(actualMsg == expectedMsg); +} + +TEST_CASE_METHOD(SchedulerTestFixture, + "Test send before recv is ready", + "[transport]") +{ + std::string expectedMsg = "Hello world!"; + + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); + + auto latch = faabric::util::Latch::create(2); + + std::thread recvThread([&latch, expectedMsg] { + // Make sure this only runs once the send has been done + latch->wait(); + + // Receive message + AsyncRecvMessageEndpoint dst(TEST_PORT); + faabric::transport::Message recvMsg = dst.recv(); + + assert(recvMsg.size() == expectedMsg.size()); + std::string actualMsg(recvMsg.data(), recvMsg.size()); + assert(actualMsg == expectedMsg); + }); + + const uint8_t* msg = BYTES_CONST(expectedMsg.c_str()); + src.send(msg, expectedMsg.size()); + latch->wait(); - // Close endpoints - src.close(); - dst.close(); + if (recvThread.joinable()) { + recvThread.join(); + } } -TEST_CASE_METHOD(MessageContextFixture, "Test await response", "[transport]") +TEST_CASE_METHOD(SchedulerTestFixture, "Test await response", "[transport]") { // Prepare common message/response std::string expectedMsg = "Hello "; std::string expectedResponse = "world!"; - std::thread senderThread([this, expectedMsg, expectedResponse] { - // Open the source endpoint client, don't bind - MessageEndpointClient src(thisHost, testPort); - src.open(context); + std::thread senderThread([expectedMsg, expectedResponse] { + // Open the source endpoint client + SyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); // Send message and wait for response - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - src.send(msg, expectedMsg.size()); + std::vector bytes(BYTES_CONST(expectedMsg.c_str()), + BYTES_CONST(expectedMsg.c_str()) + + expectedMsg.size()); + + faabric::transport::Message recvMsg = + src.sendAwaitResponse(bytes.data(), bytes.size()); // Block waiting for a response - faabric::transport::Message recvMsg = src.awaitResponse(testReplyPort); assert(recvMsg.size() == expectedResponse.size()); std::string actualResponse(recvMsg.data(), recvMsg.size()); assert(actualResponse == expectedResponse); - - src.close(); }); // Receive message - RecvMessageEndpoint dst(testPort); - dst.open(context); + SyncRecvMessageEndpoint dst(TEST_PORT); faabric::transport::Message recvMsg = dst.recv(); REQUIRE(recvMsg.size() == expectedMsg.size()); std::string actualMsg(recvMsg.data(), recvMsg.size()); REQUIRE(actualMsg == expectedMsg); - // Send response, open a new endpoint for it - SendMessageEndpoint dstResponse(thisHost, testReplyPort); - dstResponse.open(context); - uint8_t msg[expectedResponse.size()]; - memcpy(msg, expectedResponse.c_str(), expectedResponse.size()); - dstResponse.send(msg, expectedResponse.size()); + // Send response + const uint8_t* msg = BYTES_CONST(expectedResponse.c_str()); + dst.sendResponse(msg, expectedResponse.size()); // Wait for sender thread if (senderThread.joinable()) { senderThread.join(); } - - // Close receiving endpoints - dst.close(); - dstResponse.close(); } -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv many messages", "[transport]") { int numMessages = 10000; std::string baseMsg = "Hello "; - std::thread senderThread([this, numMessages, baseMsg] { - // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); - src.open(context); + std::thread senderThread([numMessages, baseMsg] { + // Open the source endpoint client + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); for (int i = 0; i < numMessages; i++) { - std::string expectedMsg = baseMsg + std::to_string(i); - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - src.send(msg, expectedMsg.size()); + std::string msgData = baseMsg + std::to_string(i); + const uint8_t* msg = BYTES_CONST(msgData.c_str()); + src.send(msg, msgData.size()); } - - src.close(); }); // Receive messages - RecvMessageEndpoint dst(testPort); - dst.open(context); + AsyncRecvMessageEndpoint dst(TEST_PORT); for (int i = 0; i < numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -149,12 +140,9 @@ TEST_CASE_METHOD(MessageContextFixture, if (senderThread.joinable()) { senderThread.join(); } - - // Close the destination endpoint - dst.close(); } -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send/recv many messages from many clients", "[transport]") { @@ -162,26 +150,20 @@ TEST_CASE_METHOD(MessageContextFixture, int numSenders = 10; std::string expectedMsg = "Hello from client"; std::vector senderThreads; + const uint8_t* msg = BYTES_CONST(expectedMsg.c_str()); for (int j = 0; j < numSenders; j++) { - senderThreads.emplace_back( - std::thread([this, numMessages, expectedMsg] { - // Open the source endpoint client, don't bind - SendMessageEndpoint src(thisHost, testPort); - src.open(context); - for (int i = 0; i < numMessages; i++) { - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - src.send(msg, expectedMsg.size()); - } - - src.close(); - })); + senderThreads.emplace_back(std::thread([msg, numMessages, expectedMsg] { + // Open the source endpoint client + AsyncSendMessageEndpoint src(LOCALHOST, TEST_PORT); + for (int i = 0; i < numMessages; i++) { + src.send(msg, expectedMsg.size()); + } + })); } // Receive messages - RecvMessageEndpoint dst(testPort); - dst.open(context); + AsyncRecvMessageEndpoint dst(TEST_PORT); for (int i = 0; i < numSenders * numMessages; i++) { faabric::transport::Message recvMsg = dst.recv(); // Check just a subset of the messages @@ -198,47 +180,44 @@ TEST_CASE_METHOD(MessageContextFixture, t.join(); } } - - // Close the destination endpoint - dst.close(); } -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test can't set invalid send/recv timeouts", "[transport]") { - MessageEndpoint cli(thisHost, testPort); SECTION("Sanity check valid timeout") { - REQUIRE_NOTHROW(cli.setRecvTimeoutMs(100)); - REQUIRE_NOTHROW(cli.setSendTimeoutMs(100)); - } + AsyncSendMessageEndpoint s(LOCALHOST, TEST_PORT, 100); + AsyncRecvMessageEndpoint r(TEST_PORT, 100); - SECTION("Recv zero timeout") { REQUIRE_THROWS(cli.setRecvTimeoutMs(0)); } - - SECTION("Send zero timeout") { REQUIRE_THROWS(cli.setSendTimeoutMs(0)); } + SyncSendMessageEndpoint sB(LOCALHOST, TEST_PORT + 10, 100); + SyncRecvMessageEndpoint rB(TEST_PORT + 10, 100); + } - SECTION("Recv negative timeout") + SECTION("Recv zero timeout") { - REQUIRE_THROWS(cli.setRecvTimeoutMs(-1)); + REQUIRE_THROWS(AsyncRecvMessageEndpoint(TEST_PORT, 0)); + REQUIRE_THROWS(SyncRecvMessageEndpoint(TEST_PORT + 10, 0)); } - SECTION("Send negative timeout") + SECTION("Send zero timeout") { - REQUIRE_THROWS(cli.setSendTimeoutMs(-1)); + REQUIRE_THROWS(AsyncSendMessageEndpoint(LOCALHOST, TEST_PORT, 0)); + REQUIRE_THROWS(SyncSendMessageEndpoint(LOCALHOST, TEST_PORT + 10, 0)); } - SECTION("Recv, socket already initialised") + SECTION("Recv negative timeout") { - cli.open(context, SocketType::PULL, false); - REQUIRE_THROWS(cli.setRecvTimeoutMs(100)); + REQUIRE_THROWS(AsyncRecvMessageEndpoint(TEST_PORT, -1)); + REQUIRE_THROWS(SyncRecvMessageEndpoint(TEST_PORT + 10, -1)); } - SECTION("Send, socket already initialised") + SECTION("Send negative timeout") { - cli.open(context, SocketType::PULL, false); - REQUIRE_THROWS(cli.setSendTimeoutMs(100)); + REQUIRE_THROWS(AsyncSendMessageEndpoint(LOCALHOST, TEST_PORT, -1)); + REQUIRE_THROWS(SyncSendMessageEndpoint(LOCALHOST, TEST_PORT + 10, -1)); } } } diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 79bbaf65b..2449d55b8 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -1,150 +1,150 @@ #include +#include "faabric_utils.h" + #include +#include +#include #include +#include #include +#include using namespace faabric::transport; -const std::string thisHost = "127.0.0.1"; -const int testPort = 9999; +#define TEST_PORT_ASYNC 9998 +#define TEST_PORT_SYNC 9999 class DummyServer final : public MessageEndpointServer { public: DummyServer() - : MessageEndpointServer(testPort) - , messageCount(0) + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) {} - // Variable to keep track of the received messages - int messageCount; + std::atomic messageCount = 0; - // This method is protected in the base class, as it's always called from - // the doRecv implementation. To ease testing, we make it public with this - // workaround. - void sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort) + private: + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override { - MessageEndpointServer::sendResponse( - serialisedMsg, size, returnHost, returnPort); + messageCount++; } - private: - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override { - // Dummy server, do nothing but increment the message count - this->messageCount++; + messageCount++; + + return std::make_unique(); } }; -class SlowServer final : public MessageEndpointServer +class EchoServer final : public MessageEndpointServer { public: - int delayMs = 1000; - - SlowServer() - : MessageEndpointServer(testPort) + EchoServer() + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) {} - void sendResponse(uint8_t* serialisedMsg, - int size, - const std::string& returnHost, - int returnPort) + protected: + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override { - usleep(delayMs * 1000); + throw std::runtime_error("Echo server not expecting async recv"); } - private: - void doRecv(faabric::transport::Message& header, - faabric::transport::Message& body) override - {} + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override + { + SPDLOG_TRACE("Echo server received {} bytes", bufferSize); + + auto response = std::make_unique(); + response->set_data(buffer, bufferSize); + + return response; + } }; -namespace tests { -TEST_CASE("Test start/stop server", "[transport]") +class SleepServer final : public MessageEndpointServer { - DummyServer server; - REQUIRE_NOTHROW(server.start()); + public: + int delayMs = 1000; - usleep(1000 * 100); + SleepServer() + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) + {} - REQUIRE_NOTHROW(server.stop()); -} + protected: + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override + { + throw std::runtime_error("Sleep server not expecting async recv"); + } + + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override + { + int* sleepTimeMs = (int*)buffer; + SPDLOG_DEBUG("Sleep server sleeping for {}ms", *sleepTimeMs); + SLEEP_MS(*sleepTimeMs); + + auto response = std::make_unique(); + response->set_data("Response after sleep"); + return response; + } +}; + +namespace tests { TEST_CASE("Test send one message to server", "[transport]") { - // Start server DummyServer server; server.start(); - // Open the source endpoint client, don't bind - auto& context = getGlobalMessageContext(); - MessageEndpointClient src(thisHost, testPort); - src.open(context); - - // Send message: server expects header + body - std::string header = "header"; - uint8_t headerMsg[header.size()]; - memcpy(headerMsg, header.c_str(), header.size()); - // Mark we are sending the header - REQUIRE_NOTHROW(src.send(headerMsg, header.size(), true)); - // Send the body + REQUIRE(server.messageCount == 0); + + MessageEndpointClient cli(LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); + + // Send a message std::string body = "body"; - uint8_t bodyMsg[body.size()]; - memcpy(bodyMsg, body.c_str(), body.size()); - src.send(bodyMsg, body.size(), false); + const uint8_t* bodyMsg = BYTES_CONST(body.c_str()); - usleep(1000 * 300); - REQUIRE(server.messageCount == 1); + server.setAsyncLatch(); + cli.asyncSend(0, bodyMsg, body.size()); + server.awaitAsyncLatch(); - // Close the client - src.close(); + REQUIRE(server.messageCount == 1); - // Close the server server.stop(); } -TEST_CASE("Test send one-off response to client", "[transport]") +TEST_CASE("Test send response to client", "[transport]") { - DummyServer server; + EchoServer server; server.start(); std::string expectedMsg = "Response from server"; - std::thread clientThread([expectedMsg] { - // Open the source endpoint client, don't bind - auto& context = getGlobalMessageContext(); - MessageEndpointClient cli(thisHost, testPort); - cli.open(context); - - Message msg = cli.awaitResponse(testPort + REPLY_PORT_OFFSET); - assert(msg.size() == expectedMsg.size()); - std::string actualMsg(msg.data(), msg.size()); - assert(actualMsg == expectedMsg); + // Open the source endpoint client + MessageEndpointClient cli(LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); - cli.close(); - }); + // Send and await the response + faabric::StatePart response; + cli.syncSend(0, BYTES(expectedMsg.data()), expectedMsg.size(), &response); - uint8_t msg[expectedMsg.size()]; - memcpy(msg, expectedMsg.c_str(), expectedMsg.size()); - REQUIRE_NOTHROW( - server.sendResponse(msg, expectedMsg.size(), thisHost, testPort)); - - if (clientThread.joinable()) { - clientThread.join(); - } + assert(response.data() == expectedMsg); server.stop(); } TEST_CASE("Test multiple clients talking to one server", "[transport]") { - DummyServer server; + EchoServer server; server.start(); std::vector clientThreads; @@ -152,27 +152,23 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") int numMessages = 1000; for (int i = 0; i < numClients; i++) { - clientThreads.emplace_back(std::thread([numMessages] { + clientThreads.emplace_back(std::thread([i, numMessages] { // Prepare client - auto& context = getGlobalMessageContext(); - MessageEndpointClient cli(thisHost, testPort); - cli.open(context); + MessageEndpointClient cli( + LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); - std::string clientMsg = "Message from threaded client"; for (int j = 0; j < numMessages; j++) { - // Send header - uint8_t header[clientMsg.size()]; - memcpy(header, clientMsg.c_str(), clientMsg.size()); - cli.send(header, clientMsg.size(), true); - // Send body - uint8_t body[clientMsg.size()]; - memcpy(body, clientMsg.c_str(), clientMsg.size()); - cli.send(body, clientMsg.size()); - } + std::string clientMsg = + fmt::format("Message {} from client {}", j, i); - usleep(1000 * 300); + // Send and get response + const uint8_t* body = BYTES_CONST(clientMsg.c_str()); + faabric::StatePart response; + cli.syncSend(0, body, clientMsg.size(), &response); - cli.close(); + std::string actual = response.data(); + assert(actual == clientMsg); + } })); } @@ -182,60 +178,49 @@ TEST_CASE("Test multiple clients talking to one server", "[transport]") } } - REQUIRE(server.messageCount == numMessages * numClients); - server.stop(); } TEST_CASE("Test client timeout on requests to valid server", "[transport]") { - // Start the server in the background - std::thread t([] { - SlowServer server; - server.start(); - - int threadSleep = server.delayMs + 500; - usleep(threadSleep * 1000); - - server.stop(); - }); - - // Wait for the server to start up - usleep(500 * 1000); - - // Set up the client - auto& context = getGlobalMessageContext(); - MessageEndpointClient cli(thisHost, testPort); - int clientTimeout; + int serverSleep; bool expectFailure; + SECTION("Long timeout no failure") { clientTimeout = 20000; + serverSleep = 100; expectFailure = false; } SECTION("Short timeout failure") { - clientTimeout = 100; + clientTimeout = 10; + serverSleep = 2000; expectFailure = true; } - cli.setRecvTimeoutMs(clientTimeout); - cli.open(context); + // Start the server + SleepServer server; + server.start(); + + // Set up the client + MessageEndpointClient cli( + LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC, clientTimeout); + + uint8_t* sleepBytes = BYTES(&serverSleep); + faabric::StatePart response; - // Check for failure accordingly if (expectFailure) { - REQUIRE_THROWS_AS(cli.awaitResponse(testPort + REPLY_PORT_OFFSET), + // Check for failure + REQUIRE_THROWS_AS(cli.syncSend(0, sleepBytes, sizeof(int), &response), MessageTimeoutException); } else { - REQUIRE_NOTHROW(cli.awaitResponse(testPort + REPLY_PORT_OFFSET)); + cli.syncSend(0, sleepBytes, sizeof(int), &response); + REQUIRE(response.data() == "Response after sleep"); } - cli.close(); - - if (t.joinable()) { - t.join(); - } + server.stop(); } } diff --git a/tests/test/transport/test_mpi_message_endpoint.cpp b/tests/test/transport/test_mpi_message_endpoint.cpp index a27ae9bf7..670062b20 100644 --- a/tests/test/transport/test_mpi_message_endpoint.cpp +++ b/tests/test/transport/test_mpi_message_endpoint.cpp @@ -6,27 +6,8 @@ using namespace faabric::transport; namespace tests { -TEST_CASE_METHOD(MessageContextFixture, - "Test send and recv the hosts to rank message", - "[transport]") -{ - // Prepare message - std::vector expected = { "foo", "bar" }; - faabric::MpiHostsToRanksMessage sendMsg; - *sendMsg.mutable_hosts() = { expected.begin(), expected.end() }; - sendMpiHostRankMsg(LOCALHOST, sendMsg); - - // Send message - faabric::MpiHostsToRanksMessage actual = recvMpiHostRankMsg(); - // Checks - REQUIRE(actual.hosts().size() == expected.size()); - for (int i = 0; i < actual.hosts().size(); i++) { - REQUIRE(actual.hosts().Get(i) == expected[i]); - } -} - -TEST_CASE_METHOD(MessageContextFixture, +TEST_CASE_METHOD(SchedulerTestFixture, "Test send and recv an MPI message", "[transport]") { @@ -41,10 +22,6 @@ TEST_CASE_METHOD(MessageContextFixture, sendEndpoint.sendMpiMessage(expected); std::shared_ptr actual = recvEndpoint.recvMpiMessage(); - // Checks REQUIRE(expected->id() == actual->id()); - - REQUIRE_NOTHROW(sendEndpoint.close()); - REQUIRE_NOTHROW(recvEndpoint.close()); } } diff --git a/tests/test/util/test_barrier.cpp b/tests/test/util/test_barrier.cpp deleted file mode 100644 index 594c0ab90..000000000 --- a/tests/test/util/test_barrier.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include -#include -#include -#include - -using namespace faabric::util; - -namespace tests { -TEST_CASE("Test barrier operation", "[util]") -{ - Barrier b(3); - - REQUIRE(b.getSlotCount() == 3); - REQUIRE(b.getUseCount() == 0); - - auto t1 = std::thread([&b] { b.wait(); }); - - auto t2 = std::thread([&b] { b.wait(); }); - - // Sleep for a bit while the threads spawn - usleep(500 * 1000); - REQUIRE(b.getSlotCount() == 1); - - // Join with master to go through barrier - b.wait(); - - if (t1.joinable()) { - t1.join(); - } - - if (t2.joinable()) { - t2.join(); - } - - REQUIRE(b.getSlotCount() == 3); - REQUIRE(b.getUseCount() == 1); -} -} diff --git a/tests/test/util/test_latch.cpp b/tests/test/util/test_latch.cpp new file mode 100644 index 000000000..3025ed06a --- /dev/null +++ b/tests/test/util/test_latch.cpp @@ -0,0 +1,41 @@ +#include + +#include "faabric_utils.h" + +#include +#include +#include + +#include +#include + +using namespace faabric::util; + +namespace tests { +TEST_CASE("Test latch operation", "[util]") +{ + auto l = Latch::create(3); + + auto t1 = std::thread([l] { l->wait(); }); + auto t2 = std::thread([l] { l->wait(); }); + + l->wait(); + + if (t1.joinable()) { + t1.join(); + } + + if (t2.joinable()) { + t2.join(); + } + + REQUIRE_THROWS(l->wait()); +} + +TEST_CASE("Test latch timeout", "[util]") +{ + int timeoutMs = 500; + auto l = Latch::create(2, timeoutMs); + REQUIRE_THROWS(l->wait()); +} +} diff --git a/tests/test/util/test_queue.cpp b/tests/test/util/test_queue.cpp index 64c4319d0..fbf1560ab 100644 --- a/tests/test/util/test_queue.cpp +++ b/tests/test/util/test_queue.cpp @@ -1,6 +1,9 @@ #include +#include "faabric_utils.h" + #include +#include #include #include @@ -86,7 +89,7 @@ TEST_CASE("Test wait for draining queue with elements", "[util]") // Background thread to consume elements std::thread t([&q, &dequeued, nElems] { for (int i = 0; i < nElems; i++) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + SLEEP_MS(100); int j = q.dequeue(); dequeued.emplace_back(j); @@ -117,7 +120,7 @@ TEST_CASE("Test queue on non-copy-constructible object", "[util]") std::thread ta([&q] { q.dequeue().set_value(1); }); std::thread tb([&q] { - usleep(500 * 1000); + SLEEP_MS(SHORT_TEST_TIMEOUT_MS); q.dequeue().set_value(2); }); diff --git a/tests/utils/faabric_utils.h b/tests/utils/faabric_utils.h index 24f078d60..acc86cb73 100644 --- a/tests/utils/faabric_utils.h +++ b/tests/utils/faabric_utils.h @@ -12,6 +12,29 @@ using namespace faabric; #define SHORT_TEST_TIMEOUT_MS 1000 +#define REQUIRE_RETRY_MAX 5 +#define REQUIRE_RETRY_SLEEP_MS 1000 + +#define REQUIRE_RETRY(updater, check) \ + { \ + { \ + updater; \ + }; \ + bool res = (check); \ + int count = 0; \ + while (!res && count < REQUIRE_RETRY_MAX) { \ + count++; \ + SLEEP_MS(REQUIRE_RETRY_SLEEP_MS); \ + { \ + updater; \ + }; \ + res = (check); \ + } \ + if (!res) { \ + FAIL(); \ + } \ + } + #define FAABRIC_CATCH_LOGGER \ struct LogListener : Catch::TestEventListenerBase \ { \ diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index b102db298..eeac4aa9f 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -6,8 +6,9 @@ #include #include #include +#include #include -#include +#include #include #include #include @@ -118,20 +119,9 @@ class ConfTestFixture faabric::util::SystemConfig& conf; }; -class MessageContextFixture : public SchedulerTestFixture -{ - protected: - faabric::transport::MessageContext& context; - - public: - MessageContextFixture() - : context(faabric::transport::getGlobalMessageContext()) - {} - - ~MessageContextFixture() { context.close(); } -}; - -class MpiBaseTestFixture : public SchedulerTestFixture +class MpiBaseTestFixture + : public SchedulerTestFixture + , public ConfTestFixture { public: MpiBaseTestFixture() @@ -178,45 +168,57 @@ class MpiTestFixture : public MpiBaseTestFixture faabric::scheduler::MpiWorld world; }; +// Note that this test has two worlds, which each "think" that the other is +// remote. This is done by allowing one to have the IP of this host, the other +// to have the localhost IP, i.e. 127.0.0.1. class RemoteMpiTestFixture : public MpiBaseTestFixture { public: RemoteMpiTestFixture() : thisHost(faabric::util::getSystemConfig().endpointHost) - , otherHost(LOCALHOST) + , testLatch(faabric::util::Latch::create(2)) + { + otherWorld.overrideHost(otherHost); + + faabric::util::setMockMode(true); + } + + ~RemoteMpiTestFixture() { - remoteWorld.overrideHost(otherHost); + faabric::util::setMockMode(false); + + faabric::scheduler::getMpiWorldRegistry().clear(); } - void setWorldsSizes(int worldSize, int ranksWorldOne, int ranksWorldTwo) + void setWorldSizes(int worldSize, int ranksThisWorld, int ranksOtherWorld) { // Update message msg.set_mpiworldsize(worldSize); - // Set local ranks - faabric::HostResources localResources; - localResources.set_slots(ranksWorldOne); - // Account for the master rank that is already running in this world - localResources.set_usedslots(1); - // Set remote ranks - faabric::HostResources otherResources; - otherResources.set_slots(ranksWorldTwo); - // Note that the remaining ranks will be allocated to the world - // with the master host + // Set up the first world, holding the master rank (which already takes + // one slot). + // Note that any excess ranks will also be allocated to this world when + // the scheduler is overloaded. + faabric::HostResources thisResources; + thisResources.set_slots(ranksThisWorld); + thisResources.set_usedslots(1); + sch.setThisHostResources(thisResources); - std::string otherHost = LOCALHOST; + // Set up the other world and add it to the global set of hosts + faabric::HostResources otherResources; + otherResources.set_slots(ranksOtherWorld); sch.addHostToGlobalSet(otherHost); - // Mock everything to make sure the other host has resources as well - faabric::util::setMockMode(true); - sch.setThisHostResources(localResources); + // Queue the resource response for this other host faabric::scheduler::queueResourceResponse(otherHost, otherResources); } protected: std::string thisHost; - std::string otherHost; + std::string otherHost = LOCALHOST; + + std::shared_ptr testLatch; - faabric::scheduler::MpiWorld remoteWorld; + faabric::scheduler::MpiWorld otherWorld; }; }