diff --git a/demo/uring/UringNetCat.cxx b/demo/uring/UringNetCat.cxx index f874b3d5..929ad352 100644 --- a/demo/uring/UringNetCat.cxx +++ b/demo/uring/UringNetCat.cxx @@ -8,6 +8,7 @@ #include "event/net/BufferedSocket.hxx" #include "event/net/ConnectSocket.hxx" #include "event/Loop.hxx" +#include "event/ShutdownListener.hxx" #include "system/Error.hxx" #include "util/PrintException.hxx" @@ -18,6 +19,7 @@ #include class NetCat final : ConnectSocketHandler, BufferedSocketHandler { + ShutdownListener shutdown_listener; ConnectSocket connect_socket; BufferedSocket socket; @@ -26,8 +28,10 @@ class NetCat final : ConnectSocketHandler, BufferedSocketHandler { public: [[nodiscard]] explicit NetCat(EventLoop &event_loop) noexcept - :connect_socket(event_loop, *this), socket(event_loop) + :shutdown_listener(event_loop, BIND_THIS_METHOD(OnShutdown)), + connect_socket(event_loop, *this), socket(event_loop) { + shutdown_listener.Enable(); } auto &GetEventLoop() const noexcept { @@ -44,6 +48,14 @@ class NetCat final : ConnectSocketHandler, BufferedSocketHandler { } private: + void OnShutdown() noexcept { + if (connect_socket.IsPending()) + connect_socket.Cancel(); + else + socket.Close(); + GetEventLoop().SetVolatile(); + } + /* virtual methods from class ConnectSocketHandler */ void OnSocketConnectSuccess(UniqueSocketDescriptor fd) noexcept override { socket.Init(fd.Release(), FdType::FD_TCP, std::chrono::minutes{1}, *this); @@ -51,6 +63,7 @@ class NetCat final : ConnectSocketHandler, BufferedSocketHandler { } void OnSocketConnectError(std::exception_ptr e) noexcept override { + shutdown_listener.Disable(); GetEventLoop().SetVolatile(); error = std::move(e); } @@ -69,6 +82,7 @@ class NetCat final : ConnectSocketHandler, BufferedSocketHandler { } bool OnBufferedEnd() override { + shutdown_listener.Disable(); GetEventLoop().SetVolatile(); return true; } @@ -78,6 +92,7 @@ class NetCat final : ConnectSocketHandler, BufferedSocketHandler { } void OnBufferedError(std::exception_ptr e) noexcept override { + shutdown_listener.Disable(); GetEventLoop().SetVolatile(); error = std::move(e); }