Skip to content

Commit

Permalink
dhcpsvc: add db
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jul 8, 2024
1 parent 9a6dd0d commit 83ec7c5
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 10 deletions.
8 changes: 7 additions & 1 deletion internal/dhcpsvc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dhcpsvc
import (
"fmt"
"log/slog"
"os"
"time"

"github.com/AdguardTeam/golibs/errors"
Expand All @@ -23,7 +24,8 @@ type Config struct {
// clients' hostnames.
LocalDomainName string

// TODO(e.burkov): Add DB path.
// DBFilePath is the path to the database file containing the DHCP leases.
DBFilePath string

// ICMPTimeout is the timeout for checking another DHCP server's presence.
ICMPTimeout time.Duration
Expand Down Expand Up @@ -64,6 +66,10 @@ func (conf *Config) Validate() (err error) {
errs = append(errs, err)
}

if _, err = os.Stat(conf.DBFilePath); err != nil && !errors.Is(err, os.ErrNotExist) {
errs = append(errs, fmt.Errorf("db file path %q: %w", conf.DBFilePath, err))
}

if len(conf.Interfaces) == 0 {
errs = append(errs, errNoInterfaces)

Expand Down
9 changes: 9 additions & 0 deletions internal/dhcpsvc/config_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package dhcpsvc_test

import (
"path/filepath"
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/testutil"
)

func TestConfig_Validate(t *testing.T) {
leasesPath := filepath.Join(t.TempDir(), "leases.json")

testCases := []struct {
name string
conf *dhcpsvc.Config
Expand All @@ -25,13 +28,15 @@ func TestConfig_Validate(t *testing.T) {
conf: &dhcpsvc.Config{
Enabled: true,
Interfaces: testInterfaceConf,
DBFilePath: leasesPath,
},
wantErrMsg: `bad domain name "": domain name is empty`,
}, {
conf: &dhcpsvc.Config{
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: nil,
DBFilePath: leasesPath,
},
name: "no_interfaces",
wantErrMsg: "no interfaces specified",
Expand All @@ -40,6 +45,7 @@ func TestConfig_Validate(t *testing.T) {
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: nil,
DBFilePath: leasesPath,
},
name: "no_interfaces",
wantErrMsg: "no interfaces specified",
Expand All @@ -50,6 +56,7 @@ func TestConfig_Validate(t *testing.T) {
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
"eth0": nil,
},
DBFilePath: leasesPath,
},
name: "nil_interface",
wantErrMsg: `interface "eth0": config is nil`,
Expand All @@ -63,6 +70,7 @@ func TestConfig_Validate(t *testing.T) {
IPv6: &dhcpsvc.IPv6Config{Enabled: false},
},
},
DBFilePath: leasesPath,
},
name: "nil_ipv4",
wantErrMsg: `interface "eth0": ipv4: config is nil`,
Expand All @@ -76,6 +84,7 @@ func TestConfig_Validate(t *testing.T) {
IPv6: nil,
},
},
DBFilePath: leasesPath,
},
name: "nil_ipv6",
wantErrMsg: `interface "eth0": ipv6: config is nil`,
Expand Down
197 changes: 197 additions & 0 deletions internal/dhcpsvc/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package dhcpsvc

import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"slices"
"strings"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/google/renameio/v2/maybe"
)

// dataVersion is the current version of the stored DHCP leases structure.
const dataVersion = 1

// dataLeases is the structure of the stored DHCP leases.
type dataLeases struct {
// Leases is the list containing stored DHCP leases.
Leases []*dbLease `json:"leases"`

// Version is the current version of the structure.
Version int `json:"version"`
}

// dbLease is the structure of stored lease.
type dbLease struct {
Expiry string `json:"expires"`
IP netip.Addr `json:"ip"`
Hostname string `json:"hostname"`
HWAddr string `json:"mac"`
IsStatic bool `json:"static"`
}

// compareNames returns the result of comparing the hostnames of dl and other
// lexicographically.
func (dl *dbLease) compareNames(other *dbLease) (res int) {
return strings.Compare(dl.Hostname, other.Hostname)
}

// fromLease converts *Lease to *dbLease.
func fromLease(l *Lease) (dl *dbLease) {
var expiryStr string
if !l.IsStatic {
// The front-end is waiting for RFC 3999 format of the time value. It
// also shouldn't got an Expiry field for static leases.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
expiryStr = l.Expiry.Format(time.RFC3339)
}

return &dbLease{
Expiry: expiryStr,
Hostname: l.Hostname,
HWAddr: l.HWAddr.String(),
IP: l.IP,
IsStatic: l.IsStatic,
}
}

