From 140deff3d1d59fd757e5a738008643be38fef54d Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Fri, 18 Aug 2023 17:07:31 +0200 Subject: [PATCH] emit ChannelClose and ChannelOpen when a client connects and disconnects (#71) --- channel.go | 8 +- channel_accepter.go => channel_provider.go | 20 +- endpoint.go | 8 +- endpoint_broadcast_test.go | 2 +- endpoint_client.go | 26 +- endpoint_client_test.go | 17 +- endpoint_custom_test.go | 41 +-- endpoint_serial.go | 21 +- endpoint_serial_test.go | 327 ++++++++++++------- endpoint_server.go | 2 +- endpoint_server_test.go | 2 +- node.go | 33 +- node_test.go | 25 +- pkg/autoreconnector/auto_reconnector.go | 171 ---------- pkg/autoreconnector/auto_reconnector_test.go | 108 ------ pkg/reconnector/reconnector.go | 113 +++++++ pkg/reconnector/reconnector_test.go | 66 ++++ 17 files changed, 486 insertions(+), 504 deletions(-) rename channel_accepter.go => channel_provider.go (51%) delete mode 100644 pkg/autoreconnector/auto_reconnector.go delete mode 100644 pkg/autoreconnector/auto_reconnector_test.go create mode 100644 pkg/reconnector/reconnector.go create mode 100644 pkg/reconnector/reconnector_test.go diff --git a/channel.go b/channel.go index 4af54e212..6e9324a00 100644 --- a/channel.go +++ b/channel.go @@ -10,9 +10,7 @@ import ( ) const ( - // this is low in order to avoid accumulating messages - // when a channel is reconnecting - writeBufferSize = 8 + writeBufferSize = 64 ) func randomByte() (byte, error) { @@ -93,12 +91,12 @@ func (ch *Channel) close() { func (ch *Channel) start() { ch.running = true - ch.n.channelsWg.Add(1) + ch.n.wg.Add(1) go ch.run() } func (ch *Channel) run() { - defer ch.n.channelsWg.Done() + defer ch.n.wg.Done() readerDone := make(chan struct{}) go ch.runReader(readerDone) diff --git a/channel_accepter.go b/channel_provider.go similarity index 51% rename from channel_accepter.go rename to channel_provider.go index 3132d35fe..d08c893e1 100644 --- a/channel_accepter.go +++ b/channel_provider.go @@ -4,32 +4,32 @@ import ( "fmt" ) -type channelAccepter struct { +type channelProvider struct { n *Node - eca endpointChannelAccepter + eca endpointChannelProvider } -func newChannelAccepter(n *Node, eca endpointChannelAccepter) (*channelAccepter, error) { - return &channelAccepter{ +func newChannelProvider(n *Node, eca endpointChannelProvider) (*channelProvider, error) { + return &channelProvider{ n: n, eca: eca, }, nil } -func (ca *channelAccepter) close() { +func (ca *channelProvider) close() { ca.eca.close() } -func (ca *channelAccepter) start() { - ca.n.channelAcceptersWg.Add(1) +func (ca *channelProvider) start() { + ca.n.wg.Add(1) go ca.run() } -func (ca *channelAccepter) run() { - defer ca.n.channelAcceptersWg.Done() +func (ca *channelProvider) run() { + defer ca.n.wg.Done() for { - label, rwc, err := ca.eca.accept() + label, rwc, err := ca.eca.provide() if err != nil { if err != errTerminated { panic("errTerminated is the only error allowed here") diff --git a/endpoint.go b/endpoint.go index 832b244df..e16021a98 100644 --- a/endpoint.go +++ b/endpoint.go @@ -18,7 +18,7 @@ type Endpoint interface { // a endpoint must also implement one of the following: // - endpointChannelSingle -// - endpointChannelAccepter +// - endpointChannelProvider // endpointChannelSingle is an endpoint that provides a single channel. // Read() must not return any error unless Close() is called. @@ -28,9 +28,9 @@ type endpointChannelSingle interface { io.ReadWriteCloser } -// endpointChannelAccepter is an endpoint that provides multiple channels. -type endpointChannelAccepter interface { +// endpointChannelProvider is an endpoint that provides multiple channels. +type endpointChannelProvider interface { Endpoint close() - accept() (string, io.ReadWriteCloser, error) + provide() (string, io.ReadWriteCloser, error) } diff --git a/endpoint_broadcast_test.go b/endpoint_broadcast_test.go index b489d88a0..4ce21e859 100644 --- a/endpoint_broadcast_test.go +++ b/endpoint_broadcast_test.go @@ -67,7 +67,7 @@ func TestEndpointBroadcast(t *testing.T) { }) require.NoError(t, err) - for i := 0; i < 3; i++ { + for i := 0; i < 3; i++ { //nolint:dupl msg := &MessageHeartbeat{ Type: 1, Autopilot: 2, diff --git a/endpoint_client.go b/endpoint_client.go index cfbe0be9a..e859f47e4 100644 --- a/endpoint_client.go +++ b/endpoint_client.go @@ -6,7 +6,7 @@ import ( "io" "net" - "github.com/bluenviron/gomavlib/v2/pkg/autoreconnector" + "github.com/bluenviron/gomavlib/v2/pkg/reconnector" "github.com/bluenviron/gomavlib/v2/pkg/timednetconn" ) @@ -56,8 +56,8 @@ func (conf EndpointUDPClient) init(node *Node) (Endpoint, error) { } type endpointClient struct { - conf endpointClientConf - io.ReadWriteCloser + conf endpointClientConf + reconnector *reconnector.Reconnector } func initEndpointClient(node *Node, conf endpointClientConf) (Endpoint, error) { @@ -68,11 +68,8 @@ func initEndpointClient(node *Node, conf endpointClientConf) (Endpoint, error) { t := &endpointClient{ conf: conf, - ReadWriteCloser: autoreconnector.New( + reconnector: reconnector.New( func(ctx context.Context) (io.ReadWriteCloser, error) { - // solve address and connect - // in UDP, the only possible error is a DNS failure - // in TCP, the handshake must be completed network := func() string { if conf.isUDP() { return "udp4" @@ -80,6 +77,8 @@ func initEndpointClient(node *Node, conf endpointClientConf) (Endpoint, error) { return "tcp4" }() + // in UDP, the only possible error is a DNS failure + // in TCP, the handshake must be completed timedContext, timedContextClose := context.WithTimeout(ctx, node.conf.ReadTimeout) nconn, err := (&net.Dialer{}).DialContext(timedContext, network, conf.getAddress()) timedContextClose() @@ -106,6 +105,19 @@ func (t *endpointClient) Conf() EndpointConf { return t.conf } +func (t *endpointClient) close() { + t.reconnector.Close() +} + +func (t *endpointClient) provide() (string, io.ReadWriteCloser, error) { + conn, ok := t.reconnector.Reconnect() + if !ok { + return "", nil, errTerminated + } + + return t.label(), conn, nil +} + func (t *endpointClient) label() string { return fmt.Sprintf("%s:%s", func() string { if t.conf.isUDP() { diff --git a/endpoint_client_test.go b/endpoint_client_test.go index ecbcba706..be005e39a 100644 --- a/endpoint_client_test.go +++ b/endpoint_client_test.go @@ -13,7 +13,7 @@ import ( "github.com/bluenviron/gomavlib/v2/pkg/frame" ) -var _ endpointChannelSingle = (*endpointClient)(nil) +var _ endpointChannelProvider = (*endpointClient)(nil) func TestEndpointClient(t *testing.T) { for _, ca := range []string{"tcp", "udp"} { @@ -32,15 +32,11 @@ func TestEndpointClient(t *testing.T) { } defer ln.Close() - connected := make(chan struct{}) - go func() { conn, err := ln.Accept() require.NoError(t, err) defer conn.Close() - close(connected) - dialectRW, err := dialect.NewReadWriter(testDialect) require.NoError(t, err) @@ -104,12 +100,6 @@ func TestEndpointClient(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - if ca == "tcp" { - <-connected - } else { - time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status - } - for i := 0; i < 3; i++ { node.WriteMessageAll(&MessageHeartbeat{ Type: 1, @@ -152,7 +142,6 @@ func TestEndpointClientIdleTimeout(t *testing.T) { require.NoError(t, err) defer ln.Close() - connected := make(chan struct{}) closed := make(chan struct{}) reconnected := make(chan struct{}) @@ -160,8 +149,6 @@ func TestEndpointClientIdleTimeout(t *testing.T) { conn, err := ln.Accept() require.NoError(t, err) - close(connected) - dialectRW, err := dialect.NewReadWriter(testDialect) require.NoError(t, err) @@ -227,8 +214,6 @@ func TestEndpointClientIdleTimeout(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - <-connected - node.WriteMessageAll(&MessageHeartbeat{ Type: 1, Autopilot: 2, diff --git a/endpoint_custom_test.go b/endpoint_custom_test.go index 93c85fc75..114f7fd0b 100644 --- a/endpoint_custom_test.go +++ b/endpoint_custom_test.go @@ -1,7 +1,6 @@ package gomavlib import ( - "bytes" "errors" "io" "testing" @@ -14,39 +13,39 @@ import ( var _ endpointChannelSingle = (*endpointCustom)(nil) -type dummyEndpoint struct { +type dummyReadWriter struct { chOut chan []byte chIn chan []byte chReadErr chan struct{} } -func newDummyEndpoint() *dummyEndpoint { - return &dummyEndpoint{ +func newDummyReadWriterPair() (*dummyReadWriter, *dummyReadWriter) { + one := &dummyReadWriter{ chOut: make(chan []byte), chIn: make(chan []byte), chReadErr: make(chan struct{}), } -} -func (e *dummyEndpoint) simulateReadError() { - close(e.chReadErr) -} + two := &dummyReadWriter{ + chOut: one.chIn, + chIn: one.chOut, + chReadErr: make(chan struct{}), + } -func (e *dummyEndpoint) push(buf []byte) { - e.chOut <- buf + return one, two } -func (e *dummyEndpoint) pull() []byte { - return <-e.chIn +func (e *dummyReadWriter) simulateReadError() { + close(e.chReadErr) } -func (e *dummyEndpoint) Close() error { +func (e *dummyReadWriter) Close() error { close(e.chOut) close(e.chIn) return nil } -func (e *dummyEndpoint) Read(p []byte) (int, error) { +func (e *dummyReadWriter) Read(p []byte) (int, error) { select { case buf, ok := <-e.chOut: if !ok { @@ -58,19 +57,19 @@ func (e *dummyEndpoint) Read(p []byte) (int, error) { } } -func (e *dummyEndpoint) Write(p []byte) (int, error) { +func (e *dummyReadWriter) Write(p []byte) (int, error) { e.chIn <- p return len(p), nil } func TestEndpointCustom(t *testing.T) { - de := newDummyEndpoint() + remote, local := newDummyReadWriterPair() node, err := NewNode(NodeConf{ Dialect: testDialect, OutVersion: V2, OutSystemID: 10, - Endpoints: []EndpointConf{EndpointCustom{de}}, + Endpoints: []EndpointConf{EndpointCustom{remote}}, HeartbeatDisable: true, }) require.NoError(t, err) @@ -84,10 +83,8 @@ func TestEndpointCustom(t *testing.T) { dialectRW, err := dialect.NewReadWriter(testDialect) require.NoError(t, err) - var buf bytes.Buffer - rw, err := frame.NewReadWriter(frame.ReadWriterConf{ - ReadWriter: &buf, + ReadWriter: local, DialectRW: dialectRW, OutVersion: frame.V2, OutSystemID: 11, @@ -105,8 +102,6 @@ func TestEndpointCustom(t *testing.T) { } err = rw.WriteMessage(msg) require.NoError(t, err) - de.push(buf.Bytes()) - buf.Reset() evt = <-node.Events() require.Equal(t, &EventFrame{ @@ -130,8 +125,6 @@ func TestEndpointCustom(t *testing.T) { } node.WriteMessageAll(msg) - buf2 := de.pull() - buf.Write(buf2) fr, err := rw.Read() require.NoError(t, err) require.Equal(t, &frame.V2Frame{ diff --git a/endpoint_serial.go b/endpoint_serial.go index af8479826..c6ea587ac 100644 --- a/endpoint_serial.go +++ b/endpoint_serial.go @@ -6,7 +6,7 @@ import ( "github.com/tarm/serial" - "github.com/bluenviron/gomavlib/v2/pkg/autoreconnector" + "github.com/bluenviron/gomavlib/v2/pkg/reconnector" ) var serialOpenFunc = func(device string, baud int) (io.ReadWriteCloser, error) { @@ -26,8 +26,8 @@ type EndpointSerial struct { } type endpointSerial struct { - conf EndpointConf - io.ReadWriteCloser + conf EndpointConf + reconnector *reconnector.Reconnector } func (conf EndpointSerial) init(_ *Node) (Endpoint, error) { @@ -40,7 +40,7 @@ func (conf EndpointSerial) init(_ *Node) (Endpoint, error) { t := &endpointSerial{ conf: conf, - ReadWriteCloser: autoreconnector.New( + reconnector: reconnector.New( func(ctx context.Context) (io.ReadWriteCloser, error) { return serialOpenFunc(conf.Device, conf.Baud) }, @@ -56,6 +56,15 @@ func (t *endpointSerial) Conf() EndpointConf { return t.conf } -func (t *endpointSerial) label() string { - return "serial" +func (t *endpointSerial) close() { + t.reconnector.Close() +} + +func (t *endpointSerial) provide() (string, io.ReadWriteCloser, error) { + conn, ok := t.reconnector.Reconnect() + if !ok { + return "", nil, errTerminated + } + + return "serial", conn, nil } diff --git a/endpoint_serial_test.go b/endpoint_serial_test.go index 09deafcb7..2925697a3 100644 --- a/endpoint_serial_test.go +++ b/endpoint_serial_test.go @@ -1,7 +1,6 @@ package gomavlib import ( - "bytes" "io" "testing" @@ -11,14 +10,66 @@ import ( "github.com/bluenviron/gomavlib/v2/pkg/frame" ) -var _ endpointChannelSingle = (*endpointSerial)(nil) +var _ endpointChannelProvider = (*endpointSerial)(nil) func TestEndpointSerial(t *testing.T) { - endpointCreated := make(chan *dummyEndpoint, 1) + done := make(chan struct{}) + first := false + serialOpenFunc = func(name string, baud int) (io.ReadWriteCloser, error) { - de := newDummyEndpoint() - endpointCreated <- de - return de, nil + remote, local := newDummyReadWriterPair() + + // skip first call to serialOpenFunc() + if !first { + first = true + return remote, nil + } + + go func() { + dialectRW, err := dialect.NewReadWriter(testDialect) + require.NoError(t, err) + + rw, err := frame.NewReadWriter(frame.ReadWriterConf{ + ReadWriter: local, + DialectRW: dialectRW, + OutVersion: frame.V2, + OutSystemID: 11, + }) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + err = rw.WriteMessage(&MessageHeartbeat{ + Type: 1, + Autopilot: 2, + BaseMode: 3, + CustomMode: 6, + SystemStatus: 4, + MavlinkVersion: 5, + }) + require.NoError(t, err) + + fr, err := rw.Read() + require.NoError(t, err) + require.Equal(t, &frame.V2Frame{ + SequenceID: byte(i), + SystemID: 10, + ComponentID: 1, + Message: &MessageHeartbeat{ + Type: 6, + Autopilot: 5, + BaseMode: 4, + CustomMode: 3, + SystemStatus: 2, + MavlinkVersion: 1, + }, + Checksum: fr.GetChecksum(), + }, fr) + } + + close(done) + }() + + return remote, nil } node, err := NewNode(NodeConf{ @@ -34,84 +85,133 @@ func TestEndpointSerial(t *testing.T) { require.NoError(t, err) defer node.Close() - <-endpointCreated - evt := <-node.Events() require.Equal(t, &EventChannelOpen{ Channel: evt.(*EventChannelOpen).Channel, }, evt) - de := <-endpointCreated - - dialectRW, err := dialect.NewReadWriter(testDialect) - require.NoError(t, err) - - var buf bytes.Buffer - - rw, err := frame.NewReadWriter(frame.ReadWriterConf{ - ReadWriter: &buf, - DialectRW: dialectRW, - OutVersion: frame.V2, - OutSystemID: 11, - }) - require.NoError(t, err) - - for i := 0; i < 3; i++ { //nolint:dupl - msg := &MessageHeartbeat{ - Type: 1, - Autopilot: 2, - BaseMode: 3, - CustomMode: 6, - SystemStatus: 4, - MavlinkVersion: 5, - } - err = rw.WriteMessage(msg) - require.NoError(t, err) - de.push(buf.Bytes()) - buf.Reset() - - evt = <-node.Events() + for i := 0; i < 3; i++ { + evt := <-node.Events() require.Equal(t, &EventFrame{ Frame: &frame.V2Frame{ SequenceID: byte(i), SystemID: 11, ComponentID: 1, - Message: msg, - Checksum: evt.(*EventFrame).Frame.GetChecksum(), + Message: &MessageHeartbeat{ + Type: 1, + Autopilot: 2, + BaseMode: 3, + CustomMode: 6, + SystemStatus: 4, + MavlinkVersion: 5, + }, + Checksum: evt.(*EventFrame).Frame.GetChecksum(), }, Channel: evt.(*EventFrame).Channel, }, evt) - msg = &MessageHeartbeat{ + node.WriteMessageAll(&MessageHeartbeat{ Type: 6, Autopilot: 5, BaseMode: 4, CustomMode: 3, SystemStatus: 2, MavlinkVersion: 1, - } - node.WriteMessageAll(msg) - - buf2 := de.pull() - buf.Write(buf2) - fr, err := rw.Read() - require.NoError(t, err) - require.Equal(t, &frame.V2Frame{ - SequenceID: byte(i), - SystemID: 10, - ComponentID: 1, - Message: msg, - Checksum: fr.GetChecksum(), - }, fr) + }) } + + <-done } func TestEndpointSerialReconnect(t *testing.T) { - endpointCreated := make(chan *dummyEndpoint, 1) + done := make(chan struct{}) + count := 0 + serialOpenFunc = func(name string, baud int) (io.ReadWriteCloser, error) { - de := newDummyEndpoint() - endpointCreated <- de - return de, nil + remote, local := newDummyReadWriterPair() + + switch count { + case 0: // skip first call to serialOpenFunc() + + case 1: + go func() { + dialectRW, err := dialect.NewReadWriter(testDialect) + require.NoError(t, err) + + rw, err := frame.NewReadWriter(frame.ReadWriterConf{ + ReadWriter: local, + DialectRW: dialectRW, + OutVersion: frame.V2, + OutSystemID: 11, + }) + require.NoError(t, err) + + err = rw.WriteMessage(&MessageHeartbeat{ + Type: 1, + Autopilot: 2, + BaseMode: 3, + CustomMode: 6, + SystemStatus: 4, + MavlinkVersion: 5, + }) + require.NoError(t, err) + + fr, err := rw.Read() + require.NoError(t, err) + require.Equal(t, &frame.V2Frame{ + SequenceID: 0, + SystemID: 10, + ComponentID: 1, + Message: &MessageHeartbeat{ + Type: 6, + Autopilot: 5, + BaseMode: 4, + CustomMode: 3, + SystemStatus: 2, + MavlinkVersion: 1, + }, + Checksum: fr.GetChecksum(), + }, fr) + + remote.simulateReadError() + }() + + case 2: + go func() { + dialectRW, err := dialect.NewReadWriter(testDialect) + require.NoError(t, err) + + rw, err := frame.NewReadWriter(frame.ReadWriterConf{ + ReadWriter: local, + DialectRW: dialectRW, + OutVersion: frame.V2, + OutSystemID: 11, + }) + require.NoError(t, err) + + fr, err := rw.Read() + require.NoError(t, err) + require.Equal(t, &frame.V2Frame{ + SequenceID: 0, + SystemID: 10, + ComponentID: 1, + Message: &MessageHeartbeat{ + Type: 7, + Autopilot: 5, + BaseMode: 4, + CustomMode: 3, + SystemStatus: 2, + MavlinkVersion: 1, + }, + Checksum: fr.GetChecksum(), + }, fr) + + close(done) + }() + } + + count++ + return remote, nil } node, err := NewNode(NodeConf{ @@ -127,82 +227,57 @@ func TestEndpointSerialReconnect(t *testing.T) { require.NoError(t, err) defer node.Close() - <-endpointCreated - evt := <-node.Events() require.Equal(t, &EventChannelOpen{ Channel: evt.(*EventChannelOpen).Channel, }, evt) - de := <-endpointCreated - - dialectRW, err := dialect.NewReadWriter(testDialect) - require.NoError(t, err) - - var buf bytes.Buffer + evt = <-node.Events() + require.Equal(t, &EventFrame{ + Frame: &frame.V2Frame{ + SequenceID: 0, + SystemID: 11, + ComponentID: 1, + Message: &MessageHeartbeat{ + Type: 1, + Autopilot: 2, + BaseMode: 3, + CustomMode: 6, + SystemStatus: 4, + MavlinkVersion: 5, + }, + Checksum: evt.(*EventFrame).Frame.GetChecksum(), + }, + Channel: evt.(*EventFrame).Channel, + }, evt) - rw, err := frame.NewReadWriter(frame.ReadWriterConf{ - ReadWriter: &buf, - DialectRW: dialectRW, - OutVersion: frame.V2, - OutSystemID: 11, + node.WriteMessageAll(&MessageHeartbeat{ + Type: 6, + Autopilot: 5, + BaseMode: 4, + CustomMode: 3, + SystemStatus: 2, + MavlinkVersion: 1, }) - require.NoError(t, err) - for i := 0; i < 2; i++ { - msg := &MessageHeartbeat{ - Type: 1, - Autopilot: 2, - BaseMode: 3, - CustomMode: 6, - SystemStatus: 4, - MavlinkVersion: 5, - } - err = rw.WriteMessage(msg) - require.NoError(t, err) - de.push(buf.Bytes()) - buf.Reset() + evt = <-node.Events() + require.Equal(t, &EventChannelClose{ + Channel: evt.(*EventChannelClose).Channel, + }, evt) - evt := <-node.Events() - require.Equal(t, &EventFrame{ - Frame: &frame.V2Frame{ - SequenceID: byte(i), - SystemID: 11, - ComponentID: 1, - Message: msg, - Checksum: evt.(*EventFrame).Frame.GetChecksum(), - }, - Channel: evt.(*EventFrame).Channel, - }, evt) - } + evt = <-node.Events() + require.Equal(t, &EventChannelOpen{ + Channel: evt.(*EventChannelOpen).Channel, + }, evt) - de.simulateReadError() - de = <-endpointCreated - - for i := 0; i < 2; i++ { - msg := &MessageHeartbeat{ - Type: 1, - Autopilot: 2, - BaseMode: 3, - CustomMode: 6, - SystemStatus: 4, - MavlinkVersion: 5, - } - err = rw.WriteMessage(msg) - require.NoError(t, err) - de.chOut <- buf.Bytes() - buf.Reset() + node.WriteMessageAll(&MessageHeartbeat{ + Type: 7, + Autopilot: 5, + BaseMode: 4, + CustomMode: 3, + SystemStatus: 2, + MavlinkVersion: 1, + }) - evt := <-node.Events() - require.Equal(t, &EventFrame{ - Frame: &frame.V2Frame{ - SequenceID: 2 + byte(i), - SystemID: 11, - ComponentID: 1, - Message: msg, - Checksum: evt.(*EventFrame).Frame.GetChecksum(), - }, - Channel: evt.(*EventFrame).Channel, - }, evt) - } + <-done } diff --git a/endpoint_server.go b/endpoint_server.go index 85fd13754..d39ac1de3 100644 --- a/endpoint_server.go +++ b/endpoint_server.go @@ -113,7 +113,7 @@ func (t *endpointServer) close() { t.listener.Close() } -func (t *endpointServer) accept() (string, io.ReadWriteCloser, error) { +func (t *endpointServer) provide() (string, io.ReadWriteCloser, error) { nconn, err := t.listener.Accept() // wait termination, do not report errors if err != nil { diff --git a/endpoint_server_test.go b/endpoint_server_test.go index ccbf4668f..3096b7343 100644 --- a/endpoint_server_test.go +++ b/endpoint_server_test.go @@ -11,7 +11,7 @@ import ( "github.com/bluenviron/gomavlib/v2/pkg/frame" ) -var _ endpointChannelAccepter = (*endpointServer)(nil) +var _ endpointChannelProvider = (*endpointServer)(nil) func TestEndpointServer(t *testing.T) { for _, ca := range []string{"tcp", "udp"} { diff --git a/node.go b/node.go index e97420cca..579456e35 100644 --- a/node.go +++ b/node.go @@ -90,14 +90,13 @@ type NodeConf struct { // Node is a high-level Mavlink encoder and decoder that works with endpoints. type Node struct { - conf NodeConf - dialectRW *dialect.ReadWriter - channelAccepters map[*channelAccepter]struct{} - channelAcceptersWg sync.WaitGroup - channels map[*Channel]struct{} - channelsWg sync.WaitGroup - nodeHeartbeat *nodeHeartbeat - nodeStreamRequest *nodeStreamRequest + conf NodeConf + dialectRW *dialect.ReadWriter + wg sync.WaitGroup + channelProviders map[*channelProvider]struct{} + channels map[*Channel]struct{} + nodeHeartbeat *nodeHeartbeat + nodeStreamRequest *nodeStreamRequest // in chNewChannel chan *Channel @@ -167,7 +166,7 @@ func NewNode(conf NodeConf) (*Node, error) { n := &Node{ conf: conf, dialectRW: dialectRW, - channelAccepters: make(map[*channelAccepter]struct{}), + channelProviders: make(map[*channelProvider]struct{}), channels: make(map[*Channel]struct{}), chNewChannel: make(chan *Channel), chCloseChannel: make(chan *Channel), @@ -183,7 +182,7 @@ func NewNode(conf NodeConf) (*Node, error) { for ch := range n.channels { ch.close() } - for ca := range n.channelAccepters { + for ca := range n.channelProviders { ca.close() } } @@ -197,14 +196,14 @@ func NewNode(conf NodeConf) (*Node, error) { } switch ttp := tp.(type) { - case endpointChannelAccepter: - ca, err := newChannelAccepter(n, ttp) + case endpointChannelProvider: + ca, err := newChannelProvider(n, ttp) if err != nil { closeExisting() return nil, err } - n.channelAccepters[ca] = struct{}{} + n.channelProviders[ca] = struct{}{} case endpointChannelSingle: ch, err := newChannel(n, ttp, ttp.label(), ttp) @@ -235,7 +234,7 @@ func NewNode(conf NodeConf) (*Node, error) { ch.start() } - for ca := range n.channelAccepters { + for ca := range n.channelProviders { ca.start() } @@ -307,15 +306,15 @@ outer: n.nodeStreamRequest.close() } - for ca := range n.channelAccepters { + for ca := range n.channelProviders { ca.close() } - n.channelAcceptersWg.Wait() for ch := range n.channels { ch.close() } - n.channelsWg.Wait() + + n.wg.Wait() close(n.chEvent) } diff --git a/node_test.go b/node_test.go index db7ab1abb..3ea6ab9af 100644 --- a/node_test.go +++ b/node_test.go @@ -4,7 +4,6 @@ import ( "bytes" "sync" "testing" - "time" "github.com/stretchr/testify/require" @@ -97,7 +96,10 @@ func TestNodeCloseInLoop(t *testing.T) { require.NoError(t, err) defer node2.Close() - time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + evt := <-node2.Events() + require.Equal(t, &EventChannelOpen{ + Channel: evt.(*EventChannelOpen).Channel, + }, evt) node2.WriteMessageAll(testMessage) @@ -368,7 +370,10 @@ func TestNodeWriteMessageInLoop(t *testing.T) { require.NoError(t, err) defer node2.Close() - time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + evt := <-node2.Events() + require.Equal(t, &EventChannelOpen{ + Channel: evt.(*EventChannelOpen).Channel, + }, evt) node2.WriteMessageAll(testMessage) @@ -414,12 +419,15 @@ func TestNodeSignature(t *testing.T) { require.NoError(t, err) defer node2.Close() - time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + evt := <-node2.Events() + require.Equal(t, &EventChannelOpen{ + Channel: evt.(*EventChannelOpen).Channel, + }, evt) node2.WriteMessageAll(testMessage) <-node1.Events() - evt := <-node1.Events() + evt = <-node1.Events() fr, ok := evt.(*EventFrame) require.Equal(t, true, ok) require.Equal(t, &EventFrame{ @@ -526,7 +534,10 @@ func TestNodeFixFrame(t *testing.T) { require.NoError(t, err) defer node2.Close() - time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + evt := <-node2.Events() + require.Equal(t, &EventChannelOpen{ + Channel: evt.(*EventChannelOpen).Channel, + }, evt) fra := &frame.V2Frame{ SequenceID: 13, @@ -544,7 +555,7 @@ func TestNodeFixFrame(t *testing.T) { node2.WriteFrameAll(fra) <-node1.Events() - evt := <-node1.Events() + evt = <-node1.Events() fr, ok := evt.(*EventFrame) require.Equal(t, true, ok) require.Equal(t, &EventFrame{ diff --git a/pkg/autoreconnector/auto_reconnector.go b/pkg/autoreconnector/auto_reconnector.go deleted file mode 100644 index 47d394eab..000000000 --- a/pkg/autoreconnector/auto_reconnector.go +++ /dev/null @@ -1,171 +0,0 @@ -// Package autoreconnector contains a io.ReadWriteCloser wrapper that implements automatic reconnection. -package autoreconnector - -import ( - "context" - "errors" - "io" - "sync" - "time" -) - -var ( - reconnectPeriod = 2 * time.Second - errTerminated = errors.New("terminated") - errReconnecting = errors.New("reconnecting") -) - -type state int - -const ( - stateInitial state = iota - stateReconnecting - stateConnected - stateTerminated -) - -type autoReconnector struct { - connect func(context.Context) (io.ReadWriteCloser, error) - - mutex sync.Mutex - state state - conn io.ReadWriteCloser - connectCtx context.Context - connectCtxCancel func() -} - -// New returns a io.ReadWriterCloser that implements auto-reconnection. -func New( - connect func(context.Context) (io.ReadWriteCloser, error), -) io.ReadWriteCloser { - a := &autoReconnector{ - connect: connect, - } - - a.resetConnection() - - return a -} - -func (a *autoReconnector) Close() error { - a.mutex.Lock() - defer a.mutex.Unlock() - - a.state = stateTerminated - - if a.connectCtxCancel != nil { - a.connectCtxCancel() - } - - if a.conn != nil { - a.conn.Close() - a.conn = nil - } - - return nil -} - -func (a *autoReconnector) getConnection() (io.ReadWriteCloser, context.Context, error) { - a.mutex.Lock() - defer a.mutex.Unlock() - - switch a.state { - case stateTerminated: - return nil, nil, errTerminated - - case stateReconnecting: - return nil, a.connectCtx, errReconnecting - - default: - return a.conn, nil, nil - } -} - -func (a *autoReconnector) resetConnection() { - a.mutex.Lock() - defer a.mutex.Unlock() - - switch a.state { - case stateTerminated, stateReconnecting: - return - } - - a.state = stateReconnecting - - if a.conn != nil { - a.conn.Close() - a.conn = nil - } - - a.connectCtx, a.connectCtxCancel = context.WithCancel(context.Background()) - - go func() { - for { - newConn, err := a.connect(a.connectCtx) - if err == nil { - a.setConn(newConn) - return - } - - select { - case <-time.After(reconnectPeriod): - case <-a.connectCtx.Done(): - return - } - } - }() -} - -func (a *autoReconnector) setConn(newConn io.ReadWriteCloser) { - a.mutex.Lock() - defer a.mutex.Unlock() - - if a.state != stateReconnecting { - newConn.Close() - return - } - - a.connectCtxCancel() - a.connectCtxCancel = nil - - a.conn = newConn - a.state = stateConnected -} - -func (a *autoReconnector) Read(p []byte) (int, error) { - for { - curConn, connectCtx, err := a.getConnection() - if err == errReconnecting { - <-connectCtx.Done() - continue - } - if err != nil { - return 0, err - } - - n, err := curConn.Read(p) - - if n == 0 { - a.resetConnection() - continue - } - - return n, err - } -} - -func (a *autoReconnector) Write(p []byte) (int, error) { - curConn, _, err := a.getConnection() - if err != nil { - return 0, err - } - - n, err := curConn.Write(p) - - if n == 0 { - a.resetConnection() - return 0, errReconnecting - } - - return n, err -} diff --git a/pkg/autoreconnector/auto_reconnector_test.go b/pkg/autoreconnector/auto_reconnector_test.go deleted file mode 100644 index b2efe4ffe..000000000 --- a/pkg/autoreconnector/auto_reconnector_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package autoreconnector - -import ( - "context" - "io" - "net" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestReconnect(t *testing.T) { - ln, err := net.Listen("tcp", "localhost:6657") - require.NoError(t, err) - defer ln.Close() - - go func() { - for i := 0; i < 2; i++ { - conn, err := ln.Accept() - require.NoError(t, err) - - _, err = conn.Write([]byte{0x05 + byte(i)}) - require.NoError(t, err) - - conn.Close() - } - }() - - a := New( - func(ctx context.Context) (io.ReadWriteCloser, error) { - return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost:6657") - }, - ) - defer a.Close() - - p := make([]byte, 1) - n, err := a.Read(p) - require.NoError(t, err) - require.Equal(t, []byte{0x05}, p[:n]) - - p = make([]byte, 1) - n, err = a.Read(p) - require.NoError(t, err) - require.Equal(t, []byte{0x06}, p[:n]) -} - -func TestCloseWhileWorking(t *testing.T) { - for _, ca := range []string{"read", "write"} { - t.Run(ca, func(t *testing.T) { - ln, err := net.Listen("tcp", "localhost:6657") - require.NoError(t, err) - defer ln.Close() - - serverDone := make(chan struct{}) - - go func() { - conn, err := ln.Accept() - require.NoError(t, err) - - b := make([]byte, 1) - if ca == "read" { - conn.Read(b) //nolint:errcheck - } else { - conn.Write(b) //nolint:errcheck - } - - conn.Close() - close(serverDone) - }() - - a := New( - func(ctx context.Context) (io.ReadWriteCloser, error) { - return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost:6657") - }, - ) - - workDone := make(chan struct{}) - - go func() { - defer close(workDone) - - p := make([]byte, 1) - if ca == "read" { - a.Read(p) //nolint:errcheck - } else { - a.Write(p) //nolint:errcheck - } - }() - - time.Sleep(500 * time.Millisecond) - a.Close() - <-workDone - <-serverDone - }) - } -} - -func TestWriteAfterClose(t *testing.T) { - a := New( - func(ctx context.Context) (io.ReadWriteCloser, error) { - return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost:6657") - }, - ) - a.Close() - _, err := a.Write([]byte{1, 2, 3, 4}) - require.EqualError(t, err, "terminated") -} diff --git a/pkg/reconnector/reconnector.go b/pkg/reconnector/reconnector.go new file mode 100644 index 000000000..112da12dd --- /dev/null +++ b/pkg/reconnector/reconnector.go @@ -0,0 +1,113 @@ +// Package reconnector allows to perform automatic reconnections. +package reconnector + +import ( + "context" + "io" + "sync" + "time" +) + +var reconnectPeriod = 2 * time.Second + +// ConnectFunc is the prototype of the callback passed to New() +type ConnectFunc func(context.Context) (io.ReadWriteCloser, error) + +type connWithContext struct { + rwc io.ReadWriteCloser + mutex sync.Mutex + ctx context.Context + ctxCancel func() +} + +func newConnWithContext(rwc io.ReadWriteCloser) *connWithContext { + ctx, ctxCancel := context.WithCancel(context.Background()) + + return &connWithContext{ + rwc: rwc, + ctx: ctx, + ctxCancel: ctxCancel, + } +} + +func (c *connWithContext) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + select { + case <-c.ctx.Done(): + return nil + default: + } + + c.ctxCancel() + + return c.rwc.Close() +} + +func (c *connWithContext) Read(p []byte) (int, error) { + n, err := c.rwc.Read(p) + if n == 0 { + c.Close() //nolint:errcheck + } + return n, err +} + +func (c *connWithContext) Write(p []byte) (int, error) { + n, err := c.rwc.Write(p) + if n == 0 { + c.Close() //nolint:errcheck + } + return n, err +} + +// Reconnector allocws to perform automatic reconnections. +type Reconnector struct { + connect ConnectFunc + + ctx context.Context + ctxCancel func() + curConn *connWithContext +} + +// New allocates a Reconnector. +func New(connect ConnectFunc) *Reconnector { + ctx, ctxCancel := context.WithCancel(context.Background()) + + return &Reconnector{ + connect: connect, + ctx: ctx, + ctxCancel: ctxCancel, + } +} + +// Close closes a reconnector. +func (a *Reconnector) Close() { + a.ctxCancel() +} + +// Reconnect returns the next working connection. +func (a *Reconnector) Reconnect() (io.ReadWriteCloser, bool) { + if a.curConn != nil { + select { + case <-a.curConn.ctx.Done(): + case <-a.ctx.Done(): + return nil, false + } + } + + for { + conn, err := a.connect(a.ctx) + if err != nil { + select { + case <-time.After(reconnectPeriod): + continue + case <-a.ctx.Done(): + return nil, false + } + } + + a.curConn = newConnWithContext(conn) + return a.curConn, true + } +} diff --git a/pkg/reconnector/reconnector_test.go b/pkg/reconnector/reconnector_test.go new file mode 100644 index 000000000..87eddbce6 --- /dev/null +++ b/pkg/reconnector/reconnector_test.go @@ -0,0 +1,66 @@ +package reconnector + +import ( + "bytes" + "context" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +type dummyRWC struct { + bytes.Buffer + closed bool +} + +func (dummyRWC) Close() error { + return nil +} + +func (d *dummyRWC) Read(p []byte) (int, error) { + return d.Buffer.Read(p) +} + +func (d *dummyRWC) Write(p []byte) (int, error) { + if d.closed { + return 0, fmt.Errorf("closed") + } + return d.Buffer.Write(p) +} + +func TestReconnector(t *testing.T) { + var buf dummyRWC + + r := New( + func(ctx context.Context) (io.ReadWriteCloser, error) { + return &buf, nil + }, + ) + + conn, ok := r.Reconnect() + require.Equal(t, true, ok) + + buf.Buffer.Write([]byte{1}) + + recv := make([]byte, 1) + _, err := conn.Read(recv) + require.NoError(t, err) + require.Equal(t, byte(1), recv[0]) + + _, err = conn.Read(recv) + require.Equal(t, io.EOF, err) + + buf.closed = true + _, err = conn.Write(recv) + require.Error(t, err) + + _, ok = r.Reconnect() + require.Equal(t, true, ok) + + r.Close() + + _, ok = r.Reconnect() + require.Equal(t, false, ok) +}