diff --git a/go.mod b/go.mod index f00594f48..a16604f33 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/dnsproxy go 1.21.8 require ( - github.com/AdguardTeam/golibs v0.21.0 + github.com/AdguardTeam/golibs v0.22.0 github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnsstamps v1.0.3 github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 diff --git a/go.sum b/go.sum index bf9ab0b62..c412f0ffc 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE= -github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= +github.com/AdguardTeam/golibs v0.22.0 h1:wvT/UFIT8XIBfMabnD3LcDRiorx8J0lc3A/bzD6OX7c= +github.com/AdguardTeam/golibs v0.22.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= diff --git a/proxy/config.go b/proxy/config.go index 28af6876b..94702c791 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -8,6 +8,7 @@ import ( "net/url" "time" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -255,24 +256,17 @@ func (p *Proxy) validateConfig() (err error) { return fmt.Errorf("validating general upstreams: %w", err) } - if !p.UsePrivateRDNS { - err = p.PrivateRDNSUpstreamConfig.validate() - if errors.Is(err, errNoDefaultUpstreams) { - // Allow [Proxy.PrivateRDNSUpstreamConfig] to be nil, but not empty. - err = nil - } - } else { - err = p.PrivateRDNSUpstreamConfig.ValidatePrivateness(p.privateNets) - } + err = ValidatePrivateConfig(p.PrivateRDNSUpstreamConfig, p.privateNets) if err != nil { - return fmt.Errorf("validating private RDNS upstreams: %w", err) + if p.UsePrivateRDNS || errors.Is(err, upstream.ErrNoUpstreams) { + return fmt.Errorf("validating private RDNS upstreams: %w", err) + } } // Allow [Proxy.Fallbacks] to be nil, but not empty. nil means not to use // fallbacks at all. - err = p.Fallbacks.validate() - if err != nil && !errors.Is(err, errNoDefaultUpstreams) { + if errors.Is(err, upstream.ErrNoUpstreams) { return fmt.Errorf("validating fallbacks: %w", err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 37a3e3b13..18c963fb5 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -165,9 +165,7 @@ type Proxy struct { // udpOOBSize is the size of the out-of-band data for UDP connections. udpOOBSize int - // counter is the counter of messages. It must only be incremented - // atomically, so it must be the first member of the struct to make sure - // that it has a 64-bit alignment. + // counter counts message contexts created with [Proxy.newDNSContext]. counter atomic.Uint64 // RWMutex protects the whole proxy. diff --git a/proxy/server.go b/proxy/server.go index 161d52b89..0934526ef 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -133,11 +133,6 @@ func (p *Proxy) handleDNSRequest(d *DNSContext) (err error) { d.Res = p.validateRequest(d) if d.Res == nil { - // TODO(e.burkov): Remove it since the configs are validated. - if len(p.UpstreamConfig.Upstreams) == 0 { - panic(errNoDefaultUpstreams) - } - defer func() { err = errors.Annotate(err, "handling request: %w") }() if p.RequestHandler != nil { diff --git a/proxy/upstreams.go b/proxy/upstreams.go index 7ff9cf02d..52f1afa31 100644 --- a/proxy/upstreams.go +++ b/proxy/upstreams.go @@ -312,52 +312,48 @@ func (p *configParser) includeToReserved(dnsUpstream upstream.Upstream, domains } } -// errNoDefaultUpstreams is returned when no default upstreams specified within -// a [Config.UpstreamConfig]. -const errNoDefaultUpstreams errors.Error = "no default upstreams specified" - // validate returns an error if the upstreams aren't configured properly. c -// considered valid if it contains at least a single default upstream. Nil c, -// as well as c with no default upstreams causes [ErrNoDefaultUpstreams]. Empty -// c causes [upstream.ErrNoUpstreams]. +// considered valid if it contains at least a single default upstream. Empty c +// causes [upstream.ErrNoUpstreams]. func (uc *UpstreamConfig) validate() (err error) { + const ( + nilErr errors.Error = errors.Error("upstream config is nil") + emptyErr errors.Error = errors.Error("no default upstreams specified") + ) + switch { case uc == nil: - return fmt.Errorf("%w; uc is nil", errNoDefaultUpstreams) + return nilErr case len(uc.Upstreams) > 0: return nil case len(uc.DomainReservedUpstreams) == 0 && len(uc.SpecifiedDomainUpstreams) == 0: return upstream.ErrNoUpstreams default: - return errNoDefaultUpstreams + return emptyErr } } -// ValidatePrivateness returns an error if uc isn't valid, or, treated as +// ValidatePrivateConfig returns an error if uc isn't valid, or, treated as // private upstreams configuration, contains specifications for invalid domains. -// -// TODO(e.burkov): !! Should it really be exported. -func (uc *UpstreamConfig) ValidatePrivateness(privateSubnets netutil.SubnetSet) (err error) { +func ValidatePrivateConfig(uc *UpstreamConfig, privateSubnets netutil.SubnetSet) (err error) { if err = uc.validate(); err != nil { + // Don't wrap the error since it's informative enough as is. return err } var errs []error rangeFunc := func(domain string, _ []upstream.Upstream) (ok bool) { - switch domain { - case "in-addr.arpa.", "ip6.arpa.": - return true - default: - // Go on. - } - pref, extErr := netutil.ExtractReversedAddr(domain) switch { case extErr != nil: // Don't wrap the error since it's informative enough as is. errs = append(errs, extErr) + case pref.Bits() == 0: + // Allow private subnets for subdomains of the root domain. case !privateSubnets.Contains(pref.Addr()): errs = append(errs, fmt.Errorf("reversed subnet in %q is not private", domain)) + default: + // Go on. } return true diff --git a/proxy/upstreams_test.go b/proxy/upstreams_test.go index 49cad428b..827aee251 100644 --- a/proxy/upstreams_test.go +++ b/proxy/upstreams_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" @@ -148,7 +149,7 @@ func TestUpstreamConfig_Validate(t *testing.T) { }, }, { name: "no_default", - wantValidateErr: errNoDefaultUpstreams, + wantValidateErr: errors.Error("no default upstreams specified"), in: []string{ "[/domain.example/]udp://upstream.example:53", "[/another.domain.example/]#", @@ -165,11 +166,11 @@ func TestUpstreamConfig_Validate(t *testing.T) { } t.Run("actual_nil", func(t *testing.T) { - assert.ErrorIs(t, (*UpstreamConfig)(nil).validate(), errNoDefaultUpstreams) + assert.ErrorIs(t, (*UpstreamConfig)(nil).validate(), errors.Error("upstream config is nil")) }) } -func TestUpstreamConfig_ValidatePrivateness(t *testing.T) { +func TestValidatePrivateConfig(t *testing.T) { ss := netutil.SubnetSetFunc(netutil.IsLocallyServed) testCases := []struct { @@ -226,8 +227,7 @@ func TestUpstreamConfig_ValidatePrivateness(t *testing.T) { upsConf, err := ParseUpstreamsConfig(set, nil) require.NoError(t, err) - err = upsConf.ValidatePrivateness(ss) - testutil.AssertErrorMsg(t, tc.wantErr, err) + testutil.AssertErrorMsg(t, tc.wantErr, ValidatePrivateConfig(upsConf, ss)) }) } }