From e9d8f7f1ce1dece10185d228271221884a2d713e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 2 Dec 2024 11:32:05 +0100 Subject: [PATCH] Fix Sonar issues Reduce cognitive complexity --- relay/client/client.go | 3 +- relay/client/dialer/race_dialer.go | 48 ++++++++++++++++--------- relay/client/dialer/race_dialer_test.go | 18 ++++++---- 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/relay/client/client.go b/relay/client/client.go index 45ecef6ca9c..4e7c58ab7e8 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -261,7 +261,8 @@ func (c *Client) Close() error { } func (c *Client) connect() error { - conn, err := dialer.RaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) + rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) + conn, err := rd.Dial() if err != nil { return err } diff --git a/relay/client/dialer/race_dialer.go b/relay/client/dialer/race_dialer.go index bf4aa56605c..8ffb84d318c 100644 --- a/relay/client/dialer/race_dialer.go +++ b/relay/client/dialer/race_dialer.go @@ -24,44 +24,51 @@ type dialResult struct { Err error } -func RaceDial(log *log.Entry, serverURL string, dialerFns ...DialerFn) (net.Conn, error) { - connChan := make(chan dialResult, len(dialerFns)) +type RaceDial struct { + log *log.Entry + serverURL string + dialerFns []DialerFn +} + +func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialerFn) *RaceDial { + return &RaceDial{ + log: log, + serverURL: serverURL, + dialerFns: dialerFns, + } +} + +func (r *RaceDial) Dial() (net.Conn, error) { + connChan := make(chan dialResult, len(r.dialerFns)) winnerConn := make(chan net.Conn, 1) abortCtx, abort := context.WithCancel(context.Background()) defer abort() - for _, d := range dialerFns { - go func() { - ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) - defer cancel() - - log.Infof("dialing Relay server via %s", d.Protocol()) - conn, err := d.Dial(ctx, serverURL) - connChan <- dialResult{Conn: conn, Protocol: d.Protocol(), Err: err} - }() + for _, dfn := range r.dialerFns { + go r.dial(dfn, abortCtx, connChan) } go func() { var hasWinner bool - for i := 0; i < len(dialerFns); i++ { + for i := 0; i < len(r.dialerFns); i++ { dr := <-connChan if dr.Err != nil { if errors.Is(dr.Err, context.Canceled) { - log.Infof("connection attempt aborted via: %s", dr.Protocol) + r.log.Infof("connection attempt aborted via: %s", dr.Protocol) } else { - log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err) + r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err) } continue } if hasWinner { if cerr := dr.Conn.Close(); cerr != nil { - log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr) + r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr) } continue } - log.Infof("successfully dialed via: %s", dr.Protocol) + r.log.Infof("successfully dialed via: %s", dr.Protocol) abort() hasWinner = true @@ -76,3 +83,12 @@ func RaceDial(log *log.Entry, serverURL string, dialerFns ...DialerFn) (net.Conn } return conn, nil } + +func (r *RaceDial) dial(dfn DialerFn, abortCtx context.Context, connChan chan dialResult) { + ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) + defer cancel() + + r.log.Infof("dialing Relay server via %s", dfn.Protocol()) + conn, err := dfn.Dial(ctx, r.serverURL) + connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err} +} diff --git a/relay/client/dialer/race_dialer_test.go b/relay/client/dialer/race_dialer_test.go index 4aef8df1ee5..a092f1061f1 100644 --- a/relay/client/dialer/race_dialer_test.go +++ b/relay/client/dialer/race_dialer_test.go @@ -77,7 +77,8 @@ func TestRaceDialEmptyDialers(t *testing.T) { logger := logrus.NewEntry(logrus.New()) serverURL := "test.server.com" - conn, err := RaceDial(logger, serverURL) + rd := NewRaceDial(logger, serverURL) + conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error with empty dialers, got nil") } @@ -102,7 +103,8 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) { protocolStr: proto, } - conn, err := RaceDial(logger, serverURL, mockDialer) + rd := NewRaceDial(logger, serverURL, mockDialer) + conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -134,7 +136,8 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { protocolStr: "proto2", } - conn, err := RaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -156,7 +159,8 @@ func TestRaceDialTimeout(t *testing.T) { protocolStr: "proto1", } - conn, err := RaceDial(logger, serverURL, mockDialer) + rd := NewRaceDial(logger, serverURL, mockDialer) + conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error, got nil") } @@ -183,7 +187,8 @@ func TestRaceDialAllDialersFail(t *testing.T) { protocolStr: "protocol2", } - conn, err := RaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error, got nil") } @@ -224,7 +229,8 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { protocolStr: proto2, } - conn, err := RaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) }