diff --git a/internal/aghalg/orderedmap.go b/internal/aghalg/sortedmap.go similarity index 96% rename from internal/aghalg/orderedmap.go rename to internal/aghalg/sortedmap.go index 5e15cf0733f..7f068b0ce31 100644 --- a/internal/aghalg/orderedmap.go +++ b/internal/aghalg/sortedmap.go @@ -1,6 +1,7 @@ package aghalg import ( + "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) @@ -18,7 +19,7 @@ type SortedMap[K comparable, V any] struct { // TODO(s.chzhen): Use cmp.Compare in Go 1.21. func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] { return SortedMap[K, V]{ - vals: make(map[K]V), + vals: map[K]V{}, cmp: cmp, } } @@ -69,7 +70,7 @@ func (m *SortedMap[K, V]) Clear() { // TODO(s.chzhen): Use built-in clear in Go 1.21. m.keys = nil - m.vals = make(map[K]V) + maps.Clear(m.vals) } // Range calls cb for each element of the map, sorted by m.cmp. If cb returns diff --git a/internal/aghalg/orderedmap_test.go b/internal/aghalg/sortedmap_test.go similarity index 100% rename from internal/aghalg/orderedmap_test.go rename to internal/aghalg/sortedmap_test.go diff --git a/internal/home/client.go b/internal/home/client.go index 64e9b67742d..0e1366d63cd 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -30,6 +30,16 @@ func NewUID() (uid UID, err error) { return UID(uuidv7), err } +// MustNewUID is a wrapper around [NewUID] that panics if there is an error. +func MustNewUID() (uid UID) { + uid, err := NewUID() + if err != nil { + panic(fmt.Errorf("unexpected uuidv7 error: %w", err)) + } + + return uid +} + // type check var _ encoding.TextMarshaler = UID{} diff --git a/internal/home/clientindex.go b/internal/home/clientindex.go index 68f61a85731..c217f3efc57 100644 --- a/internal/home/clientindex.go +++ b/internal/home/clientindex.go @@ -1,6 +1,7 @@ package home import ( + "fmt" "net" "net/netip" @@ -10,23 +11,24 @@ import ( // macKey contains MAC as byte array of 6, 8, or 20 bytes. type macKey any +// macToKey converts mac into key of type macKey, which is used as the key of +// the [clientIndex.macToUID]. mac must be valid MAC address. func macToKey(mac net.HardwareAddr) (key macKey) { switch len(mac) { case 6: - arr := [6]byte{} - copy(arr[:], mac[:]) + arr := *(*[6]byte)(mac) return arr case 8: - arr := [8]byte{} - copy(arr[:], mac[:]) + arr := *(*[8]byte)(mac) return arr - default: - arr := [20]byte{} - copy(arr[:], mac[:]) + case 20: + arr := *(*[20]byte)(mac) return arr + default: + panic("invalid mac address") } } @@ -54,7 +56,8 @@ func NewClientIndex() (ci *clientIndex) { } } -// add stores information about a persistent client in the index. +// add stores information about a persistent client in the index. c must +// contain UID. func (ci *clientIndex) add(c *persistentClient) { for _, id := range c.ClientIDs { ci.clientIDToUID[id] = c.UID @@ -76,26 +79,57 @@ func (ci *clientIndex) add(c *persistentClient) { ci.uidToClient[c.UID] = c } -// contains returns true if the index contains a persistent client with at least -// a single identifier contained by c. -func (ci *clientIndex) contains(c *persistentClient) (ok bool) { +// clashes returns an error if the index contains a different persistent client +// with at least a single identifier contained by c. +func (ci *clientIndex) clashes(c *persistentClient) (err error) { for _, id := range c.ClientIDs { - _, ok = ci.clientIDToUID[id] - if ok { - return true + existing, ok := ci.clientIDToUID[id] + if ok && existing != c.UID { + p := ci.uidToClient[existing] + + return fmt.Errorf("another client %q uses the same ID %q", p.Name, id) } } + p, ip := ci.clashesIP(c) + if p != nil { + return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip) + } + + p, s := ci.clashesSubnet(c) + if p != nil { + return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s) + } + + p, mac := ci.clashesMAC(c) + if p != nil { + return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac) + } + + return nil +} + +// clashesIP returns a previous client with the same IP address as c. +func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) { for _, ip := range c.IPs { - _, ok = ci.ipToUID[ip] - if ok { - return true + existing, ok := ci.ipToUID[ip] + if ok && existing != c.UID { + return ci.uidToClient[existing], ip } } - for _, pref := range c.Subnets { - ci.subnetToUID.Range(func(p netip.Prefix, _ UID) (cont bool) { - if pref == p { + return nil, netip.Addr{} +} + +// clashesSubnet returns a previous client with the same subnet as c. +func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) { + var existing UID + var ok bool + + for _, s = range c.Subnets { + ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) { + if s == p { + existing = uid ok = true return false @@ -104,20 +138,25 @@ func (ci *clientIndex) contains(c *persistentClient) (ok bool) { return true }) - if ok { - return true + if ok && existing != c.UID { + return ci.uidToClient[existing], s } } - for _, mac := range c.MACs { + return nil, netip.Prefix{} +} + +// clashesMAC returns a previous client with the same MAC address as c. +func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) { + for _, mac = range c.MACs { k := macToKey(mac) - _, ok = ci.macToUID[k] - if ok { - return true + existing, ok := ci.macToUID[k] + if ok && existing != c.UID { + return ci.uidToClient[existing], mac } } - return false + return nil, nil } // find finds persistent client by string representation of the client ID, IP diff --git a/internal/home/clientindex_internal_test.go b/internal/home/clientindex_internal_test.go index 5eef753368c..868b2fc68ea 100644 --- a/internal/home/clientindex_internal_test.go +++ b/internal/home/clientindex_internal_test.go @@ -1,6 +1,7 @@ package home import ( + "net/netip" "testing" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -97,11 +98,19 @@ func TestClientIndex(t *testing.T) { }) t.Run("contains_delete", func(t *testing.T) { - ok := ci.contains(client1) - require.True(t, ok) + err := ci.clashes(client1) + require.NoError(t, err) + + dup := &persistentClient{ + Name: "client_with_the_same_ip_as_client1", + IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, + UID: MustNewUID(), + } + err = ci.clashes(dup) + require.Error(t, err) ci.del(client1) - ok = ci.contains(client1) - require.False(t, ok) + err = ci.clashes(dup) + require.NoError(t, err) }) } diff --git a/internal/home/clients.go b/internal/home/clients.go index 07256fea246..f4b67d42a4e 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -47,8 +47,9 @@ type DHCP interface { type clientsContainer struct { // TODO(a.garipov): Perhaps use a number of separate indices for different // types (string, netip.Addr, and so on). - list map[string]*persistentClient // name -> client - idIndex map[string]*persistentClient // ID -> client + list map[string]*persistentClient // name -> client + + clientIndex *clientIndex // ipToRC maps IP addresses to runtime client information. ipToRC map[netip.Addr]*client.Runtime @@ -103,9 +104,10 @@ func (clients *clientsContainer) Init( } clients.list = map[string]*persistentClient{} - clients.idIndex = map[string]*persistentClient{} clients.ipToRC = map[netip.Addr]*client.Runtime{} + clients.clientIndex = NewClientIndex() + clients.allTags = stringutil.NewSet(clientTags...) // TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. @@ -518,7 +520,7 @@ func (clients *clientsContainer) UpstreamConfigByID( // findLocked searches for a client by its ID. clients.lock is expected to be // locked. func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) { - c, ok = clients.idIndex[id] + c, ok = clients.clientIndex.find(id) if ok { return c, true } @@ -528,14 +530,6 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok return nil, false } - for _, c = range clients.list { - for _, subnet := range c.Subnets { - if subnet.Contains(ip) { - return c, true - } - } - } - // TODO(e.burkov): Iterate through clients.list only once. return clients.findDHCP(ip) } @@ -639,18 +633,15 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { } // check ID index - ids := c.ids() - for _, id := range ids { - var c2 *persistentClient - c2, ok = clients.idIndex[id] - if ok { - return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name) - } + err = clients.clientIndex.clashes(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return false, err } clients.addLocked(c) - log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list)) + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.ids(), len(clients.list)) return true, nil } @@ -661,9 +652,7 @@ func (clients *clientsContainer) addLocked(c *persistentClient) { clients.list[c.Name] = c // update ID index - for _, id := range c.ids() { - clients.idIndex[id] = c - } + clients.clientIndex.add(c) } // remove removes a client. ok is false if there is no such client. @@ -693,9 +682,7 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) { delete(clients.list, c.Name) // Update the ID index. - for _, id := range c.ids() { - delete(clients.idIndex, id) - } + clients.clientIndex.del(c) } // update updates a client by its name. @@ -725,11 +712,10 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) { } // Check the ID index. - for _, id := range c.ids() { - existing, ok := clients.idIndex[id] - if ok && existing != prev { - return fmt.Errorf("id %q is used by client with name %q", id, existing.Name) - } + err = clients.clientIndex.clashes(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err } clients.removeLocked(prev) diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 07332ecf78b..5e3f6d507c6 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -68,6 +68,7 @@ func TestClients(t *testing.T) { c := &persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{cli1IP, cliIPv6}, } @@ -78,6 +79,7 @@ func TestClients(t *testing.T) { c = &persistentClient{ Name: "client2", + UID: MustNewUID(), IPs: []netip.Addr{cli2IP}, } diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 820b22a6bf8..592971c8a76 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -12,6 +12,19 @@ import ( var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) +func idIndex(m map[string]*persistentClient) (ci *clientIndex) { + ci = NewClientIndex() + + for id, c := range m { + c.ClientIDs = []string{id} + c.UID = MustNewUID() + + ci.add(c) + } + + return ci +} + func TestApplyAdditionalFiltering(t *testing.T) { var err error @@ -22,7 +35,7 @@ func TestApplyAdditionalFiltering(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.idIndex = map[string]*persistentClient{ + Context.clients.clientIndex = idIndex(map[string]*persistentClient{ "default": { UseOwnSettings: false, safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, @@ -44,7 +57,7 @@ func TestApplyAdditionalFiltering(t *testing.T) { SafeBrowsingEnabled: false, ParentalEnabled: false, }, - } + }) testCases := []struct { name string @@ -108,7 +121,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.idIndex = map[string]*persistentClient{ + Context.clients.clientIndex = idIndex(map[string]*persistentClient{ "default": { UseOwnBlockedServices: false, }, @@ -139,7 +152,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, UseOwnBlockedServices: true, }, - } + }) testCases := []struct { name string