Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse original resolv.conf #1270

Merged
merged 17 commits into from
Nov 3, 2023
220 changes: 160 additions & 60 deletions client/internal/dns/file_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,25 @@
package dns

import (
"bufio"
"bytes"
"fmt"
"os"
"strings"

log "github.com/sirupsen/logrus"
)

const (
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfSearchBeginContent = "search "
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n\n" +
"%s\n"
)
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"

const (
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
)

var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
)

type fileConfigurator struct {
originalPerms os.FileMode
Expand Down Expand Up @@ -55,58 +51,39 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
}
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
}
managerType, err := getOSDNSManagerType()
if err != nil {
return err
}
switch managerType {
case fileManager, netbirdManager:
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
}
}
default:
// todo improve this and maybe restart DNS manager from scratch
return fmt.Errorf("something happened and file manager is not your preferred host dns configurator, restart the agent")
}

var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
}

searchDomains += " " + dConf.domain
appendedDomains++
}

originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
searchDomainList := searchDomains(config)

originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Errorf("Could not read existing resolv.conf")
log.Error(err)
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)

searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)

buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.serverIP}, nameServers...),
others)

log.Debugf("creating managed file %s", defaultResolvConfPath)
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
if err != nil {
err = f.restore()
if err != nil {
restoreErr := f.restore()
if restoreErr != nil {
log.Errorf("attempt to restore default file failed with error: %s", err)
}
return err
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
}
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)

log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
return nil
}

Expand Down Expand Up @@ -138,15 +115,138 @@ func (f *fileConfigurator) restore() error {
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}

func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
log.Debugf("creating managed file %s", fileName)
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
var buf bytes.Buffer
buf.WriteString(content)
err := os.WriteFile(fileName, buf.Bytes(), permissions)
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)

for _, cfgLine := range others {
buf.WriteString(cfgLine)
buf.WriteString("\n")
}

if len(searchDomains) > 0 {
buf.WriteString("search ")
buf.WriteString(strings.Join(searchDomains, " "))
buf.WriteString("\n")
}

for _, ns := range nameServers {
buf.WriteString("nameserver ")
buf.WriteString(ns)
buf.WriteString("\n")
}
return buf
}

func searchDomains(config hostDNSConfig) []string {
listOfDomains := make([]string, 0)
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}

listOfDomains = append(listOfDomains, dConf.domain)
}
return listOfDomains
}

func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
file, err := os.Open(resolvconfFile)
if err != nil {
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
err = fmt.Errorf(`could not read existing resolv.conf`)
return
}
return nil
defer file.Close()

reader := bufio.NewReader(file)

for {
lineBytes, isPrefix, readErr := reader.ReadLine()
if readErr != nil {
break
}

if isPrefix {
err = fmt.Errorf(`resolv.conf line too long`)
return
}

line := strings.TrimSpace(string(lineBytes))

if strings.HasPrefix(line, "#") {
continue
}

if strings.HasPrefix(line, "domain") {
continue
}

if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
line = strings.ReplaceAll(line, "rotate", "")
splitLines := strings.Fields(line)
if len(splitLines) == 1 {
continue
}
line = strings.Join(splitLines, " ")
}

if strings.HasPrefix(line, "search") {
splitLines := strings.Fields(line)
if len(splitLines) < 2 {
continue
}

searchDomains = splitLines[1:]
continue
}

if strings.HasPrefix(line, "nameserver") {
splitLines := strings.Fields(line)
if len(splitLines) != 2 {
continue
}
nameServers = append(nameServers, splitLines[1])
continue
}

others = append(others, line)
}
return
}

// merge search domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
lineSize := len("search")
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))

lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)

return searchDomainsList
}

// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long
// extend s slice with vs elements
// return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
for _, sd := range vs {
tmpCharsNumber := initialLineChars + 1 + len(sd)
if tmpCharsNumber > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
continue
}

initialLineChars = tmpCharsNumber

if len(*s) >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
continue
}
*s = append(*s, sd)
}
return initialLineChars
}

func copyFile(src, dest string) error {
Expand Down
62 changes: 62 additions & 0 deletions client/internal/dns/file_linux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package dns

import (
"fmt"
"testing"
)

func Test_mergeSearchDomains(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"a", "b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 4 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
}
}

func Test_mergeSearchTooMuchDomains(t *testing.T) {
searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"}
originDomains := []string{"h", "i"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}

func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"c", "d", "e", "f", "g"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}

func Test_mergeSearchTooLongDomain(t *testing.T) {
searchDomains := []string{getLongLine()}
originDomains := []string{"b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}

searchDomains = []string{"b"}
originDomains = []string{getLongLine()}

mergedDomains = mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}
}

func getLongLine() string {
x := "search "
for {
for i := 0; i <= 9; i++ {
if len(x) > fileMaxLineCharsLimit {
return x
}
x = fmt.Sprintf("%s%d", x, i)
}
}
}
Loading