Skip to content

Commit

Permalink
Add fallback retry to daemon (#1690)
Browse files Browse the repository at this point in the history
This change adds a fallback retry to the daemon service.

this retry has a larger interval with a shorter max retry run time
then others retries
  • Loading branch information
mlsmaycon authored Mar 12, 2024
1 parent 5dde044 commit 9d213e0
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 12 deletions.
127 changes: 115 additions & 12 deletions client/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package server
import (
"context"
"fmt"
"os"
"os/exec"
"runtime"
"strconv"
"sync"
"time"

"github.com/cenkalti/backoff/v4"

"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"

Expand All @@ -23,7 +27,17 @@ import (
"github.com/netbirdio/netbird/version"
)

const probeThreshold = time.Second * 5
const (
probeThreshold = time.Second * 5
retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME"
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
defaultInitialRetryTime = 14 * 24 * time.Hour
defaultMaxRetryInterval = 60 * time.Minute
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
)

// Server for service control.
type Server struct {
Expand Down Expand Up @@ -125,16 +139,110 @@ func (s *Server) Start() error {
}

if !config.DisableAutoConnect {
go func() {
if err := internal.RunClientWithProbes(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil {
log.Errorf("init connections: %v", err)
}
}()
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
}

return nil
}

// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false

go func() {
t := time.NewTicker(24 * time.Hour)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
if retryStarted {

mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status")
retryStarted = false
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
}
}
}
}()

runOperation := func() error {
log.Tracef("running client connection")
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}

if config.DisableAutoConnect {
return backoff.Permanent(err)
}

if !retryStarted {
retryStarted = true
backOff.Reset()
}

log.Tracef("client connection exited")
return fmt.Errorf("client connection exited")
}

err := backoff.Retry(runOperation, backOff)
if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled {
log.Errorf("received an error when trying to connect: %v", err)
} else {
log.Tracef("retry canceled")
}
}

// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier

if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}

return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}

// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}

// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
var status internal.StatusType
Expand Down Expand Up @@ -445,12 +553,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)

go func() {
if err := internal.RunClientWithProbes(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil {
log.Errorf("run client connection: %v", err)
return
}
}()
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)

return &proto.UpResponse{}, nil
}
Expand Down
157 changes: 157 additions & 0 deletions client/server/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package server

import (
"context"
"net"
"testing"
"time"

log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"

"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)

var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}

kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)

// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal()
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}

counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}

ctx := internal.CtxInitState(context.Background())

ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
s := New(ctx, t.TempDir()+"/config.json", "debug")
s.latestConfigInput.ManagementURL = "http://" + mgmtAddr
config, err := internal.UpdateOrCreateConfig(s.latestConfigInput)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
s.config = config

s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")

s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}

type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}

func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}

func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()

config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Signal: &server.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}

lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStoreFromJson(config.Datadir, nil)
if err != nil {
return nil, "", err
}

peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()

return s, lis.Addr().String(), nil
}

func startSignal() (*grpc.Server, string, error) {
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))

lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}

proto.RegisterSignalExchangeServer(s, signalServer.NewServer())

go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()

return s, lis.Addr().String(), nil
}

0 comments on commit 9d213e0

Please sign in to comment.