Skip to content

Commit

Permalink
all: slog rdns whois
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Aug 26, 2024
1 parent 738958d commit 0b1f022
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 32 deletions.
26 changes: 18 additions & 8 deletions internal/client/addrproc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"context"
"log/slog"
"net/netip"
"sync"
"time"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
)

Expand Down Expand Up @@ -38,6 +40,10 @@ func (EmptyAddrProc) Close() (_ error) { return nil }

// DefaultAddrProcConfig is the configuration structure for address processors.
type DefaultAddrProcConfig struct {
// BaseLogger is used to create loggers with custom prefixes for sources of
// information about runtime clients. It must not be nil.
BaseLogger *slog.Logger

// DialContext is used to create TCP connections to WHOIS servers.
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
DialContext aghnet.DialContextFunc
Expand Down Expand Up @@ -147,14 +153,15 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {

if c.UseRDNS {
p.rdns = rdns.New(&rdns.Config{
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "rdns"),
Exchanger: c.Exchanger,
CacheSize: defaultCacheSize,
CacheTTL: defaultIPTTL,
})
}

if c.UseWHOIS {
p.whois = newWHOIS(c.DialContext)
p.whois = newWHOIS(c.BaseLogger.With(slogutil.KeyPrefix, "whois"), c.DialContext)
}

go p.process(c.CatchPanics)
Expand All @@ -168,7 +175,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {

// newWHOIS returns a whois.Interface instance using the given function for
// dialing.
func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
func newWHOIS(logger *slog.Logger, dialFunc aghnet.DialContextFunc) (w whois.Interface) {
// TODO(s.chzhen): Consider making configurable.
const (
// defaultTimeout is the timeout for WHOIS requests.
Expand All @@ -186,6 +193,7 @@ func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
)

return whois.New(&whois.Config{
Logger: logger,
DialContext: dialFunc,
ServerAddr: whois.DefaultServer,
Port: whois.DefaultPort,
Expand Down Expand Up @@ -227,9 +235,11 @@ func (p *DefaultAddrProc) process(catchPanics bool) {

log.Info("clients: processing addresses")

ctx := context.TODO()

for ip := range p.clientIPs {
host := p.processRDNS(ip)
info := p.processWHOIS(ip)
host := p.processRDNS(ctx, ip)
info := p.processWHOIS(ctx, ip)

p.addrUpdater.UpdateAddress(ip, host, info)
}
Expand All @@ -239,7 +249,7 @@ func (p *DefaultAddrProc) process(catchPanics bool) {

// processRDNS resolves the clients' IP addresses using reverse DNS. host is
// empty if there were errors or if the information hasn't changed.
func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) {
func (p *DefaultAddrProc) processRDNS(ctx context.Context, ip netip.Addr) (host string) {
start := time.Now()
log.Debug("clients: processing %s with rdns", ip)
defer func() {
Expand All @@ -251,7 +261,7 @@ func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) {
return
}

host, changed := p.rdns.Process(ip)
host, changed := p.rdns.Process(ctx, ip)
if !changed {
host = ""
}
Expand All @@ -268,7 +278,7 @@ func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) {
// processWHOIS looks up the information about clients' IP addresses in the
// WHOIS databases. info is nil if there were errors or if the information
// hasn't changed.
func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) {
func (p *DefaultAddrProc) processWHOIS(ctx context.Context, ip netip.Addr) (info *whois.Info) {
start := time.Now()
log.Debug("clients: processing %s with whois", ip)
defer func() {
Expand All @@ -277,7 +287,7 @@ func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) {

// TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the
// context.
info, changed := p.whois.Process(context.Background(), ip)
info, changed := p.whois.Process(ctx, ip)
if !changed {
info = nil
}
Expand Down
3 changes: 3 additions & 0 deletions internal/client/addrproc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakenet"
Expand Down Expand Up @@ -99,6 +100,7 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
updInfoCh := make(chan *whois.Info, 1)

p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
BaseLogger: slogutil.NewDiscardLogger(),
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
panic("not implemented")
},
Expand Down Expand Up @@ -208,6 +210,7 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
updInfoCh := make(chan *whois.Info, 1)

p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
BaseLogger: slogutil.NewDiscardLogger(),
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
return whoisConn, nil
},
Expand Down
1 change: 1 addition & 0 deletions internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ func (s *Server) setupAddrProc() {
s.addrProc = client.EmptyAddrProc{}
} else {
c := s.conf.AddrProcConf
c.BaseLogger = s.logger
c.DialContext = s.DialContext
c.PrivateSubnets = s.privateNets
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
Expand Down
42 changes: 32 additions & 10 deletions internal/rdns/rdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
package rdns

import (
"context"
"fmt"
"log/slog"
"net/netip"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/bluele/gcache"
)

// Interface processes rDNS queries.
type Interface interface {
// Process makes rDNS request and returns domain name. changed indicates
// that domain name was updated since last request.
Process(ip netip.Addr) (host string, changed bool)
Process(ctx context.Context, ip netip.Addr) (host string, changed bool)
}

// Empty is an empty [Interface] implementation which does nothing.
Expand All @@ -24,7 +27,7 @@ type Empty struct{}
var _ Interface = (*Empty)(nil)

// Process implements the [Interface] interface for Empty.
func (Empty) Process(_ netip.Addr) (host string, changed bool) {
func (Empty) Process(_ context.Context, _ netip.Addr) (host string, changed bool) {
return "", false
}

Expand All @@ -37,6 +40,10 @@ type Exchanger interface {

// Config is the configuration structure for Default.
type Config struct {
// Logger is used for logging the operation of the reverse DNS lookup
// queries. It must not be nil.
Logger *slog.Logger

// Exchanger resolves IP addresses to domain names.
Exchanger Exchanger

Expand All @@ -50,6 +57,10 @@ type Config struct {

// Default is the default rDNS query processor.
type Default struct {
// logger is used for logging the operation of the reverse DNS lookup
// queries. It must not be nil.
logger *slog.Logger

// cache is the cache containing IP addresses of clients. An active IP
// address is resolved once again after it expires. If IP address couldn't
// be resolved, it stays here for some time to prevent further attempts to
Expand All @@ -66,6 +77,7 @@ type Default struct {
// New returns a new default rDNS query processor. conf must not be nil.
func New(conf *Config) (r *Default) {
return &Default{
logger: conf.Logger,
cache: gcache.New(conf.CacheSize).LRU().Build(),
exchanger: conf.Exchanger,
cacheTTL: conf.CacheTTL,
Expand All @@ -76,15 +88,15 @@ func New(conf *Config) (r *Default) {
var _ Interface = (*Default)(nil)

// Process implements the [Interface] interface for Default.
func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
fromCache, expired := r.findInCache(ip)
func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, changed bool) {
fromCache, expired := r.findInCache(ctx, ip)
if !expired {
return fromCache, false
}

host, ttl, err := r.exchanger.Exchange(ip)
if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err)
r.logger.DebugContext(ctx, "resolving ip", "ip", ip, slogutil.KeyError, err)
}

ttl = max(ttl, r.cacheTTL)
Expand All @@ -96,7 +108,7 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) {

err = r.cache.Set(ip, item)
if err != nil {
log.Debug("rdns: cache: adding item %q: %s", ip, err)
r.logger.DebugContext(ctx, "adding item to cache", "item", ip, slogutil.KeyError, err)
}

// TODO(e.burkov): The name doesn't change if it's neither stored in cache
Expand All @@ -106,19 +118,29 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) {

// findInCache finds domain name in the cache. expired is true if host is not
// valid anymore.
func (r *Default) findInCache(ip netip.Addr) (host string, expired bool) {
func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string, expired bool) {
val, err := r.cache.Get(ip)
if err != nil {
if !errors.Is(err, gcache.KeyNotFoundError) {
log.Debug("rdns: cache: retrieving %q: %s", ip, err)
r.logger.DebugContext(
ctx,
"retrieving item from cache",
"item", ip,
slogutil.KeyError, err,
)
}

return "", true
}

item, ok := val.(*cacheItem)
if !ok {
log.Debug("rdns: cache: %q bad type %T", ip, val)
r.logger.DebugContext(
ctx,
"bad type of cache item",
"item", ip,
"type", fmt.Sprintf("%T", val),
)

return "", true
}
Expand Down
15 changes: 10 additions & 5 deletions internal/rdns/rdns_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
package rdns_test

import (
"context"
"net/netip"
"testing"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// testTimeout is a common timeout for tests and contexts.
const testTimeout = 1 * time.Second

func TestDefault_Process(t *testing.T) {
ip1 := netip.MustParseAddr("1.2.3.4")
revAddr1, err := netutil.IPToReversedAddr(ip1.AsSlice())
Expand Down Expand Up @@ -71,14 +76,14 @@ func TestDefault_Process(t *testing.T) {
Exchanger: &aghtest.Exchanger{OnExchange: onExchange},
})

got, changed := r.Process(tc.addr)
got, changed := r.Process(testutil.ContextWithTimeout(t, testTimeout), tc.addr)
require.True(t, changed)

assert.Equal(t, tc.want, got)
assert.Equal(t, 1, hit)

// From cache.
got, changed = r.Process(tc.addr)
got, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), tc.addr)
require.False(t, changed)

assert.Equal(t, tc.want, got)
Expand All @@ -101,7 +106,7 @@ func TestDefault_Process(t *testing.T) {
Exchanger: zeroTTLExchanger,
})

got, changed := r.Process(ip1)
got, changed := r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1)
require.True(t, changed)
assert.Equal(t, revAddr1, got)

Expand All @@ -110,13 +115,13 @@ func TestDefault_Process(t *testing.T) {
}

require.EventuallyWithT(t, func(t *assert.CollectT) {
got, changed = r.Process(ip1)
got, changed = r.Process(context.TODO(), ip1)
assert.True(t, changed)
assert.Equal(t, revAddr2, got)
}, 2*cacheTTL, time.Millisecond*100)

assert.Never(t, func() (changed bool) {
_, changed = r.Process(ip1)
_, changed = r.Process(context.TODO(), ip1)

return changed
}, 2*cacheTTL, time.Millisecond*100)
Expand Down
Loading

0 comments on commit 0b1f022

Please sign in to comment.