Skip to content

Commit

Permalink
Enhance DNS failover reliability (netbirdio#1637)
Browse files Browse the repository at this point in the history
* Fix using wrong array index in log to avoid potential panic

* Increase gRPC connection timeout and add the timeout resolv.conf option

This makes sure the dns client is able to failover to a second
configured nameserver, if present. That is the case then when using the
dns `file` manager and a resolv.conf file generated for netbird.

* On file backup restore, remove the first NS if it's the netbird NS

* Bump dns mangager discovery message from debug to info to ease debugging
  • Loading branch information
lixmal authored Mar 1, 2024
1 parent 30fe193 commit c12b583
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 15 deletions.
27 changes: 20 additions & 7 deletions client/internal/dns/file_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/netip"
"os"
"strings"
"time"

log "github.com/sirupsen/logrus"
)
Expand All @@ -23,10 +24,16 @@ const (
fileMaxNumberOfSearchDomains = 6
)

const (
dnsFailoverTimeout = 4 * time.Second
dnsFailoverAttempts = 1
)

type fileConfigurator struct {
repair *repair

originalPerms os.FileMode
originalPerms os.FileMode
nbNameserverIP string
}

func newFileConfigurator() (hostManager, error) {
Expand Down Expand Up @@ -64,7 +71,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
}

nbSearchDomains := searchDomains(config)
nbNameserverIP := config.ServerIP
f.nbNameserverIP = config.ServerIP

resolvConf, err := parseBackupResolvConf()
if err != nil {
Expand All @@ -73,22 +80,23 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {

f.repair.stopWatchFileChanges()

err = f.updateConfig(nbSearchDomains, nbNameserverIP, resolvConf)
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
if err != nil {
return err
}
f.repair.watchFileChanges(nbSearchDomains, nbNameserverIP)
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
return nil
}

func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg)

options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
buf := prepareResolvConfContent(
searchDomainList,
nameServers,
cfg.others)
options)

log.Debugf("creating managed file %s", defaultResolvConfPath)
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
Expand Down Expand Up @@ -131,7 +139,12 @@ func (f *fileConfigurator) backup() error {
}

func (f *fileConfigurator) restore() error {
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}

err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
}
Expand All @@ -157,7 +170,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[1], err)
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
return restoreResolvConfFile()
}

Expand Down
63 changes: 63 additions & 0 deletions client/internal/dns/file_parser_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package dns
import (
"fmt"
"os"
"regexp"
"strings"

log "github.com/sirupsen/logrus"
Expand All @@ -14,6 +15,9 @@ const (
defaultResolvConfPath = "/etc/resolv.conf"
)

var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)

type resolvConf struct {
nameServers []string
searchDomains []string
Expand Down Expand Up @@ -103,3 +107,62 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
}
return rconf, nil
}

// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
// otherwise it adds a new option with timeout and attempts.
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
configs := make([]string, len(input))
copy(configs, input)

for i, config := range configs {
if strings.HasPrefix(config, "options") {
config = strings.ReplaceAll(config, "rotate", "")
config = strings.Join(strings.Fields(config), " ")

if strings.Contains(config, "timeout:") {
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
}

if strings.Contains(config, "attempts:") {
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
}

configs[i] = config
return configs
}
}

return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
}

// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename, nameserverIP string) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}

if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)

stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}

}

return nil
}
130 changes: 130 additions & 0 deletions client/internal/dns/file_parser_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_parseResolvConf(t *testing.T) {
Expand Down Expand Up @@ -172,3 +174,131 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg)
}
}

func TestPrepareOptionsWithTimeout(t *testing.T) {
tests := []struct {
name string
others []string
timeout int
attempts int
expected []string
}{
{
name: "Append new options with timeout and attempts",
others: []string{"some config"},
timeout: 2,
attempts: 2,
expected: []string{"some config", "options timeout:2 attempts:2"},
},
{
name: "Modify existing options to exclude rotate and include timeout and attempts",
others: []string{"some config", "options rotate someother"},
timeout: 3,
attempts: 2,
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
},
{
name: "Existing options with timeout and attempts are updated",
others: []string{"some config", "options timeout:4 attempts:3"},
timeout: 5,
attempts: 4,
expected: []string{"some config", "options timeout:5 attempts:4"},
},
{
name: "Modify existing options, add missing attempts before timeout",
others: []string{"some config", "options timeout:4"},
timeout: 4,
attempts: 3,
expected: []string{"some config", "options attempts:3 timeout:4"},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
assert.Equal(t, tc.expected, result)
})
}
}

func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)

err = removeFirstNbNameserver(tempFile, tc.ipToRemove)
assert.NoError(t, err)

content, err := os.ReadFile(tempFile)
assert.NoError(t, err)

assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}
2 changes: 1 addition & 1 deletion client/internal/dns/host_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
return nil, err
}

log.Debugf("discovered mode is: %s", osManager)
log.Infof("System DNS manager discovered: %s", osManager)
return newHostManagerFromType(wgInterface, osManager)
}

Expand Down
4 changes: 3 additions & 1 deletion client/internal/dns/resolvconf_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
searchDomainList := searchDomains(config)
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)

options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)

buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.ServerIP}, r.originalNameServers...),
r.othersConfigs)
options)

// create a backup for unclean shutdown detection before the resolv.conf is changed
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
Expand Down
6 changes: 4 additions & 2 deletions management/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"github.com/netbirdio/netbird/management/proto"
)

const ConnectTimeout = 10 * time.Second

// ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface {
MarkManagementDisconnected(error)
Expand All @@ -49,7 +51,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}

mgmCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
Expand Down Expand Up @@ -318,7 +320,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Expand Down
7 changes: 3 additions & 4 deletions signal/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import (
"google.golang.org/grpc/status"

"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/signal/proto"
)

const defaultSendTimeout = 5 * time.Second

// ConnStateNotifier is a wrapper interface of the status recorder
type ConnStateNotifier interface {
MarkSignalDisconnected(error)
Expand Down Expand Up @@ -71,7 +70,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}

sigCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
Expand Down Expand Up @@ -353,7 +352,7 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
return err
}

attemptTimeout := defaultSendTimeout
attemptTimeout := client.ConnectTimeout

for attempt := 0; attempt < 4; attempt++ {
if attempt > 1 {
Expand Down

0 comments on commit c12b583

Please sign in to comment.