diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index 83d5223c3..0192bf9a6 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -194,6 +194,9 @@ void MpiWorld::create(const faabric::Message& call, int newId, int newSize) msg.set_mpiworldsize(size); // Log chained functions to generate execution graphs sch.logChainedFunction(call.id(), msg.id()); + SPDLOG_INFO("Logging chained call (MsgId-WorldId: {}-{}, Rank: {}/{}) -> (MsgId-WorldId: {}-{}, Rank: {}/{})", + call.id(), call.mpiworldid(), call.mpirank(), call.mpiworldsize(), + msg.id(), msg.mpiworldid(), msg.mpirank(), msg.mpiworldsize()); } std::vector<std::string> executedAt; diff --git a/tests/test/scheduler/test_exec_graph.cpp b/tests/test/scheduler/test_exec_graph.cpp index 5154663d1..4a30d6b94 100644 --- a/tests/test/scheduler/test_exec_graph.cpp +++ b/tests/test/scheduler/test_exec_graph.cpp @@ -3,8 +3,10 @@ #include "faabric_utils.h" #include <faabric/redis/Redis.h> +#include <faabric/scheduler/MpiWorld.h> #include <faabric/scheduler/Scheduler.h> #include <faabric/util/environment.h> +#include <faabric/util/logging.h> using namespace scheduler; @@ -67,4 +69,52 @@ TEST_CASE("Test execution graph", "[scheduler]") checkExecGraphEquality(expected, actual); } + +TEST_CASE_METHOD(MpiBaseTestFixture, + "Test MPI execution graph", + "[mpi][scheduler]") +{ + faabric::scheduler::MpiWorld world; + + // Build the message vector to reconstruct the graph + std::vector<faabric::Message> messages(worldSize); + for (int rank = 0; rank < worldSize; rank++) { + messages.at(rank) = faabric::util::messageFactory("mpi", "hellompi"); + messages.at(rank).set_mpirank(rank); + } + + world.create(msg, worldId, worldSize); + + world.destroy(); + + // Build expected graph + ExecGraphNode nodeB1 = { .msg = messages.at(1) }; + ExecGraphNode nodeB2 = { .msg = messages.at(2) }; + ExecGraphNode nodeB3 = { .msg = messages.at(3) }; + ExecGraphNode nodeB4 = { .msg = messages.at(4) }; + + ExecGraphNode nodeA = { .msg = messages.at(0), + .children = { nodeB1, nodeB2, nodeB3, nodeB4 } }; + + ExecGraph expected{ .rootNode = nodeA }; + + // Check the execution graph + ExecGraph actual = sch.getFunctionExecGraph(msg.id()); + REQUIRE(countExecGraphNodes(actual) == worldSize); + REQUIRE(countExecGraphNodes(expected) == worldSize); + + // Print contents of actual + SPDLOG_INFO("Actual root node. World id: {}, Rank: {}/{}", + actual.rootNode.msg.mpiworldid(), + actual.rootNode.msg.mpirank(), + actual.rootNode.msg.mpiworldsize()); + for (const auto& node : actual.rootNode.children) { + SPDLOG_INFO("Actual children node. World id: {}, Rank: {}/{}", + node.msg.mpiworldid(), + node.msg.mpirank(), + node.msg.mpiworldsize()); + } + + checkExecGraphEquality(expected, actual, true); +} } diff --git a/tests/utils/exec_graph_utils.cpp b/tests/utils/exec_graph_utils.cpp index 6178fcfaa..1b8457047 100644 --- a/tests/utils/exec_graph_utils.cpp +++ b/tests/utils/exec_graph_utils.cpp @@ -6,10 +6,15 @@ namespace tests { void checkExecGraphNodeEquality(const scheduler::ExecGraphNode& nodeA, - const scheduler::ExecGraphNode& nodeB) + const scheduler::ExecGraphNode& nodeB, + bool isMpi) { // Check the message itself - checkMessageEquality(nodeA.msg, nodeB.msg); + if (isMpi) { + REQUIRE(nodeA.msg.mpirank() == nodeB.msg.mpirank()); + } else { + checkMessageEquality(nodeA.msg, nodeB.msg); + } if (nodeA.children.size() != nodeB.children.size()) { FAIL(fmt::format("Children not same size: {} vs {}", @@ -19,13 +24,15 @@ void checkExecGraphNodeEquality(const scheduler::ExecGraphNode& nodeA, // Assume children are in same order for (int i = 0; i < nodeA.children.size(); i++) { - checkExecGraphNodeEquality(nodeA.children.at(i), nodeB.children.at(i)); + checkExecGraphNodeEquality( + nodeA.children.at(i), nodeB.children.at(i), isMpi); } } void checkExecGraphEquality(const scheduler::ExecGraph& graphA, - const scheduler::ExecGraph& graphB) + const scheduler::ExecGraph& graphB, + bool isMpi) { - checkExecGraphNodeEquality(graphA.rootNode, graphB.rootNode); + checkExecGraphNodeEquality(graphA.rootNode, graphB.rootNode, isMpi); } } diff --git a/tests/utils/faabric_utils.h b/tests/utils/faabric_utils.h index b2a3f042d..27b933b73 100644 --- a/tests/utils/faabric_utils.h +++ b/tests/utils/faabric_utils.h @@ -66,10 +66,12 @@ void checkMessageEquality(const faabric::Message& msgA, const faabric::Message& msgB); void checkExecGraphNodeEquality(const scheduler::ExecGraphNode& nodeA, - const scheduler::ExecGraphNode& nodeB); + const scheduler::ExecGraphNode& nodeB, + bool isMpi = false); void checkExecGraphEquality(const scheduler::ExecGraph& graphA, - const scheduler::ExecGraph& graphB); + const scheduler::ExecGraph& graphB, + bool isMpi = false); std::pair<int, std::string> submitGetRequestToUrl(const std::string& host, int port,