// toLease converts dl to *Lease.
func (dl *dbLease) toLease() (l *Lease, err error) {
mac, err := net.ParseMAC(dl.HWAddr)
if err != nil {
return nil, fmt.Errorf("parsing hardware address: %w", err)
}

expiry := time.Time{}
if !dl.IsStatic {
expiry, err = time.Parse(time.RFC3339, dl.Expiry)
if err != nil {
return nil, fmt.Errorf("parsing expiry time: %w", err)
}
}

return &Lease{
Expiry: expiry,
IP: dl.IP,
Hostname: dl.Hostname,
HWAddr: mac,
IsStatic: dl.IsStatic,
}, nil
}

// dbLoad loads stored leases. It must only be called before the service has
// been started.
func (srv *DHCPServer) dbLoad(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "loading db: %w") }()

file, err := os.Open(srv.dbFilePath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("reading db: %w", err)
}

srv.logger.DebugContext(ctx, "no db file found")

return nil
}

dl := &dataLeases{}
err = json.NewDecoder(file).Decode(dl)
if err != nil {
return fmt.Errorf("decoding db: %w", err)
}

srv.resetLeases()
srv.addDBLeases(ctx, dl.Leases)

return nil
}

// addDBLeases adds leases to the server.
func (srv *DHCPServer) addDBLeases(ctx context.Context, leases []*dbLease) {
const logMsg = "loading lease"

var v4, v6 uint
for i, l := range leases {
var lease *Lease
lease, err := l.toLease()
if err != nil {
srv.logger.DebugContext(ctx, logMsg, "idx", i, slogutil.KeyError, err)

continue
}

addr := l.IP
iface, err := srv.ifaceForAddr(addr)
if err != nil {
srv.logger.DebugContext(ctx, logMsg, "idx", i, slogutil.KeyError, err)

continue
}

err = srv.leases.add(lease, iface)
if err != nil {
srv.logger.DebugContext(ctx, logMsg, "idx", i, slogutil.KeyError, err)

continue
}

if lease.IP.Is4() {
v4++
} else {
v6++
}
}

srv.logger.InfoContext(
ctx,
"loaded leases",
"v4", v4,
"v6", v6,
"total", len(leases),
)
}

// writeDB writes leases to the database file. It expects the
// [DHCPServer.leasesMu] to be locked.
func (srv *DHCPServer) dbStore(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "writing db: %w") }()

dl := &dataLeases{
// Avoid writing "null" into the database file if there are no leases.
Leases: make([]*dbLease, 0, srv.leases.len()),
Version: dataVersion,
}

srv.leases.rangeLeases(func(l *Lease) (cont bool) {
lease := fromLease(l)
i, _ := slices.BinarySearchFunc(dl.Leases, lease, (*dbLease).compareNames)
dl.Leases = slices.Insert(dl.Leases, i, lease)

return true
})

buf, err := json.Marshal(dl)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

err = maybe.WriteFile(srv.dbFilePath, buf, 0o644)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

srv.logger.InfoContext(ctx, "stored leases", "num", len(dl.Leases), "file", srv.dbFilePath)

return nil
}
55 changes: 55 additions & 0 deletions internal/dhcpsvc/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package dhcpsvc_test

import (
"net/netip"
"path/filepath"
"testing"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestServer_loadDatabase(t *testing.T) {
leasesPath := filepath.Join("testdata", t.Name(), "leases.json")

ipv4Conf := &dhcpsvc.IPv4Config{
Enabled: true,
GatewayIP: netip.MustParseAddr("192.168.0.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("192.168.0.2"),
RangeEnd: netip.MustParseAddr("192.168.0.254"),
LeaseDuration: 1 * time.Hour,
}
conf := &dhcpsvc.Config{
Enabled: true,
LocalDomainName: "local",
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
IPv4: ipv4Conf,
IPv6: &dhcpsvc.IPv6Config{Enabled: false},
},
},
DBFilePath: leasesPath,
Logger: discardLog,
}

ctx := testutil.ContextWithTimeout(t, testTimeout)

srv, err := dhcpsvc.New(ctx, conf)
require.NoError(t, err)

expiry, err := time.Parse(time.RFC3339, "2042-01-02T03:04:05Z")
require.NoError(t, err)

wantLeases := []*dhcpsvc.Lease{{
Expiry: expiry,
IP: netip.MustParseAddr("192.168.0.3"),
Hostname: "example.host",
HWAddr: mustParseMAC(t, "AA:AA:AA:AA:AA:AA"),
IsStatic: false,
}}
assert.Equal(t, wantLeases, srv.Leases())
}
15 changes: 15 additions & 0 deletions internal/dhcpsvc/leaseindex.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,18 @@ func (idx *leaseIndex) update(l *Lease, iface *netInterface) (err error) {

return nil
}

// rangeLeases calls f for each lease in idx in an unspecified order until f
// returns false.
func (idx *leaseIndex) rangeLeases(f func(l *Lease) (cont bool)) {
for _, l := range idx.byName {
if !f(l) {
break
}
}
}

// len returns the number of leases in idx.
func (idx *leaseIndex) len() (l uint) {
return uint(len(idx.byAddr))
}
Loading

0 comments on commit 83ec7c5

Please sign in to comment.