From 6bdb8e95ea2066f17bd5665f90105cf0e0f0a5e0 Mon Sep 17 00:00:00 2001 From: Dimitry Kolyshev Date: Wed, 29 Nov 2023 11:02:21 +0200 Subject: [PATCH] upstream: imp code --- upstream/parallel_test.go | 105 +++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 47 deletions(-) diff --git a/upstream/parallel_test.go b/upstream/parallel_test.go index 76acce6d8..0a36ed82a 100644 --- a/upstream/parallel_test.go +++ b/upstream/parallel_test.go @@ -3,12 +3,13 @@ package upstream import ( "context" "fmt" - "net" + "net/netip" "testing" "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -71,45 +72,52 @@ func TestLookupParallel(t *testing.T) { } func TestLookupParallelEmpty(t *testing.T) { - u1 := testUpstream{} - u2 := testUpstream{} - - resolvers := []Resolver{} - resolvers = append(resolvers, &UpstreamResolver{Upstream: &u1}) - resolvers = append(resolvers, &UpstreamResolver{Upstream: &u2}) + resolvers := []Resolver{ + &UpstreamResolver{Upstream: &testUpstream{}}, + &UpstreamResolver{Upstream: &testUpstream{}}, + } ctx, cancel := context.WithTimeout(context.TODO(), timeout) defer cancel() - a, err := LookupParallel(ctx, resolvers, "google.com") - assert.Nil(t, err) - assert.Equal(t, 0, len(a)) + + addrs, err := LookupParallel(ctx, resolvers, "google.com") + require.NoError(t, err) + assert.Len(t, addrs, 0) } func TestExchangeParallelEmpty(t *testing.T) { - u1 := testUpstream{} - u1.empty = true - u2 := testUpstream{} - u2.empty = true - u := []Upstream{&u1, &u2} + ups := []Upstream{ + &testUpstream{empty: true}, + &testUpstream{empty: true}, + } req := createTestMessage() - a, up, err := ExchangeParallel(u, req) - assert.NotNil(t, err) - assert.Nil(t, a) + resp, up, err := ExchangeParallel(ups, req) + require.Error(t, err) + + assert.Nil(t, resp) assert.Nil(t, up) } +// testUpstream represents a mock upstream structure. type testUpstream struct { - a net.IP - err bool + // addr is a mock A record IP address to be returned. + addr netip.Addr + + // err is a mock error to be returned. + err bool + + // empty indicates if a nil response is returned. empty bool - sleep time.Duration // a delay before response + + // sleep is a delay before response. + sleep time.Duration } // type check var _ Upstream = (*testUpstream)(nil) -// Exchange implements the Upstream interface for *testUpstream. +// Exchange implements the [Upstream] interface for *testUpstream. func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { if u.sleep != 0 { time.Sleep(u.sleep) @@ -119,52 +127,55 @@ func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { return nil, nil } + if u.err { + return nil, fmt.Errorf("upstream error") + } + resp = &dns.Msg{} resp.SetReply(req) - if len(u.a) != 0 { - a := dns.A{} - a.A = u.a - resp.Answer = append(resp.Answer, &a) - } + if u.addr != (netip.Addr{}) { + a := dns.A{ + A: u.addr.AsSlice(), + } - if u.err { - return nil, fmt.Errorf("upstream error") + resp.Answer = append(resp.Answer, &a) } return resp, nil } -// Address implements the Upstream interface for *testUpstream. +// Address implements the [Upstream] interface for *testUpstream. func (u *testUpstream) Address() (addr string) { return "" } -// Close implements the Upstream interface for *testUpstream. +// Close implements the [Upstream] interface for *testUpstream. func (u *testUpstream) Close() (err error) { return nil } func TestExchangeAll(t *testing.T) { - u1 := testUpstream{} - u1.a = net.ParseIP("1.1.1.1") - u1.sleep = 100 * time.Millisecond - - u2 := testUpstream{} - u2.err = true - - u3 := testUpstream{} - u3.a = net.ParseIP("3.3.3.3") + delayedAnsAddr := netip.MustParseAddr("1.1.1.1") + ansAddr := netip.MustParseAddr("3.3.3.3") + + ups := []Upstream{&testUpstream{ + addr: delayedAnsAddr, + sleep: 100 * time.Millisecond, + }, &testUpstream{ + err: true, + }, &testUpstream{ + addr: ansAddr, + }} - ups := []Upstream{&u1, &u2, &u3} req := createHostTestMessage("test.org") res, err := ExchangeAll(ups, req) - assert.True(t, err == nil) - assert.True(t, len(res) == 2) + require.NoError(t, err) + assert.Len(t, res, 2) - a := res[0].Resp.Answer[0].(*dns.A) - assert.True(t, a.A.To4().Equal(net.ParseIP("3.3.3.3").To4())) + ans := res[0].Resp.Answer[0].(*dns.A) + assert.Equal(t, ansAddr.AsSlice(), []byte(ans.A)) - a = res[1].Resp.Answer[0].(*dns.A) - assert.True(t, a.A.To4().Equal(net.ParseIP("1.1.1.1").To4())) + ans = res[1].Resp.Answer[0].(*dns.A) + assert.Equal(t, delayedAnsAddr.AsSlice(), []byte(ans.A)) }