Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release an RxNodeComponent edge on error #327

Merged
1 change: 1 addition & 0 deletions cpp/mrc/include/mrc/node/rx_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class RxNodeComponent : public WritableProvider<InputT>, public WritableAcceptor
this->get_writable_edge()->await_write(std::move(message));
},
[this](std::exception_ptr ptr) {
WritableAcceptor<OutputT>::release_edge_connection();
runnable::Context::get_runtime_context().set_exception(std::move(ptr));
},
[this]() {
Expand Down
74 changes: 74 additions & 0 deletions cpp/mrc/tests/test_edges.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mrc/node/operators/combine_latest.hpp"
#include "mrc/node/operators/node_component.hpp"
#include "mrc/node/operators/router.hpp"
#include "mrc/node/rx_node.hpp"
#include "mrc/node/sink_channel_owner.hpp"
#include "mrc/node/sink_properties.hpp"
#include "mrc/node/source_channel_owner.hpp"
Expand Down Expand Up @@ -278,6 +279,31 @@ class TestNodeComponent : public NodeComponent<T, T>
}
};

template <typename T>
class TestRxNodeComponent : public RxNodeComponent<T, T>
{
using base_t = node::RxNodeComponent<T, T>;

public:
using typename base_t::stream_fn_t;

void make_stream(stream_fn_t fn)
{
return base_t::make_stream([this, fn](auto&&... args) {
m_stream_fn_called = true;
return fn(std::forward<decltype(args)>(args)...);
});
}

~TestRxNodeComponent() override
{
// Debug print
VLOG(10) << "Destroying TestRxNodeComponent";
}

bool m_stream_fn_called = false;
};

template <typename T>
class TestSinkComponent : public WritableProvider<T>
{
Expand Down Expand Up @@ -517,6 +543,26 @@ TEST_F(TestEdges, SourceToNodeComponentToSinkComponent)
source->run();
}

TEST_F(TestEdges, SourceToRxNodeComponentToSinkComponent)
{
auto source = std::make_shared<node::TestSource<int>>();
auto node = std::make_shared<node::TestRxNodeComponent<int>>();
auto sink = std::make_shared<node::TestSinkComponent<int>>();

mrc::make_edge(*source, *node);
mrc::make_edge(*node, *sink);

node->make_stream([=](rxcpp::observable<int> input) {
return input.map([](int i) {
return i * 2;
});
});

source->run();

EXPECT_TRUE(node->m_stream_fn_called);
}

TEST_F(TestEdges, SourceComponentToNodeToSinkComponent)
{
auto source = std::make_shared<node::TestSourceComponent<int>>();
Expand Down Expand Up @@ -825,6 +871,10 @@ TEST_F(TestEdges, CreateAndDestroy)
auto x = std::make_shared<node::TestNodeComponent<int>>();
}

{
auto x = std::make_shared<node::TestRxNodeComponent<int>>();
}

{
auto x = std::make_shared<node::TestSinkComponent<int>>();
}
Expand Down Expand Up @@ -927,4 +977,28 @@ TEST_F(TestEdges, EdgeTapWithSpliceComponent)
source->run();
sink->run();
}

TEST_F(TestEdges, EdgeTapWithSpliceRxComponent)
{
auto source = std::make_shared<node::TestSource<int>>();
auto node = std::make_shared<node::TestRxNodeComponent<int>>();
auto sink = std::make_shared<node::TestSink<int>>();

// Original edge
mrc::make_edge(*source, *sink);

node->make_stream([=](rxcpp::observable<int> input) {
return input.map([](int i) {
return i * 2;
});
});

// Tap edge
mrc::edge::EdgeBuilder::splice_edge<int>(*source, *sink, *node, *node);

source->run();
sink->run();

EXPECT_TRUE(node->m_stream_fn_called);
}
} // namespace mrc
50 changes: 50 additions & 0 deletions cpp/mrc/tests/test_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,56 @@ TEST_F(TestNode, NodePrologueEpilogue)
EXPECT_EQ(epilogue_tap_sum, 20);
}

TEST_F(TestNode, RxNodeComponentThrows)
{
auto p = pipeline::make_pipeline();
std::atomic<int> throw_count = 0;
std::atomic<int> sink_call_count = 0;
std::atomic<int> complete_count = 0;

auto my_segment = p->make_segment("test_segment", [&](segment::Builder& seg) {
auto source = seg.make_source<int>("source", [&](rxcpp::subscriber<int>& s) {
s.on_next(1);
s.on_next(2);
s.on_next(3);
s.on_completed();
});

auto node_comp = seg.make_node_component<int, int>("node", rxcpp::operators::map([&](int i) -> int {
++throw_count;
throw std::runtime_error("test");
return 0;
}));

auto sink = seg.make_sink<int>(
"sinkInt",
[&](const int& x) {
++sink_call_count;
DVLOG(1) << "Sink got value: '" << x << "'" << std::endl;
},
[&]() {
++complete_count;
DVLOG(1) << "Sink on_completed" << std::endl;
});

seg.make_edge(source, node_comp);
seg.make_edge(node_comp, sink);
});

auto options = std::make_unique<Options>();
options->topology().user_cpuset("0");

Executor exec(std::move(options));
exec.register_pipeline(std::move(p));
exec.start();

EXPECT_THROW(exec.join(), std::runtime_error);

EXPECT_EQ(throw_count, 1);
EXPECT_EQ(sink_call_count, 0);
EXPECT_EQ(complete_count, 0);
}

// the parallel tests:
// - SourceMultiThread
// - SinkMultiThread
Expand Down