Skip to content

Commit

Permalink
Pull request 324: 6723 fix caching resolver
Browse files Browse the repository at this point in the history
Updates AdguardTeam/AdGuardHome#6723.

Squashed commit of the following:

commit 9c72e49
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Feb 21 16:24:09 2024 +0300

    dnsproxytest: imp code

commit 6951c76
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Feb 21 14:22:42 2024 +0300

    all: imp file names, add dnsproxytest

commit bb5d612
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Feb 20 16:26:00 2024 +0300

    upstream: fix expire

commit 2232083
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Feb 20 15:40:59 2024 +0300

    upstream: split internal tests

commit f4d9592
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Feb 20 15:36:27 2024 +0300

    upstream: add staleness cache

commit d0c2995
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Feb 19 19:28:28 2024 +0300

    all: fix resolver cache
  • Loading branch information
EugeneOne1 committed Feb 22, 2024
1 parent d918c7f commit ad97a28
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 60 deletions.
1 change: 1 addition & 0 deletions internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func ResolveDialContext(
defer cancel()
}

// TODO(e.burkov): Use network properly, perhaps, pass it through options.
ips, err := r.LookupNetIP(ctx, NetworkIP, host)
if err != nil {
return nil, fmt.Errorf("resolving hostname: %w", err)
Expand Down
3 changes: 3 additions & 0 deletions internal/dnsproxytest/dnsproxytest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// Package dnsproxytest provides a set of test utilities for the dnsproxy
// module.
package dnsproxytest
29 changes: 29 additions & 0 deletions internal/dnsproxytest/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package dnsproxytest

import (
"github.com/miekg/dns"
)

// FakeUpstream is a fake [Upstream] implementation for tests.
//
// TODO(e.burkov): Move this to the golibs?
type FakeUpstream struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
OnClose func() (err error)
}

// Address implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Address() (addr string) {
return u.OnAddress()
}

// Exchange implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}

// Close implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Close() (err error) {
return u.OnClose()
}
9 changes: 9 additions & 0 deletions internal/dnsproxytest/interface_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package dnsproxytest_test

import (
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/dnsproxy/upstream"
)

// type check
var _ upstream.Upstream = (*dnsproxytest.FakeUpstream)(nil)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
97 changes: 44 additions & 53 deletions upstream/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package upstream
import (
"context"
"fmt"
"math"
"net/netip"
"net/url"
"strings"
Expand Down Expand Up @@ -135,84 +136,70 @@ func (r *UpstreamResolver) LookupNetIP(

host = dns.Fqdn(strings.ToLower(host))

rr, err := r.resolveIP(ctx, network, host)
res, err := r.lookupNetIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}

for _, ip := range rr {
ips = append(ips, ip.addr)
}

return ips, err
return res.addrs, err
}

// ipResult reflects a single A/AAAA record from the DNS response. It's used
// to cache the results of lookups.
type ipResult struct {
addr netip.Addr
expire time.Time
addrs []netip.Addr
}

// filterExpired returns the addresses from res that are not expired yet. It
// returns nil if all the addresses are expired.
func filterExpired(res []ipResult, now time.Time) (filtered []netip.Addr) {
for _, r := range res {
if r.expire.After(now) {
filtered = append(filtered, r.addr)
}
}

return filtered
}

