From ad97a28de18313b61eefeb1962502a858997b052 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Thu, 22 Feb 2024 15:53:07 +0300 Subject: [PATCH] Pull request 324: 6723 fix caching resolver Updates AdguardTeam/AdGuardHome#6723. Squashed commit of the following: commit 9c72e491f85fadfe3d28632058cf81c61a6b8a64 Author: Eugene Burkov Date: Wed Feb 21 16:24:09 2024 +0300 dnsproxytest: imp code commit 6951c768c3758cd5c9b8918ac32c5dab06428350 Author: Eugene Burkov Date: Wed Feb 21 14:22:42 2024 +0300 all: imp file names, add dnsproxytest commit bb5d6124979e86d81286a0bed1996165a19c9a5a Author: Eugene Burkov Date: Tue Feb 20 16:26:00 2024 +0300 upstream: fix expire commit 223208303878b611e1d068bfedf2c4c08db60fff Author: Eugene Burkov Date: Tue Feb 20 15:40:59 2024 +0300 upstream: split internal tests commit f4d95928291dfdc25a173fd306eea056d71dd506 Author: Eugene Burkov Date: Tue Feb 20 15:36:27 2024 +0300 upstream: add staleness cache commit d0c29959dfbd0c545506d24bf9f8c0483e2f3faf Author: Eugene Burkov Date: Mon Feb 19 19:28:28 2024 +0300 all: fix resolver cache --- internal/bootstrap/bootstrap.go | 1 + internal/dnsproxytest/dnsproxytest.go | 3 + internal/dnsproxytest/interface.go | 29 +++++ internal/dnsproxytest/interface_test.go | 9 ++ ...rypt_test.go => dnscrypt_internal_test.go} | 0 .../{doh_test.go => doh_internal_test.go} | 0 .../{doq_test.go => doq_internal_test.go} | 0 .../{dot_test.go => dot_internal_test.go} | 0 ...llel_test.go => parallel_internal_test.go} | 0 .../{plain_test.go => plain_internal_test.go} | 0 upstream/resolver.go | 97 ++++++++--------- upstream/resolver_internal_test.go | 100 ++++++++++++++++++ upstream/resolver_test.go | 34 ++++-- ...ream_test.go => upstream_internal_test.go} | 2 +- 14 files changed, 215 insertions(+), 60 deletions(-) create mode 100644 internal/dnsproxytest/dnsproxytest.go create mode 100644 internal/dnsproxytest/interface.go create mode 100644 internal/dnsproxytest/interface_test.go rename upstream/{dnscrypt_test.go => dnscrypt_internal_test.go} (100%) rename upstream/{doh_test.go => doh_internal_test.go} (100%) rename upstream/{doq_test.go => doq_internal_test.go} (100%) rename upstream/{dot_test.go => dot_internal_test.go} (100%) rename upstream/{parallel_test.go => parallel_internal_test.go} (100%) rename upstream/{plain_test.go => plain_internal_test.go} (100%) create mode 100644 upstream/resolver_internal_test.go rename upstream/{upstream_test.go => upstream_internal_test.go} (99%) diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 4e897fe2c..5231a0b98 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -71,6 +71,7 @@ func ResolveDialContext( defer cancel() } + // TODO(e.burkov): Use network properly, perhaps, pass it through options. ips, err := r.LookupNetIP(ctx, NetworkIP, host) if err != nil { return nil, fmt.Errorf("resolving hostname: %w", err) diff --git a/internal/dnsproxytest/dnsproxytest.go b/internal/dnsproxytest/dnsproxytest.go new file mode 100644 index 000000000..8a9be2084 --- /dev/null +++ b/internal/dnsproxytest/dnsproxytest.go @@ -0,0 +1,3 @@ +// Package dnsproxytest provides a set of test utilities for the dnsproxy +// module. +package dnsproxytest diff --git a/internal/dnsproxytest/interface.go b/internal/dnsproxytest/interface.go new file mode 100644 index 000000000..a600c122a --- /dev/null +++ b/internal/dnsproxytest/interface.go @@ -0,0 +1,29 @@ +package dnsproxytest + +import ( + "github.com/miekg/dns" +) + +// FakeUpstream is a fake [Upstream] implementation for tests. +// +// TODO(e.burkov): Move this to the golibs? +type FakeUpstream struct { + OnAddress func() (addr string) + OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) + OnClose func() (err error) +} + +// Address implements the [Upstream] interface for *FakeUpstream. +func (u *FakeUpstream) Address() (addr string) { + return u.OnAddress() +} + +// Exchange implements the [Upstream] interface for *FakeUpstream. +func (u *FakeUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + return u.OnExchange(req) +} + +// Close implements the [Upstream] interface for *FakeUpstream. +func (u *FakeUpstream) Close() (err error) { + return u.OnClose() +} diff --git a/internal/dnsproxytest/interface_test.go b/internal/dnsproxytest/interface_test.go new file mode 100644 index 000000000..3d231bc8e --- /dev/null +++ b/internal/dnsproxytest/interface_test.go @@ -0,0 +1,9 @@ +package dnsproxytest_test + +import ( + "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" + "github.com/AdguardTeam/dnsproxy/upstream" +) + +// type check +var _ upstream.Upstream = (*dnsproxytest.FakeUpstream)(nil) diff --git a/upstream/dnscrypt_test.go b/upstream/dnscrypt_internal_test.go similarity index 100% rename from upstream/dnscrypt_test.go rename to upstream/dnscrypt_internal_test.go diff --git a/upstream/doh_test.go b/upstream/doh_internal_test.go similarity index 100% rename from upstream/doh_test.go rename to upstream/doh_internal_test.go diff --git a/upstream/doq_test.go b/upstream/doq_internal_test.go similarity index 100% rename from upstream/doq_test.go rename to upstream/doq_internal_test.go diff --git a/upstream/dot_test.go b/upstream/dot_internal_test.go similarity index 100% rename from upstream/dot_test.go rename to upstream/dot_internal_test.go diff --git a/upstream/parallel_test.go b/upstream/parallel_internal_test.go similarity index 100% rename from upstream/parallel_test.go rename to upstream/parallel_internal_test.go diff --git a/upstream/plain_test.go b/upstream/plain_internal_test.go similarity index 100% rename from upstream/plain_test.go rename to upstream/plain_internal_test.go diff --git a/upstream/resolver.go b/upstream/resolver.go index 5f98c800d..cc613fa7f 100644 --- a/upstream/resolver.go +++ b/upstream/resolver.go @@ -3,6 +3,7 @@ package upstream import ( "context" "fmt" + "math" "net/netip" "net/url" "strings" @@ -135,54 +136,38 @@ func (r *UpstreamResolver) LookupNetIP( host = dns.Fqdn(strings.ToLower(host)) - rr, err := r.resolveIP(ctx, network, host) + res, err := r.lookupNetIP(ctx, network, host) if err != nil { return []netip.Addr{}, err } - for _, ip := range rr { - ips = append(ips, ip.addr) - } - - return ips, err + return res.addrs, err } // ipResult reflects a single A/AAAA record from the DNS response. It's used // to cache the results of lookups. type ipResult struct { - addr netip.Addr expire time.Time + addrs []netip.Addr } -// filterExpired returns the addresses from res that are not expired yet. It -// returns nil if all the addresses are expired. -func filterExpired(res []ipResult, now time.Time) (filtered []netip.Addr) { - for _, r := range res { - if r.expire.After(now) { - filtered = append(filtered, r.addr) - } - } - - return filtered -} - -// resolveIP performs a DNS lookup of host and returns the result. network must -// be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6] or +// lookupNetIP performs a DNS lookup of host and returns the result. network +// must be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6], or // [bootstrap.NetworkIP]. host must be in a lower-case FQDN form. // // TODO(e.burkov): Use context. -func (r *UpstreamResolver) resolveIP( +func (r *UpstreamResolver) lookupNetIP( _ context.Context, network bootstrap.Network, host string, -) (rr []ipResult, err error) { +) (result *ipResult, err error) { switch network { case bootstrap.NetworkIP4, bootstrap.NetworkIP6: - return r.resolve(host, network) + return r.request(host, network) case bootstrap.NetworkIP: // Go on. default: - return nil, fmt.Errorf("unsupported network %s", network) + return result, fmt.Errorf("unsupported network %s", network) } resCh := make(chan any, 2) @@ -190,29 +175,31 @@ func (r *UpstreamResolver) resolveIP( go r.resolveAsync(resCh, host, bootstrap.NetworkIP6) var errs []error + result = &ipResult{} for i := 0; i < 2; i++ { switch res := <-resCh; res := res.(type) { case error: errs = append(errs, res) - case []ipResult: - rr = append(rr, res...) + case *ipResult: + if result.expire.Equal(time.Time{}) || res.expire.Before(result.expire) { + result.expire = res.expire + } + result.addrs = append(result.addrs, res.addrs...) } } - return rr, errors.Join(errs...) + return result, errors.Join(errs...) } -// resolve performs a single DNS lookup of host and returns all the valid +// request performs a single DNS lookup of host and returns all the valid // addresses from the answer section of the response. network must be either -// "ip4" or "ip6". host must be in a lower-case FQDN form. +// [bootstrap.NetworkIP4], or [bootstrap.NetworkIP6]. host must be in a +// lower-case FQDN form. // // TODO(e.burkov): Consider NS and Extra sections when setting TTL. Check out // what RFCs say about it. -func (r *UpstreamResolver) resolve( - host string, - n bootstrap.Network, -) (res []ipResult, err error) { +func (r *UpstreamResolver) request(host string, n bootstrap.Network) (res *ipResult, err error) { var qtype uint16 switch n { case bootstrap.NetworkIP4: @@ -235,25 +222,29 @@ func (r *UpstreamResolver) resolve( }}, } - // As per [upstream.Exchange] documentation, the response is always returned + // As per [Upstream.Exchange] documentation, the response is always returned // if no error occurred. resp, err := r.Exchange(req) if err != nil { - return nil, err + return res, err } - now := time.Now() + res = &ipResult{ + expire: time.Now(), + addrs: make([]netip.Addr, 0, len(resp.Answer)), + } + var minTTL uint32 = math.MaxUint32 + for _, rr := range resp.Answer { ip := proxyutil.IPFromRR(rr) if !ip.IsValid() { continue } - res = append(res, ipResult{ - addr: ip, - expire: now.Add(time.Duration(rr.Header().Ttl) * time.Second), - }) + res.addrs = append(res.addrs, ip) + minTTL = min(minTTL, rr.Header().Ttl) } + res.expire = res.expire.Add(time.Duration(minTTL) * time.Second) return res, nil } @@ -261,7 +252,7 @@ func (r *UpstreamResolver) resolve( // resolveAsync performs a single DNS lookup and sends the result to ch. It's // intended to be used as a goroutine. func (r *UpstreamResolver) resolveAsync(resCh chan<- any, host, network string) { - res, err := r.resolve(host, network) + res, err := r.request(host, network) if err != nil { resCh <- err } else { @@ -279,7 +270,9 @@ type CachingResolver struct { mu *sync.RWMutex // cached is the set of cached results sorted by [resolveResult.name]. - cached map[string][]ipResult + // + // TODO(e.burkov): Use expiration cache. + cached map[string]*ipResult } // NewCachingResolver creates a new caching resolver that uses r for lookups. @@ -287,7 +280,7 @@ func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) { return &CachingResolver{ resolver: r, mu: &sync.RWMutex{}, - cached: map[string][]ipResult{}, + cached: map[string]*ipResult{}, } } @@ -295,6 +288,9 @@ func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) { var _ Resolver = (*CachingResolver)(nil) // LookupNetIP implements the [Resolver] interface for *CachingResolver. +// +// TODO(e.burkov): It may appear that several concurrent lookup results rewrite +// each other in the cache. func (r *CachingResolver) LookupNetIP( ctx context.Context, network bootstrap.Network, @@ -308,22 +304,17 @@ func (r *CachingResolver) LookupNetIP( return addrs, nil } - newRes, err := r.resolver.resolveIP(ctx, network, host) + newRes, err := r.resolver.lookupNetIP(ctx, network, host) if err != nil { return []netip.Addr{}, err } - addrs = filterExpired(newRes, now) - if len(addrs) == 0 { - return []netip.Addr{}, nil - } - r.mu.Lock() defer r.mu.Unlock() r.cached[host] = newRes - return addrs, nil + return newRes.addrs, nil } // findCached returns the cached addresses for host if it's not expired yet, and @@ -333,9 +324,9 @@ func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip. defer r.mu.RUnlock() res, ok := r.cached[host] - if !ok { + if !ok || res.expire.Before(now) { return nil } - return filterExpired(res, now) + return res.addrs } diff --git a/upstream/resolver_internal_test.go b/upstream/resolver_internal_test.go new file mode 100644 index 000000000..bd2690fde --- /dev/null +++ b/upstream/resolver_internal_test.go @@ -0,0 +1,100 @@ +package upstream + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/AdguardTeam/dnsproxy/internal/bootstrap" + "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCachingResolver_staleness(t *testing.T) { + ip4 := netip.MustParseAddr("1.2.3.4") + ip6 := netip.MustParseAddr("2001:db8::1") + + const ( + smallTTL = 10 * time.Second + largeTTL = 1000 * time.Second + + fqdn = "test.fully.qualified.name." + ) + + onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = (&dns.Msg{}).SetReply(req) + + hdr := dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: req.Question[0].Qtype, + Class: dns.ClassINET, + } + var rr dns.RR + switch q := req.Question[0]; q.Qtype { + case dns.TypeA: + hdr.Ttl = uint32(smallTTL.Seconds()) + rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()} + case dns.TypeAAAA: + hdr.Ttl = uint32(largeTTL.Seconds()) + rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()} + default: + require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype) + } + resp.Answer = append(resp.Answer, rr) + + return resp, nil + } + + ups := &dnsproxytest.FakeUpstream{ + OnAddress: func() (_ string) { panic("not implemented") }, + OnClose: func() (_ error) { panic("not implemented") }, + OnExchange: onExchange, + } + + r := NewCachingResolver(&UpstreamResolver{Upstream: ups}) + + require.True(t, t.Run("resolve", func(t *testing.T) { + testCases := []struct { + name string + network bootstrap.Network + want []netip.Addr + }{{ + name: "ip4", + network: bootstrap.NetworkIP4, + want: []netip.Addr{ip4}, + }, { + name: "ip6", + network: bootstrap.NetworkIP6, + want: []netip.Addr{ip6}, + }, { + name: "both", + network: bootstrap.NetworkIP, + want: []netip.Addr{ip4, ip6}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.name != "both" { + t.Skip(`TODO(e.burkov): Bootstrap now only uses "ip" network, see TODO there.`) + } + + res, err := r.LookupNetIP(context.Background(), tc.network, fqdn) + require.NoError(t, err) + + assert.ElementsMatch(t, tc.want, res) + }) + } + })) + + t.Run("staleness", func(t *testing.T) { + cached := r.findCached(fqdn, time.Now()) + require.ElementsMatch(t, []netip.Addr{ip4, ip6}, cached) + + cached = r.findCached(fqdn, time.Now().Add(smallTTL)) + require.Empty(t, cached) + }) +} diff --git a/upstream/resolver_test.go b/upstream/resolver_test.go index 505b141dd..f22360d07 100644 --- a/upstream/resolver_test.go +++ b/upstream/resolver_test.go @@ -1,18 +1,40 @@ -package upstream +package upstream_test import ( "context" + "net/netip" "testing" "time" + "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewUpstreamResolver(t *testing.T) { - r, err := NewUpstreamResolver("1.1.1.1:53", &Options{Timeout: 3 * time.Second}) - require.NoError(t, err) + ups := &dnsproxytest.FakeUpstream{ + OnAddress: func() (_ string) { panic("not implemented") }, + OnClose: func() (_ error) { panic("not implemented") }, + OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = (&dns.Msg{}).SetReply(req) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 60, + }, + A: netip.MustParseAddr("1.2.3.4").AsSlice(), + }} + + return resp, nil + }, + } + + r := &upstream.UpstreamResolver{Upstream: ups} ipAddrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com") require.NoError(t, err) @@ -21,7 +43,7 @@ func TestNewUpstreamResolver(t *testing.T) { } func TestNewUpstreamResolver_validity(t *testing.T) { - withTimeoutOpt := &Options{Timeout: 3 * time.Second} + withTimeoutOpt := &upstream.Options{Timeout: 3 * time.Second} testCases := []struct { name string @@ -71,10 +93,10 @@ func TestNewUpstreamResolver_validity(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - r, err := NewUpstreamResolver(tc.addr, withTimeoutOpt) + r, err := upstream.NewUpstreamResolver(tc.addr, withTimeoutOpt) if tc.wantErrMsg != "" { assert.Equal(t, tc.wantErrMsg, err.Error()) - if nberr := (&NotBootstrapError{}); errors.As(err, &nberr) { + if nberr := (&upstream.NotBootstrapError{}); errors.As(err, &nberr) { assert.NotNil(t, r) } diff --git a/upstream/upstream_test.go b/upstream/upstream_internal_test.go similarity index 99% rename from upstream/upstream_test.go rename to upstream/upstream_internal_test.go index 193d7ba3f..043300f5d 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_internal_test.go @@ -25,7 +25,7 @@ import ( "github.com/stretchr/testify/require" ) -// TODO(ameshkov): make tests here not depend on external servers. +// TODO(ameshkov): Make tests here not depend on external servers. func TestMain(m *testing.M) { testutil.DiscardLogOutput(m)