diff --git a/internal/client/client.go b/internal/client/client.go index d0a75045664..6d13277a97a 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -7,6 +7,8 @@ package client import ( "encoding" "fmt" + "net/netip" + "sync" "github.com/AdguardTeam/AdGuardHome/internal/whois" ) @@ -157,3 +159,85 @@ func (r *Runtime) IsEmpty() (ok bool) { r.dhcp == nil && r.hostsFile == nil } + +// RuntimeIndex stores information about runtime clients. +type RuntimeIndex struct { + // indexMu protects index. + indexMu *sync.RWMutex + + // index maps IP address to runtime client. + index map[netip.Addr]*Runtime +} + +// NewRuntimeIndex returns initialized runtime index. +func NewRuntimeIndex() (ri *RuntimeIndex) { + return &RuntimeIndex{ + indexMu: &sync.RWMutex{}, + index: map[netip.Addr]*Runtime{}, + } +} + +// Client returns the saved runtime client by ip. If no such client exists, +// returns nil. +func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime, ok bool) { + ri.indexMu.RLock() + defer ri.indexMu.RUnlock() + + rc, ok = ri.index[ip] + + return rc, ok +} + +// Add saves the runtime client by ip. +func (ri *RuntimeIndex) Add(ip netip.Addr, rc *Runtime) { + ri.indexMu.Lock() + defer ri.indexMu.Unlock() + + ri.index[ip] = rc +} + +// Size returns the number of the runtime clients. +func (ri *RuntimeIndex) Size() (n int) { + ri.indexMu.RLock() + defer ri.indexMu.RUnlock() + + return len(ri.index) +} + +// Range calls cb for each runtime client. +func (ri *RuntimeIndex) Range(cb func(ip netip.Addr, rc *Runtime) (cont bool)) { + ri.indexMu.RLock() + defer ri.indexMu.RUnlock() + + for ip, rc := range ri.index { + if !cb(ip, rc) { + return + } + } +} + +// Delete removes the runtime client by ip. +func (ri *RuntimeIndex) Delete(ip netip.Addr) { + ri.indexMu.Lock() + defer ri.indexMu.Unlock() + + delete(ri.index, ip) +} + +// DeleteBySrc removes all runtime clients that have information only from the +// specified source and returns the number of removed clients. +func (ri *RuntimeIndex) DeleteBySrc(src Source) (n int) { + ri.indexMu.Lock() + defer ri.indexMu.Unlock() + + for ip, rc := range ri.index { + rc.Unset(src) + + if rc.IsEmpty() { + delete(ri.index, ip) + n++ + } + } + + return n +} diff --git a/internal/home/clients.go b/internal/home/clients.go index fb627a2e14a..8cfc8ea3e2e 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -49,10 +49,11 @@ type clientsContainer struct { // types (string, netip.Addr, and so on). list map[string]*client.Persistent // name -> client + // clientIndex stores information about persistent clients. clientIndex *client.Index - // ipToRC maps IP addresses to runtime client information. - ipToRC map[netip.Addr]*client.Runtime + // runtimeIndex stores information about runtime clients. + runtimeIndex *client.RuntimeIndex allTags *stringutil.Set @@ -104,7 +105,7 @@ func (clients *clientsContainer) Init( } clients.list = map[string]*client.Persistent{} - clients.ipToRC = map[netip.Addr]*client.Runtime{} + clients.runtimeIndex = client.NewRuntimeIndex() clients.clientIndex = client.NewIndex() @@ -362,7 +363,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) return client.SourcePersistent } - rc, ok := clients.ipToRC[ip] + rc, ok := clients.runtimeIndex.Client(ip) if ok { src, _ = rc.Info() } @@ -558,10 +559,7 @@ func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtim return nil, false } - clients.lock.Lock() - defer clients.lock.Unlock() - - rc, ok = clients.ipToRC[ip] + rc, ok = clients.runtimeIndex.Client(ip) return rc, ok } @@ -733,12 +731,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { return } - rc, ok := clients.ipToRC[ip] + rc, ok := clients.runtimeIndex.Client(ip) if !ok { // Create a RuntimeClient implicitly so that we don't do this check // again. rc = &client.Runtime{} - clients.ipToRC[ip] = rc + clients.runtimeIndex.Add(ip, rc) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) } else { @@ -797,7 +795,7 @@ func (clients *clientsContainer) addHostLocked( host string, src client.Source, ) (ok bool) { - rc, ok := clients.ipToRC[ip] + rc, ok := clients.runtimeIndex.Client(ip) if !ok { if src < client.SourceDHCP { if clients.dhcp.HostByIP(ip) != "" { @@ -806,52 +804,39 @@ func (clients *clientsContainer) addHostLocked( } rc = &client.Runtime{} - clients.ipToRC[ip] = rc + clients.runtimeIndex.Add(ip, rc) } rc.SetInfo(src, []string{host}) - log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC)) + log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, clients.runtimeIndex.Size()) return true } -// rmHostsBySrc removes all entries that match the specified source. -func (clients *clientsContainer) rmHostsBySrc(src client.Source) { - n := 0 - for ip, rc := range clients.ipToRC { - rc.Unset(src) - if rc.IsEmpty() { - delete(clients.ipToRC, ip) - n++ - } - } - - log.Debug("clients: removed %d client aliases", n) -} - // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) { clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHostsBySrc(client.SourceHostsFile) + deleted := clients.runtimeIndex.DeleteBySrc(client.SourceHostsFile) + log.Debug("clients: removed %d client aliases from system hosts file", deleted) - n := 0 + added := 0 hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical // hostname for the IP address. // // TODO(e.burkov): Consider using all the names from all the records. if clients.addHostLocked(addr, names[0], client.SourceHostsFile) { - n++ + added++ } return true }) - log.Debug("clients: added %d client aliases from system hosts file", n) + log.Debug("clients: added %d client aliases from system hosts file", added) } // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a @@ -875,7 +860,8 @@ func (clients *clientsContainer) addFromSystemARP() { clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHostsBySrc(client.SourceARP) + deleted := clients.runtimeIndex.DeleteBySrc(client.SourceARP) + log.Debug("clients: removed %d client aliases from arp neighborhood", deleted) added := 0 for _, n := range ns { diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 4f9cb946bff..71d7c8ee2f1 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -244,8 +244,9 @@ func TestClientsWHOIS(t *testing.T) { t.Run("new_client", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.255") clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc, ok := clients.runtimeIndex.Client(ip) require.NotNil(t, rc) + require.True(t, ok) assert.Equal(t, whois, rc.WHOIS()) }) @@ -256,8 +257,9 @@ func TestClientsWHOIS(t *testing.T) { assert.True(t, ok) clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc, ok := clients.runtimeIndex.Client(ip) require.NotNil(t, rc) + require.True(t, ok) assert.Equal(t, whois, rc.WHOIS()) }) @@ -274,8 +276,9 @@ func TestClientsWHOIS(t *testing.T) { assert.True(t, ok) clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc, ok := clients.runtimeIndex.Client(ip) require.Nil(t, rc) + require.False(t, ok) assert.True(t, clients.remove("client1")) }) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index b2270416e29..a0659115573 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -101,7 +101,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http data.Clients = append(data.Clients, cj) } - for ip, rc := range clients.ipToRC { + clients.runtimeIndex.Range(func(ip netip.Addr, rc *client.Runtime) (cont bool) { src, host := rc.Info() cj := runtimeClientJSON{ WHOIS: whoisOrEmpty(rc), @@ -111,7 +111,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http } data.RuntimeClients = append(data.RuntimeClients, cj) - } + + return true + }) for _, l := range clients.dhcp.Leases() { cj := runtimeClientJSON{