// resolveIP performs a DNS lookup of host and returns the result. network must
// be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6] or
// lookupNetIP performs a DNS lookup of host and returns the result. network
// must be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6], or
// [bootstrap.NetworkIP]. host must be in a lower-case FQDN form.
//
// TODO(e.burkov): Use context.
func (r *UpstreamResolver) resolveIP(
func (r *UpstreamResolver) lookupNetIP(
_ context.Context,
network bootstrap.Network,
host string,
) (rr []ipResult, err error) {
) (result *ipResult, err error) {
switch network {
case bootstrap.NetworkIP4, bootstrap.NetworkIP6:
return r.resolve(host, network)
return r.request(host, network)
case bootstrap.NetworkIP:
// Go on.
default:
return nil, fmt.Errorf("unsupported network %s", network)
return result, fmt.Errorf("unsupported network %s", network)
}

resCh := make(chan any, 2)
go r.resolveAsync(resCh, host, bootstrap.NetworkIP4)
go r.resolveAsync(resCh, host, bootstrap.NetworkIP6)

var errs []error
result = &ipResult{}

for i := 0; i < 2; i++ {
switch res := <-resCh; res := res.(type) {
case error:
errs = append(errs, res)
case []ipResult:
rr = append(rr, res...)
case *ipResult:
if result.expire.Equal(time.Time{}) || res.expire.Before(result.expire) {
result.expire = res.expire
}
result.addrs = append(result.addrs, res.addrs...)
}
}

return rr, errors.Join(errs...)
return result, errors.Join(errs...)
}

// resolve performs a single DNS lookup of host and returns all the valid
// request performs a single DNS lookup of host and returns all the valid
// addresses from the answer section of the response. network must be either
// "ip4" or "ip6". host must be in a lower-case FQDN form.
// [bootstrap.NetworkIP4], or [bootstrap.NetworkIP6]. host must be in a
// lower-case FQDN form.
//
// TODO(e.burkov): Consider NS and Extra sections when setting TTL. Check out
// what RFCs say about it.
func (r *UpstreamResolver) resolve(
host string,
n bootstrap.Network,
) (res []ipResult, err error) {
func (r *UpstreamResolver) request(host string, n bootstrap.Network) (res *ipResult, err error) {
var qtype uint16
switch n {
case bootstrap.NetworkIP4:
Expand All @@ -235,33 +222,37 @@ func (r *UpstreamResolver) resolve(
}},
}

// As per [upstream.Exchange] documentation, the response is always returned
// As per [Upstream.Exchange] documentation, the response is always returned
// if no error occurred.
resp, err := r.Exchange(req)
if err != nil {
return nil, err
return res, err
}

now := time.Now()
res = &ipResult{
expire: time.Now(),
addrs: make([]netip.Addr, 0, len(resp.Answer)),
}
var minTTL uint32 = math.MaxUint32

for _, rr := range resp.Answer {
ip := proxyutil.IPFromRR(rr)
if !ip.IsValid() {
continue
}

res = append(res, ipResult{
addr: ip,
expire: now.Add(time.Duration(rr.Header().Ttl) * time.Second),
})
res.addrs = append(res.addrs, ip)
minTTL = min(minTTL, rr.Header().Ttl)
}
res.expire = res.expire.Add(time.Duration(minTTL) * time.Second)

return res, nil
}

// resolveAsync performs a single DNS lookup and sends the result to ch. It's
// intended to be used as a goroutine.
func (r *UpstreamResolver) resolveAsync(resCh chan<- any, host, network string) {
res, err := r.resolve(host, network)
res, err := r.request(host, network)
if err != nil {
resCh <- err
} else {
Expand All @@ -279,22 +270,27 @@ type CachingResolver struct {
mu *sync.RWMutex

// cached is the set of cached results sorted by [resolveResult.name].
cached map[string][]ipResult
//
// TODO(e.burkov): Use expiration cache.
cached map[string]*ipResult
}

// NewCachingResolver creates a new caching resolver that uses r for lookups.
func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) {
return &CachingResolver{
resolver: r,
mu: &sync.RWMutex{},
cached: map[string][]ipResult{},
cached: map[string]*ipResult{},
}
}

// type check
var _ Resolver = (*CachingResolver)(nil)

// LookupNetIP implements the [Resolver] interface for *CachingResolver.
//
// TODO(e.burkov): It may appear that several concurrent lookup results rewrite
// each other in the cache.
func (r *CachingResolver) LookupNetIP(
ctx context.Context,
network bootstrap.Network,
Expand All @@ -308,22 +304,17 @@ func (r *CachingResolver) LookupNetIP(
return addrs, nil
}

newRes, err := r.resolver.resolveIP(ctx, network, host)
newRes, err := r.resolver.lookupNetIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}

addrs = filterExpired(newRes, now)
if len(addrs) == 0 {
return []netip.Addr{}, nil
}

r.mu.Lock()
defer r.mu.Unlock()

r.cached[host] = newRes

return addrs, nil
return newRes.addrs, nil
}

// findCached returns the cached addresses for host if it's not expired yet, and
Expand All @@ -333,9 +324,9 @@ func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.
defer r.mu.RUnlock()

res, ok := r.cached[host]
if !ok {
if !ok || res.expire.Before(now) {
return nil
}

return filterExpired(res, now)
return res.addrs
}
100 changes: 100 additions & 0 deletions upstream/resolver_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package upstream

import (
"context"
"net/netip"
"testing"
"time"

"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCachingResolver_staleness(t *testing.T) {
ip4 := netip.MustParseAddr("1.2.3.4")
ip6 := netip.MustParseAddr("2001:db8::1")

const (
smallTTL = 10 * time.Second
largeTTL = 1000 * time.Second

fqdn = "test.fully.qualified.name."
)

onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(req)

hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: req.Question[0].Qtype,
Class: dns.ClassINET,
}
var rr dns.RR
switch q := req.Question[0]; q.Qtype {
case dns.TypeA:
hdr.Ttl = uint32(smallTTL.Seconds())
rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()}
case dns.TypeAAAA:
hdr.Ttl = uint32(largeTTL.Seconds())
rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()}
default:
require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype)
}
resp.Answer = append(resp.Answer, rr)

return resp, nil
}

ups := &dnsproxytest.FakeUpstream{
OnAddress: func() (_ string) { panic("not implemented") },
OnClose: func() (_ error) { panic("not implemented") },
OnExchange: onExchange,
}

r := NewCachingResolver(&UpstreamResolver{Upstream: ups})

require.True(t, t.Run("resolve", func(t *testing.T) {
testCases := []struct {
name string
network bootstrap.Network
want []netip.Addr
}{{
name: "ip4",
network: bootstrap.NetworkIP4,
want: []netip.Addr{ip4},
}, {
name: "ip6",
network: bootstrap.NetworkIP6,
want: []netip.Addr{ip6},
}, {
name: "both",
network: bootstrap.NetworkIP,
want: []netip.Addr{ip4, ip6},
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.name != "both" {
t.Skip(`TODO(e.burkov): Bootstrap now only uses "ip" network, see TODO there.`)
}

res, err := r.LookupNetIP(context.Background(), tc.network, fqdn)
require.NoError(t, err)

assert.ElementsMatch(t, tc.want, res)
})
}
}))

t.Run("staleness", func(t *testing.T) {
cached := r.findCached(fqdn, time.Now())
require.ElementsMatch(t, []netip.Addr{ip4, ip6}, cached)

cached = r.findCached(fqdn, time.Now().Add(smallTTL))
require.Empty(t, cached)
})
}
Loading

0 comments on commit ad97a28

Please sign in to comment.