diff --git a/upstream/resolver_internal_test.go b/upstream/resolver_internal_test.go new file mode 100644 index 000000000..d2e5bdae2 --- /dev/null +++ b/upstream/resolver_internal_test.go @@ -0,0 +1,101 @@ +package upstream + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/AdguardTeam/dnsproxy/internal/bootstrap" + "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCachingResolver_staleness(t *testing.T) { + ip4 := netip.MustParseAddr("1.2.3.4") + ip6 := netip.MustParseAddr("2001:db8::1") + + const ( + smallTTL = 10 * time.Second + largeTTL = 1000 * time.Second + + fqdn = "cloudflare-dns.com." + ) + + onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = (&dns.Msg{}).SetReply(req) + + hdr := dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: req.Question[0].Qtype, + Class: dns.ClassINET, + } + var rr dns.RR + switch q := req.Question[0]; q.Qtype { + case dns.TypeA: + hdr.Ttl = uint32(smallTTL.Seconds()) + rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()} + case dns.TypeAAAA: + hdr.Ttl = uint32(largeTTL.Seconds()) + rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()} + default: + require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype) + } + resp.Answer = append(resp.Answer, rr) + + return resp, nil + } + + ups := &FakeUpstream{ + OnAddress: func() (_ string) { panic("not implemented") }, + OnClose: func() (_ error) { panic("not implemented") }, + OnExchange: onExchange, + } + + r := NewCachingResolver(&UpstreamResolver{Upstream: ups}) + + require.True(t, t.Run("resolve", func(t *testing.T) { + testCases := []struct { + name string + network bootstrap.Network + want []netip.Addr + }{{ + name: "ip4", + network: bootstrap.NetworkIP4, + want: []netip.Addr{ip4}, + }, { + name: "ip6", + network: bootstrap.NetworkIP6, + want: []netip.Addr{ip6}, + }, { + name: "both", + network: bootstrap.NetworkIP, + want: []netip.Addr{ip4, ip6}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.name != "both" { + t.Skip(`TODO(e.burkov): Bootstrap now only uses "ip" network, see TODO there.`) + } + + res, err := r.LookupNetIP(context.Background(), tc.network, fqdn) + require.NoError(t, err) + + assert.ElementsMatch(t, tc.want, res) + }) + } + })) + + t.Run("staleness", func(t *testing.T) { + r.mu.Lock() + defer r.mu.Unlock() + + require.Contains(t, r.cached, fqdn) + + cached := r.cached[fqdn] + assert.Less(t, time.Until(cached.expire), smallTTL) + }) +} diff --git a/upstream/resolver_test.go b/upstream/resolver_test.go index db1e8feb2..eb0257cf1 100644 --- a/upstream/resolver_test.go +++ b/upstream/resolver_test.go @@ -1,4 +1,4 @@ -package upstream +package upstream_test import ( "context" @@ -6,24 +6,34 @@ import ( "testing" "time" - "github.com/AdguardTeam/dnsproxy/internal/bootstrap" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewUpstreamResolver(t *testing.T) { - ups := &FakeUpstream{ + ups := &upstream.FakeUpstream{ OnAddress: func() (_ string) { panic("not implemented") }, OnClose: func() (_ error) { panic("not implemented") }, OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - return respondToTestMessage(req), nil + resp = (&dns.Msg{}).SetReply(req) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 60, + }, + A: netip.MustParseAddr("1.2.3.4").AsSlice(), + }} + + return resp, nil }, } - r := &UpstreamResolver{Upstream: ups} + r := &upstream.UpstreamResolver{Upstream: ups} ipAddrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com") require.NoError(t, err) @@ -32,7 +42,7 @@ func TestNewUpstreamResolver(t *testing.T) { } func TestNewUpstreamResolver_validity(t *testing.T) { - withTimeoutOpt := &Options{Timeout: 3 * time.Second} + withTimeoutOpt := &upstream.Options{Timeout: 3 * time.Second} testCases := []struct { name string @@ -82,10 +92,10 @@ func TestNewUpstreamResolver_validity(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - r, err := NewUpstreamResolver(tc.addr, withTimeoutOpt) + r, err := upstream.NewUpstreamResolver(tc.addr, withTimeoutOpt) if tc.wantErrMsg != "" { assert.Equal(t, tc.wantErrMsg, err.Error()) - if nberr := (&NotBootstrapError{}); errors.As(err, &nberr) { + if nberr := (&upstream.NotBootstrapError{}); errors.As(err, &nberr) { assert.NotNil(t, r) } @@ -101,90 +111,3 @@ func TestNewUpstreamResolver_validity(t *testing.T) { }) } } - -func TestCachingResolver_staleness(t *testing.T) { - ip4 := netip.MustParseAddr("1.2.3.4") - ip6 := netip.MustParseAddr("2001:db8::1") - - const ( - smallTTL = 10 * time.Second - largeTTL = 1000 * time.Second - - fqdn = "cloudflare-dns.com." - ) - - onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) { - resp = (&dns.Msg{}).SetReply(req) - - hdr := dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: req.Question[0].Qtype, - Class: dns.ClassINET, - } - var rr dns.RR - switch q := req.Question[0]; q.Qtype { - case dns.TypeA: - hdr.Ttl = uint32(smallTTL.Seconds()) - rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()} - case dns.TypeAAAA: - hdr.Ttl = uint32(largeTTL.Seconds()) - rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()} - default: - require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype) - } - resp.Answer = append(resp.Answer, rr) - - return resp, nil - } - - ups := &FakeUpstream{ - OnAddress: func() (_ string) { panic("not implemented") }, - OnClose: func() (_ error) { panic("not implemented") }, - OnExchange: onExchange, - } - - r := NewCachingResolver(&UpstreamResolver{Upstream: ups}) - - require.True(t, t.Run("resolve", func(t *testing.T) { - testCases := []struct { - name string - network bootstrap.Network - want []netip.Addr - }{{ - name: "ip4", - network: bootstrap.NetworkIP4, - want: []netip.Addr{ip4}, - }, { - name: "ip6", - network: bootstrap.NetworkIP6, - want: []netip.Addr{ip6}, - }, { - name: "both", - network: bootstrap.NetworkIP, - want: []netip.Addr{ip4, ip6}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if tc.name != "both" { - t.Skip(`TODO(e.burkov): Bootstrap now only uses "ip" network, see TODO there.`) - } - - res, err := r.LookupNetIP(context.Background(), tc.network, fqdn) - require.NoError(t, err) - - assert.ElementsMatch(t, tc.want, res) - }) - } - })) - - t.Run("staleness", func(t *testing.T) { - r.mu.Lock() - defer r.mu.Unlock() - - require.Contains(t, r.cached, fqdn) - - cached := r.cached[fqdn] - assert.Less(t, time.Until(cached.expire), smallTTL) - }) -}