From 9d213e0b54dc3c9005590b376d8c6d465ae42f09 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 12 Mar 2024 18:05:41 +0100 Subject: [PATCH] Add fallback retry to daemon (#1690) This change adds a fallback retry to the daemon service. this retry has a larger interval with a shorter max retry run time then others retries --- client/server/server.go | 127 +++++++++++++++++++++++++--- client/server/server_test.go | 157 +++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 12 deletions(-) create mode 100644 client/server/server_test.go diff --git a/client/server/server.go b/client/server/server.go index fc1e4cc2642..90b5bcb642c 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -3,11 +3,15 @@ package server import ( "context" "fmt" + "os" "os/exec" "runtime" + "strconv" "sync" "time" + "github.com/cenkalti/backoff/v4" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/system" @@ -23,7 +27,17 @@ import ( "github.com/netbirdio/netbird/version" ) -const probeThreshold = time.Second * 5 +const ( + probeThreshold = time.Second * 5 + retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME" + maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME" + maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME" + retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER" + defaultInitialRetryTime = 14 * 24 * time.Hour + defaultMaxRetryInterval = 60 * time.Minute + defaultMaxRetryTime = 14 * 24 * time.Hour + defaultRetryMultiplier = 1.7 +) // Server for service control. type Server struct { @@ -125,16 +139,110 @@ func (s *Server) Start() error { } if !config.DisableAutoConnect { - go func() { - if err := internal.RunClientWithProbes(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil { - log.Errorf("init connections: %v", err) - } - }() + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) } return nil } +// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional +// mechanism to keep the client connected even when the connection is lost. +// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. +func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, + mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe) { + backOff := getConnectWithBackoff(ctx) + retryStarted := false + + go func() { + t := time.NewTicker(24 * time.Hour) + for { + select { + case <-ctx.Done(): + t.Stop() + return + case <-t.C: + if retryStarted { + + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + retryStarted = false + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) + } + } + } + } + }() + + runOperation := func() error { + log.Tracef("running client connection") + err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) + if err != nil { + log.Debugf("run client connection exited with error: %v. Will retry in the background", err) + } + + if config.DisableAutoConnect { + return backoff.Permanent(err) + } + + if !retryStarted { + retryStarted = true + backOff.Reset() + } + + log.Tracef("client connection exited") + return fmt.Errorf("client connection exited") + } + + err := backoff.Retry(runOperation, backOff) + if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { + log.Errorf("received an error when trying to connect: %v", err) + } else { + log.Tracef("retry canceled") + } +} + +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } + } + + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + } + return defaultDuration +} + // loginAttempt attempts to login using the provided information. it returns a status in case something fails func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { var status internal.StatusType @@ -445,12 +553,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) - go func() { - if err := internal.RunClientWithProbes(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil { - log.Errorf("run client connection: %v", err) - return - } - }() + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) return &proto.UpResponse{}, nil } diff --git a/client/server/server_test.go b/client/server/server_test.go new file mode 100644 index 00000000000..79a22002311 --- /dev/null +++ b/client/server/server_test.go @@ -0,0 +1,157 @@ +package server + +import ( + "context" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/signal/proto" + signalServer "github.com/netbirdio/netbird/signal/server" +) + +var ( + kaep = keepalive.EnforcementPolicy{ + MinTime: 15 * time.Second, + PermitWithoutStream: true, + } + + kasp = keepalive.ServerParameters{ + MaxConnectionIdle: 15 * time.Second, + MaxConnectionAgeGrace: 5 * time.Second, + Time: 5 * time.Second, + Timeout: 2 * time.Second, + } +) + +// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables +// we will use a management server started via to simulate the server and capture the number of retries +func TestConnectWithRetryRuns(t *testing.T) { + // start the signal server + _, signalAddr, err := startSignal() + if err != nil { + t.Fatalf("failed to start signal server: %v", err) + } + + counter := 0 + // start the management server + _, mgmtAddr, err := startManagement(t, signalAddr, &counter) + if err != nil { + t.Fatalf("failed to start management server: %v", err) + } + + ctx := internal.CtxInitState(context.Background()) + + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second)) + defer cancel() + // create new server + s := New(ctx, t.TempDir()+"/config.json", "debug") + s.latestConfigInput.ManagementURL = "http://" + mgmtAddr + config, err := internal.UpdateOrCreateConfig(s.latestConfigInput) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } + s.config = config + + s.statusRecorder = peer.NewRecorder(config.ManagementURL.String()) + t.Setenv(retryInitialIntervalVar, "1s") + t.Setenv(maxRetryIntervalVar, "2s") + t.Setenv(maxRetryTimeVar, "5s") + t.Setenv(retryMultiplierVar, "1") + + s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + if counter < 3 { + t.Fatalf("expected counter > 2, got %d", counter) + } +} + +type mockServer struct { + mgmtProto.ManagementServiceServer + counter *int +} + +func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { + *m.counter++ + return m.ManagementServiceServer.Login(ctx, req) +} + +func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) { + t.Helper() + dataDir := t.TempDir() + + config := &server.Config{ + Stuns: []*server.Host{}, + TURNConfig: &server.TURNConfig{}, + Signal: &server.Host{ + Proto: "http", + URI: signalAddr, + }, + Datadir: dataDir, + HttpConfig: nil, + } + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, "", err + } + s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) + store, err := server.NewStoreFromJson(config.Datadir, nil) + if err != nil { + return nil, "", err + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + eventStore := &activity.InMemoryEventStore{} + if err != nil { + return nil, "", err + } + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false) + if err != nil { + return nil, "", err + } + turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) + mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + if err != nil { + return nil, "", err + } + mock := &mockServer{ + ManagementServiceServer: mgmtServer, + counter: counter, + } + mgmtProto.RegisterManagementServiceServer(s, mock) + go func() { + if err = s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } + }() + + return s, lis.Addr().String(), nil +} + +func startSignal() (*grpc.Server, string, error) { + s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + + proto.RegisterSignalExchangeServer(s, signalServer.NewServer()) + + go func() { + if err = s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } + }() + + return s, lis.Addr().String(), nil +}