diff --git a/tests/test/scheduler/test_exec_graph.cpp b/tests/test/scheduler/test_exec_graph.cpp index 46b04a05d..595c8eb12 100644 --- a/tests/test/scheduler/test_exec_graph.cpp +++ b/tests/test/scheduler/test_exec_graph.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -84,6 +85,13 @@ TEST_CASE_METHOD(MpiBaseTestFixture, std::vector messages(worldSize); for (int rank = 0; rank < worldSize; rank++) { messages.at(rank) = faabric::util::messageFactory("mpi", "hellompi"); + messages.at(rank).set_id(0); + messages.at(rank).set_timestamp(0); + messages.at(rank).set_finishtimestamp(0); + messages.at(rank).set_resultkey(""); + messages.at(rank).set_statuskey(""); + messages.at(rank).set_executedhost( + faabric::util::getSystemConfig().endpointHost); messages.at(rank).set_ismpi(true); messages.at(rank).set_mpiworldid(worldId); messages.at(rank).set_mpirank(rank); @@ -110,12 +118,28 @@ TEST_CASE_METHOD(MpiBaseTestFixture, for (const auto& id : sch.getChainedFunctions(msg.id())) { sch.getFunctionResult(id, 500); } + ExecGraph actual = sch.getFunctionExecGraph(msg.id()); + + // Unset the fields that we can't recreate + actual.rootNode.msg.set_id(0); + actual.rootNode.msg.set_finishtimestamp(0); + actual.rootNode.msg.set_timestamp(0); + actual.rootNode.msg.set_resultkey(""); + actual.rootNode.msg.set_statuskey(""); + actual.rootNode.msg.set_outputdata(""); + for (auto& node : actual.rootNode.children) { + node.msg.set_id(0); + node.msg.set_finishtimestamp(0); + node.msg.set_timestamp(0); + node.msg.set_resultkey(""); + node.msg.set_statuskey(""); + node.msg.set_outputdata(""); + } // Check the execution graph - ExecGraph actual = sch.getFunctionExecGraph(msg.id()); REQUIRE(countExecGraphNodes(actual) == worldSize); REQUIRE(countExecGraphNodes(expected) == worldSize); - checkExecGraphEquality(expected, actual, true); + checkExecGraphEquality(expected, actual); } } diff --git a/tests/utils/exec_graph_utils.cpp b/tests/utils/exec_graph_utils.cpp index 45788730d..a39258dd6 100644 --- a/tests/utils/exec_graph_utils.cpp +++ b/tests/utils/exec_graph_utils.cpp @@ -6,15 +6,9 @@ namespace tests { void checkExecGraphNodeEquality(const scheduler::ExecGraphNode& nodeA, - const scheduler::ExecGraphNode& nodeB, - bool isMpi) + const scheduler::ExecGraphNode& nodeB) { - // Check the message itself - if (isMpi) { - checkMpiMessageEquivalence(nodeA.msg, nodeB.msg); - } else { - checkMessageEquality(nodeA.msg, nodeB.msg); - } + checkMessageEquality(nodeA.msg, nodeB.msg); if (nodeA.children.size() != nodeB.children.size()) { FAIL(fmt::format("Children not same size: {} vs {}", @@ -24,15 +18,13 @@ void checkExecGraphNodeEquality(const scheduler::ExecGraphNode& nodeA, // Assume children are in same order for (int i = 0; i < nodeA.children.size(); i++) { - checkExecGraphNodeEquality( - nodeA.children.at(i), nodeB.children.at(i), isMpi); + checkExecGraphNodeEquality(nodeA.children.at(i), nodeB.children.at(i)); } } void checkExecGraphEquality(const scheduler::ExecGraph& graphA, - const scheduler::ExecGraph& graphB, - bool isMpi) + const scheduler::ExecGraph& graphB) { - checkExecGraphNodeEquality(graphA.rootNode, graphB.rootNode, isMpi); + checkExecGraphNodeEquality(graphA.rootNode, graphB.rootNode); } } diff --git a/tests/utils/faabric_utils.h b/tests/utils/faabric_utils.h index a7162fcaf..b2a3f042d 100644 --- a/tests/utils/faabric_utils.h +++ b/tests/utils/faabric_utils.h @@ -65,16 +65,11 @@ void cleanFaabric(); void checkMessageEquality(const faabric::Message& msgA, const faabric::Message& msgB); -void checkMpiMessageEquivalence(const faabric::Message& msgA, - const faabric::Message& msgB); - void checkExecGraphNodeEquality(const scheduler::ExecGraphNode& nodeA, - const scheduler::ExecGraphNode& nodeB, - bool isMpi = false); + const scheduler::ExecGraphNode& nodeB); void checkExecGraphEquality(const scheduler::ExecGraph& graphA, - const scheduler::ExecGraph& graphB, - bool isMpi = false); + const scheduler::ExecGraph& graphB); std::pair submitGetRequestToUrl(const std::string& host, int port, diff --git a/tests/utils/message_utils.cpp b/tests/utils/message_utils.cpp index f0220b64d..843dc9762 100644 --- a/tests/utils/message_utils.cpp +++ b/tests/utils/message_utils.cpp @@ -48,16 +48,4 @@ void checkMessageEquality(const faabric::Message& msgA, REQUIRE(msgA.sgxpolicy() == msgB.sgxpolicy()); REQUIRE(msgA.sgxresult() == msgB.sgxresult()); } - -void checkMpiMessageEquivalence(const faabric::Message& msgA, - const faabric::Message& msgB) -{ - REQUIRE(msgA.user() == msgB.user()); - REQUIRE(msgA.function() == msgB.function()); - - REQUIRE(msgA.ismpi() == msgB.ismpi()); - REQUIRE(msgA.mpiworldid() == msgB.mpiworldid()); - REQUIRE(msgA.mpirank() == msgB.mpirank()); - REQUIRE(msgA.mpiworldsize() == msgB.mpiworldsize()); -} }