From 638288ec9bd2e9fbadc458fdf6d32f386d8fa41e Mon Sep 17 00:00:00 2001 From: Dimitry Kolyshev Date: Tue, 12 Dec 2023 16:21:14 +0300 Subject: [PATCH 1/7] Pull request: all: upd golibs Squashed commit of the following: commit 9b7e21eaf062a569a259d2279f785542e6fadb37 Author: Dimitry Kolyshev Date: Tue Dec 12 10:05:31 2023 +0200 proxy: imp code commit d5b40a4a88c8fbd5edf1c207a6f7e3d0fad16d0e Author: Dimitry Kolyshev Date: Mon Dec 11 12:02:26 2023 +0200 all: upd golibs --- README.md | 2 +- go.mod | 2 +- go.sum | 4 +-- main.go | 4 +-- proxy/config.go | 2 +- proxy/proxy.go | 12 ++++----- proxy/sema.go | 58 ---------------------------------------- proxy/server.go | 8 +++--- proxy/server_dnscrypt.go | 16 +++++++---- proxy/server_quic.go | 42 ++++++++++++++++++++--------- proxy/server_tcp.go | 16 ++++++++--- proxy/server_udp.go | 17 +++++++++--- 12 files changed, 82 insertions(+), 101 deletions(-) delete mode 100644 proxy/sema.go diff --git a/README.md b/README.md index 7c515ca39..f636d2514 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ Application Options: --ipv6-disabled If specified, all AAAA requests will be replied with NoError RCode and empty answer --bogus-nxdomain= Transform the responses containing at least a single IP that matches specified addresses and CIDRs into NXDOMAIN. Can be specified multiple times. --udp-buf-size= Set the size of the UDP buffer in bytes. A value <= 0 will use the system default. - --max-go-routines= Set the maximum number of go routines. A value <= 0 will not not set a maximum. + --max-go-routines= Set the maximum number of go routines. A zero value will not not set a maximum. --pprof If present, exposes pprof information on localhost:6060. --version Prints the program version diff --git a/go.mod b/go.mod index 011b011cf..1aef0f28b 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/dnsproxy go 1.20 require ( - github.com/AdguardTeam/golibs v0.18.0 + github.com/AdguardTeam/golibs v0.18.1 github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnsstamps v1.0.3 github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 diff --git a/go.sum b/go.sum index d9a30c3d3..4fa5ec0e2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/golibs v0.18.0 h1:ckS2YK7t2Ub6UkXl0fnreVaM15Zb07Hh1gmFqttjpWg= -github.com/AdguardTeam/golibs v0.18.0/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= +github.com/AdguardTeam/golibs v0.18.1 h1:6u0fvrIj2qjUsRdbIGJ9AR0g5QRSWdKIo/DYl3tp5aM= +github.com/AdguardTeam/golibs v0.18.1/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= diff --git a/main.go b/main.go index 8e5b731e2..dc53a047d 100644 --- a/main.go +++ b/main.go @@ -196,8 +196,8 @@ type Options struct { // UDP buffer size value UDPBufferSize int `yaml:"udp-buf-size" long:"udp-buf-size" description:"Set the size of the UDP buffer in bytes. A value <= 0 will use the system default."` - // The maximum number of go routines - MaxGoRoutines int `yaml:"max-go-routines" long:"max-go-routines" description:"Set the maximum number of go routines. A value <= 0 will not not set a maximum."` + // MaxGoRoutines is the maximum number of goroutines. + MaxGoRoutines uint `yaml:"max-go-routines" long:"max-go-routines" description:"Set the maximum number of go routines. A zero value will not not set a maximum."` // Pprof defines whether the pprof information needs to be exposed via // localhost:6060 or not. diff --git a/proxy/config.go b/proxy/config.go index 93b8831a5..33c0fd296 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -182,7 +182,7 @@ type Config struct { // TODO(a.garipov): Rename this to something like // “MaxDNSRequestGoroutines” in a later major version, as it doesn't // actually limit all goroutines. - MaxGoroutines int + MaxGoroutines uint // The size of the read buffer on the underlying socket. Larger read buffers can handle // larger bursts of requests before packets get dropped. diff --git a/proxy/proxy.go b/proxy/proxy.go index b6e2740c6..5aa8b4bd1 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -19,6 +19,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/syncutil" "github.com/ameshkov/dnscrypt/v2" "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" @@ -166,7 +167,7 @@ type Proxy struct { // RWMutex protects the whole proxy. sync.RWMutex - // requestGoroutinesSema limits the number of simultaneous requests. + // requestsSema limits the number of simultaneous requests. // // TODO(a.garipov): Currently we have to pass this exact semaphore to // the workers, to prevent races on restart. In the future we will need @@ -174,7 +175,7 @@ type Proxy struct { // states. // // See also: https://github.com/AdguardTeam/AdGuardHome/issues/2242. - requestGoroutinesSema semaphore + requestsSema syncutil.Semaphore // Config is the proxy configuration. // @@ -195,12 +196,9 @@ func (p *Proxy) Init() (err error) { if p.MaxGoroutines > 0 { log.Info("dnsproxy: max goroutines is set to %d", p.MaxGoroutines) - p.requestGoroutinesSema, err = newChanSemaphore(p.MaxGoroutines) - if err != nil { - return fmt.Errorf("can't init semaphore: %w", err) - } + p.requestsSema = syncutil.NewChanSemaphore(p.MaxGoroutines) } else { - p.requestGoroutinesSema = newNoopSemaphore() + p.requestsSema = syncutil.EmptySemaphore{} } p.udpOOBSize = proxynetutil.UDPGetOOBSize() diff --git a/proxy/sema.go b/proxy/sema.go deleted file mode 100644 index a909f40e3..000000000 --- a/proxy/sema.go +++ /dev/null @@ -1,58 +0,0 @@ -package proxy - -import ( - "fmt" -) - -// semaphore is the semaphore interface. acquire will block until the -// resource can be acquired. release never blocks. -type semaphore interface { - acquire() - release() -} - -// noopSemaphore is a semaphore that has no limit. -type noopSemaphore struct{} - -// acquire implements the semaphore interface for noopSemaphore. -func (noopSemaphore) acquire() {} - -// release implements the semaphore interface for noopSemaphore. -func (noopSemaphore) release() {} - -// newNoopSemaphore returns a new noopSemaphore. -func newNoopSemaphore() (s semaphore) { return noopSemaphore{} } - -// sig is an alias for struct{} to type less. -type sig = struct{} - -// chanSemaphore is a channel-based semaphore. -type chanSemaphore struct { - c chan sig -} - -// acquire implements the semaphore interface for *chanSemaphore. -func (c *chanSemaphore) acquire() { - c.c <- sig{} -} - -// release implements the semaphore interface for *chanSemaphore. -func (c *chanSemaphore) release() { - select { - case <-c.c: - default: - } -} - -// newChanSemaphore returns a new chanSemaphore with the provided -// maximum resource number. maxRes must be greater than zero. -func newChanSemaphore(maxRes int) (s semaphore, err error) { - if maxRes < 1 { - return nil, fmt.Errorf("bad maxRes: %d", maxRes) - } - - s = &chanSemaphore{ - c: make(chan sig, maxRes), - } - return s, nil -} diff --git a/proxy/server.go b/proxy/server.go index 68bebe19a..a01b2bc78 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -44,15 +44,15 @@ func (p *Proxy) startListeners(ctx context.Context) error { } for _, l := range p.udpListen { - go p.udpPacketLoop(l, p.requestGoroutinesSema) + go p.udpPacketLoop(l, p.requestsSema) } for _, l := range p.tcpListen { - go p.tcpPacketLoop(l, ProtoTCP, p.requestGoroutinesSema) + go p.tcpPacketLoop(l, ProtoTCP, p.requestsSema) } for _, l := range p.tlsListen { - go p.tcpPacketLoop(l, ProtoTLS, p.requestGoroutinesSema) + go p.tcpPacketLoop(l, ProtoTLS, p.requestsSema) } for _, l := range p.httpsListen { @@ -64,7 +64,7 @@ func (p *Proxy) startListeners(ctx context.Context) error { } for _, l := range p.quicListen { - go p.quicPacketLoop(l, p.requestGoroutinesSema) + go p.quicPacketLoop(l, p.requestsSema) } for _, l := range p.dnsCryptUDPListen { diff --git a/proxy/server_dnscrypt.go b/proxy/server_dnscrypt.go index cdd1abeae..92d04d1f2 100644 --- a/proxy/server_dnscrypt.go +++ b/proxy/server_dnscrypt.go @@ -1,12 +1,14 @@ package proxy import ( + "context" "fmt" "net" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/syncutil" "github.com/ameshkov/dnscrypt/v2" "github.com/miekg/dns" ) @@ -28,7 +30,7 @@ func (p *Proxy) createDNSCryptListeners() (err error) { Handler: &dnsCryptHandler{ proxy: p, - requestGoroutinesSema: p.requestGoroutinesSema, + reqSema: p.requestsSema, }, } @@ -61,20 +63,24 @@ func (p *Proxy) createDNSCryptListeners() (err error) { type dnsCryptHandler struct { proxy *Proxy - requestGoroutinesSema semaphore + reqSema syncutil.Semaphore } // compile-time type check var _ dnscrypt.Handler = &dnsCryptHandler{} // ServeDNS - processes the DNS query -func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) error { +func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) (err error) { d := h.proxy.newDNSContext(ProtoDNSCrypt, req) d.Addr = netutil.NetAddrToAddrPort(rw.RemoteAddr()) d.DNSCryptResponseWriter = rw - h.requestGoroutinesSema.acquire() - defer h.requestGoroutinesSema.release() + // TODO(d.kolyshev): Pass and use context from above. + err = h.reqSema.Acquire(context.Background()) + if err != nil { + return fmt.Errorf("dnsproxy: dnscrypt: acquiring semaphore: %w", err) + } + defer h.reqSema.Release() return h.proxy.handleDNSRequest(d) } diff --git a/proxy/server_quic.go b/proxy/server_quic.go index 9936b18a0..6695d3f55 100644 --- a/proxy/server_quic.go +++ b/proxy/server_quic.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/syncutil" "github.com/bluele/gcache" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -82,11 +83,12 @@ func (p *Proxy) createQUICListeners() error { // quicPacketLoop listens for incoming QUIC packets. // -// See also the comment on Proxy.requestGoroutinesSema. -func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, requestGoroutinesSema semaphore) { +// See also the comment on Proxy.requestsSema. +func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, reqSema syncutil.Semaphore) { log.Info("Entering the DNS-over-QUIC listener loop on %s", l.Addr()) for { - conn, err := l.Accept(context.Background()) + ctx := context.Background() + conn, err := l.Accept(ctx) if err != nil { if isQUICErrorForDebugLog(err) { log.Debug("accepting quic conn: closed or timed out: %s", err) @@ -97,10 +99,16 @@ func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, requestGoroutinesSema sema break } - requestGoroutinesSema.acquire() + err = reqSema.Acquire(ctx) + if err != nil { + log.Error("dnsproxy: quic: acquiring semaphore: %s", err) + + break + } go func() { - p.handleQUICConnection(conn, requestGoroutinesSema) - requestGoroutinesSema.release() + defer reqSema.Release() + + p.handleQUICConnection(conn, reqSema) }() } } @@ -108,15 +116,17 @@ func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, requestGoroutinesSema sema // handleQUICConnection handles a new QUIC connection. It waits for new streams // and passes them to handleQUICStream. // -// See also the comment on Proxy.requestGoroutinesSema. -func (p *Proxy) handleQUICConnection(conn quic.Connection, requestGoroutinesSema semaphore) { +// See also the comment on Proxy.requestsSema. +func (p *Proxy) handleQUICConnection(conn quic.Connection, reqSema syncutil.Semaphore) { for { + ctx := context.Background() + // The stub to resolver DNS traffic follows a simple pattern in which // the client sends a query, and the server provides a response. This // design specifies that for each subsequent query on a QUIC connection // the client MUST select the next available client-initiated // bidirectional stream. - stream, err := conn.AcceptStream(context.Background()) + stream, err := conn.AcceptStream(ctx) if err != nil { if isQUICErrorForDebugLog(err) { log.Debug("accepting quic stream: closed or timed out: %s", err) @@ -130,16 +140,24 @@ func (p *Proxy) handleQUICConnection(conn quic.Connection, requestGoroutinesSema return } - requestGoroutinesSema.acquire() + err = reqSema.Acquire(ctx) + if err != nil { + log.Error("dnsproxy: quic: acquiring semaphore: %s", err) + + // Close the connection to make sure resources are freed. + closeQUICConn(conn, DoQCodeNoError) + + return + } go func() { + defer reqSema.Release() + p.handleQUICStream(stream, conn) // The server MUST send the response(s) on the same stream and MUST // indicate, after the last response, through the STREAM FIN // mechanism that no further data will be sent on that stream. _ = stream.Close() - - requestGoroutinesSema.release() }() } } diff --git a/proxy/server_tcp.go b/proxy/server_tcp.go index 780123786..2edfe7473 100644 --- a/proxy/server_tcp.go +++ b/proxy/server_tcp.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/syncutil" "github.com/miekg/dns" ) @@ -60,8 +61,8 @@ func (p *Proxy) createTLSListeners() (err error) { // tcpPacketLoop listens for incoming TCP packets. proto must be either "tcp" // or "tls". // -// See also the comment on Proxy.requestGoroutinesSema. -func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, requestGoroutinesSema semaphore) { +// See also the comment on Proxy.requestsSema. +func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, reqSema syncutil.Semaphore) { log.Info("dnsproxy: entering %s listener loop on %s", proto, l.Addr()) for { @@ -76,10 +77,17 @@ func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, requestGoroutinesSema break } - requestGoroutinesSema.acquire() + // TODO(d.kolyshev): Pass and use context from above. + err = reqSema.Acquire(context.Background()) + if err != nil { + log.Error("dnsproxy: tcp: acquiring semaphore: %s", err) + + break + } go func() { + defer reqSema.Release() + p.handleTCPConnection(clientConn, proto) - requestGoroutinesSema.release() }() } } diff --git a/proxy/server_udp.go b/proxy/server_udp.go index 677f023e6..672ad1539 100644 --- a/proxy/server_udp.go +++ b/proxy/server_udp.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/syncutil" "github.com/miekg/dns" ) @@ -60,8 +61,8 @@ func (p *Proxy) udpCreate(ctx context.Context, udpAddr *net.UDPAddr) (*net.UDPCo // udpPacketLoop listens for incoming UDP packets. // -// See also the comment on Proxy.requestGoroutinesSema. -func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema semaphore) { +// See also the comment on Proxy.requestsSema. +func (p *Proxy) udpPacketLoop(conn *net.UDPConn, reqSema syncutil.Semaphore) { log.Info("dnsproxy: entering udp listener loop on %s", conn.LocalAddr()) b := make([]byte, dns.MaxMsgSize) @@ -79,10 +80,18 @@ func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema semaphore // we need the contents to survive the call because we're handling them in goroutine packet := make([]byte, n) copy(packet, b) - requestGoroutinesSema.acquire() + + // TODO(d.kolyshev): Pass and use context from above. + sErr := reqSema.Acquire(context.Background()) + if sErr != nil { + log.Error("dnsproxy: udp: acquiring semaphore: %s", sErr) + + break + } go func() { + defer reqSema.Release() + p.udpHandlePacket(packet, localIP, remoteAddr, conn) - requestGoroutinesSema.release() }() } if err != nil { From 06d548fc2fb39145c2a56a81a4073bbf1eb04c39 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Mon, 18 Dec 2023 18:40:53 +0300 Subject: [PATCH 2/7] Pull request 309: 6480 imp load balance Updates AdguardTeam/AdGuardHome#6480. Squashed commit of the following: commit 053c830f80f3369dca8a383aedc9833291dae56a Merge: 8f6b648 638288e Author: Eugene Burkov Date: Mon Dec 18 18:27:31 2023 +0300 Merge branch 'master' into 6480-imp-load-balance commit 8f6b648490cd575d8746dfd72b02a20ac09f65e0 Author: Eugene Burkov Date: Mon Dec 18 15:30:28 2023 +0300 proxy: imp docs commit 7927132f7aa0728511587c31452b6aa134a8a728 Author: Eugene Burkov Date: Mon Dec 18 15:17:41 2023 +0300 proxy: unexport clock commit 5cbb56e09035274bd3170e1c1e239b708b5e2a2b Author: Eugene Burkov Date: Mon Dec 18 15:01:01 2023 +0300 proxy: imp code commit df5863b326ec341a800502699e10513adbf59802 Author: Eugene Burkov Date: Mon Dec 18 12:47:34 2023 +0300 proxy: return algo commit 2ad8274fdfb2b96ad2f672105d36266d5ee4675a Author: Eugene Burkov Date: Fri Dec 15 19:57:14 2023 +0300 proxy: imp code, algo commit 068bcd89cad9bc54bea658aba3f1cf6e57159491 Author: Eugene Burkov Date: Fri Dec 15 16:18:48 2023 +0300 proxy: imp code, change algo commit 932f00f358e0c2e1c1b699106c9c0c16674ed1a7 Author: Eugene Burkov Date: Thu Dec 14 18:50:02 2023 +0300 proxy: imp docs commit 62d9b57515828290a611801caebca1beeb0ce2c3 Author: Eugene Burkov Date: Thu Dec 14 16:36:25 2023 +0300 proxy: imp test commit 90aa395ec644b64263aa589ab8c5c363c3d021ec Author: Eugene Burkov Date: Thu Dec 14 15:00:37 2023 +0300 proxy: imp tests commit 69ba821d0b86fc3536d7d6212b703f9c2ddba10c Author: Eugene Burkov Date: Wed Dec 13 19:26:56 2023 +0300 proxy: imp algo commit d7bf7b9ef4c1edcb73f69d96a68c9ece38c675ea Author: Eugene Burkov Date: Thu Dec 7 19:28:53 2023 +0300 all: fix nil deref commit bb4ceb66a769eb4ed2c89a99936bec6a4467759e Merge: cc26532 d6ebaac Author: Eugene Burkov Date: Thu Dec 7 19:16:34 2023 +0300 Merge branch 'master' into 6480-test-load-balance commit cc2653274cf2bbf8ebdc787a399faa585d52b6e5 Author: Eugene Burkov Date: Thu Dec 7 18:46:17 2023 +0300 proxy: test load balancing --- go.mod | 1 + go.sum | 2 + proxy/cache_test.go | 7 +- proxy/clock.go | 21 +++ proxy/dns64.go | 2 +- proxy/exchange.go | 146 +++++++++++++------- proxy/exchange_internal_test.go | 229 ++++++++++++++++++++++++++++++++ proxy/proxy.go | 21 ++- proxy/proxy_test.go | 94 ++++--------- 9 files changed, 392 insertions(+), 131 deletions(-) create mode 100644 proxy/clock.go create mode 100644 proxy/exchange_internal_test.go diff --git a/go.mod b/go.mod index 1aef0f28b..d64b2a0d0 100644 --- a/go.mod +++ b/go.mod @@ -36,5 +36,6 @@ require ( golang.org/x/mod v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.13.0 // indirect + gonum.org/v1/gonum v0.14.0 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect ) diff --git a/go.sum b/go.sum index 4fa5ec0e2..3e0a8f0bd 100644 --- a/go.sum +++ b/go.sum @@ -68,6 +68,8 @@ golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= +gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= diff --git a/proxy/cache_test.go b/proxy/cache_test.go index d607d0729..d3ad46a6d 100644 --- a/proxy/cache_test.go +++ b/proxy/cache_test.go @@ -22,9 +22,10 @@ const testCacheSize = 4096 const testUpsAddr = "https://upstream.address" -var upstreamWithAddr = &funcUpstream{ - exchangeFunc: func(m *dns.Msg) (resp *dns.Msg, err error) { panic("not implemented") }, - addressFunc: func() (addr string) { return testUpsAddr }, +var upstreamWithAddr = &fakeUpstream{ + onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) { panic("not implemented") }, + onClose: func() (err error) { panic("not implemented") }, + onAddress: func() (addr string) { return testUpsAddr }, } func TestServeCached(t *testing.T) { diff --git a/proxy/clock.go b/proxy/clock.go new file mode 100644 index 000000000..3727dbd63 --- /dev/null +++ b/proxy/clock.go @@ -0,0 +1,21 @@ +package proxy + +import "time" + +// clock is the interface for provider of current time. It's used to simplify +// testing. +// +// TODO(e.burkov): Move to golibs. +type clock interface { + // Now returns the current local time. + Now() (now time.Time) +} + +// type check +var _ clock = realClock{} + +// realClock is the [clock] which actually uses the [time] package. +type realClock struct{} + +// Now implements the [clock] interface for RealClock. +func (realClock) Now() (now time.Time) { return time.Now() } diff --git a/proxy/dns64.go b/proxy/dns64.go index 88d06d9bd..89960fc48 100644 --- a/proxy/dns64.go +++ b/proxy/dns64.go @@ -312,7 +312,7 @@ func (p *Proxy) performDNS64( host := origReq.Question[0].Name log.Debug("proxy: received an empty aaaa response for %q, checking dns64", host) - dns64Resp, u, err := p.exchange(dns64Req, upstreams) + dns64Resp, u, err := p.exchangeUpstreams(dns64Req, upstreams) if err != nil { log.Error("proxy: dns64 request failed: %s", err) diff --git a/proxy/exchange.go b/proxy/exchange.go index bbe259710..4323e7c69 100644 --- a/proxy/exchange.go +++ b/proxy/exchange.go @@ -7,70 +7,72 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" - "golang.org/x/exp/slices" + "gonum.org/v1/gonum/stat/sampleuv" ) -// exchange -- sends DNS query to the upstream DNS server and returns the response -func (p *Proxy) exchange(req *dns.Msg, upstreams []upstream.Upstream) (reply *dns.Msg, u upstream.Upstream, err error) { - qtype := req.Question[0].Qtype - if p.UpstreamMode == UModeFastestAddr && (qtype == dns.TypeA || qtype == dns.TypeAAAA) { - reply, u, err = p.fastestAddr.ExchangeFastest(req, upstreams) - return - } - - if p.UpstreamMode == UModeParallel { - reply, u, err = upstream.ExchangeParallel(upstreams, req) - return +// exchangeUpstreams resolves req using the given upstreams. It returns the DNS +// response, the upstream that successfully resolved the request, and the error +// if any. +func (p *Proxy) exchangeUpstreams( + req *dns.Msg, + ups []upstream.Upstream, +) (resp *dns.Msg, u upstream.Upstream, err error) { + switch p.UpstreamMode { + case UModeParallel: + return upstream.ExchangeParallel(ups, req) + case UModeFastestAddr: + switch req.Question[0].Qtype { + case dns.TypeA, dns.TypeAAAA: + return p.fastestAddr.ExchangeFastest(req, ups) + default: + // Go on to the load-balancing mode. + } + default: + // Go on to the load-balancing mode. } - // UModeLoadBalance goes below + if len(ups) == 1 { + u = ups[0] + resp, _, err = exchange(u, req, p.time) + // TODO(e.burkov): p.updateRTT(u.Address(), elapsed) - if len(upstreams) == 1 { - u = upstreams[0] - reply, _, err = exchangeWithUpstream(u, req) - return + return resp, u, err } - // sort upstreams by rtt from fast to slow - sortedUpstreams := p.getSortedUpstreams(upstreams) + w := sampleuv.NewWeighted(p.calcWeights(ups), p.randSrc) + var errs []error + for i, ok := w.Take(); ok; i, ok = w.Take() { + u = ups[i] - errs := []error{} - for _, dnsUpstream := range sortedUpstreams { - var elapsed int - reply, elapsed, err = exchangeWithUpstream(dnsUpstream, req) + var elapsed time.Duration + resp, elapsed, err = exchange(u, req, p.time) if err == nil { - p.updateRTT(dnsUpstream.Address(), elapsed) + p.updateRTT(u.Address(), elapsed) - return reply, dnsUpstream, err + return resp, u, nil } errs = append(errs, err) - p.updateRTT(dnsUpstream.Address(), int(defaultTimeout/time.Millisecond)) + + // TODO(e.burkov): Use the actual configured timeout or, perhaps, the + // actual measured elapsed time. + p.updateRTT(u.Address(), defaultTimeout) } + // TODO(e.burkov): Use [errors.Join]. return nil, nil, errors.List("all upstreams failed to exchange request", errs...) } -func (p *Proxy) getSortedUpstreams(u []upstream.Upstream) []upstream.Upstream { - // clone upstreams list to avoid race conditions - clone := slices.Clone(u) - - p.rttLock.Lock() - defer p.rttLock.Unlock() +// exchange returns the result of the DNS request exchange with the given +// upstream and the elapsed time in milliseconds. It uses the given clock to +// measure the request duration. +func exchange(u upstream.Upstream, req *dns.Msg, c clock) (resp *dns.Msg, dur time.Duration, err error) { + startTime := c.Now() - slices.SortFunc(clone, func(a, b upstream.Upstream) (res int) { - // TODO(d.kolyshev): Use upstreams for sort comparing. - return p.upstreamRTTStats[a.Address()] - p.upstreamRTTStats[b.Address()] - }) - - return clone -} - -// exchangeWithUpstream returns result of Exchange with elapsed time -func exchangeWithUpstream(u upstream.Upstream, req *dns.Msg) (*dns.Msg, int, error) { - startTime := time.Now() reply, err := u.Exchange(req) - elapsed := time.Since(startTime) + + // Don't use [time.Since] because it uses [time.Now]. + dur = c.Now().Sub(startTime) addr := u.Address() if err != nil { @@ -78,7 +80,7 @@ func exchangeWithUpstream(u upstream.Upstream, req *dns.Msg) (*dns.Msg, int, err "dnsproxy: upstream %s failed to exchange %s in %s: %s", addr, req.Question[0].String(), - elapsed, + dur, err, ) } else { @@ -86,21 +88,63 @@ func exchangeWithUpstream(u upstream.Upstream, req *dns.Msg) (*dns.Msg, int, err "dnsproxy: upstream %s successfully finished exchange of %s; elapsed %s", addr, req.Question[0].String(), - elapsed, + dur, ) } - return reply, int(elapsed.Milliseconds()), err + return reply, dur, err +} + +// upstreamRTTStats is the statistics for a single upstream's round-trip time. +type upstreamRTTStats struct { + // rttSum is the sum of all the round-trip times in microseconds. The + // float64 type is used since it's capable of representing about 285 years + // in microseconds. + rttSum float64 + + // reqNum is the number of requests to the upstream. The float64 type is + // used since to avoid unnecessary type conversions. + reqNum float64 +} + +// update returns updated stats after adding given RTT. +func (stats upstreamRTTStats) update(rtt time.Duration) (updated upstreamRTTStats) { + return upstreamRTTStats{ + rttSum: stats.rttSum + float64(rtt.Microseconds()), + reqNum: stats.reqNum + 1, + } +} + +// calcWeights returns the slice of weights, each corresponding to the upstream +// with the same index in the given slice. +func (p *Proxy) calcWeights(ups []upstream.Upstream) (weights []float64) { + weights = make([]float64, 0, len(ups)) + + p.rttLock.Lock() + defer p.rttLock.Unlock() + + for _, u := range ups { + stat := p.upstreamRTTStats[u.Address()] + if stat.rttSum == 0 || stat.reqNum == 0 { + // Use 1 as the default weight. + weights = append(weights, 1) + } else { + weights = append(weights, 1/(stat.rttSum/stat.reqNum)) + } + } + + return weights } -// updateRTT updates the round-trip time in upstreamRTTStats for given address. -func (p *Proxy) updateRTT(address string, rtt int) { +// updateRTT updates the round-trip time in [upstreamRTTStats] for given +// address. +func (p *Proxy) updateRTT(address string, rtt time.Duration) { p.rttLock.Lock() defer p.rttLock.Unlock() if p.upstreamRTTStats == nil { - p.upstreamRTTStats = map[string]int{} + p.upstreamRTTStats = map[string]upstreamRTTStats{} } - p.upstreamRTTStats[address] = (p.upstreamRTTStats[address] + rtt) / 2 + p.upstreamRTTStats[address] = p.upstreamRTTStats[address].update(rtt) } diff --git a/proxy/exchange_internal_test.go b/proxy/exchange_internal_test.go new file mode 100644 index 000000000..3a890bdb8 --- /dev/null +++ b/proxy/exchange_internal_test.go @@ -0,0 +1,229 @@ +package proxy + +import ( + "net/netip" + "sync" + "testing" + "time" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/netutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "golang.org/x/exp/rand" +) + +// fakeClock is the function-based implementation of the [clock] interface. +type fakeClock struct { + onNow func() (now time.Time) +} + +// type check +var _ clock = (*fakeClock)(nil) + +// Now implements the [clock] interface for *fakeClock. +func (c *fakeClock) Now() (now time.Time) { return c.onNow() } + +// newUpstreamWithErrorRate returns an [upstream.Upstream] that responds with an +// error every [rate] requests. The returned upstream isn't safe for concurrent +// use. +func newUpstreamWithErrorRate(rate uint, name string) (u upstream.Upstream) { + var n uint + + return &fakeUpstream{ + onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + n++ + if n%rate == 0 { + return nil, assert.AnError + } + + return (&dns.Msg{}).SetReply(req), nil + }, + onAddress: func() (addr string) { return name }, + onClose: func() (_ error) { panic("not implemented") }, + } +} + +// measuredUpstream is an [upstream.Upstream] that increments the counter every +// time it's used. +type measuredUpstream struct { + // Upstream is embedded here to avoid implementing all the methods. + upstream.Upstream + + // stats is the statistics collector for current upstream. + stats map[string]int64 +} + +// type check +var _ upstream.Upstream = measuredUpstream{} + +// Exchange implements the [upstream.Upstream] interface for measuredUpstream. +func (u measuredUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + u.stats[u.Address()]++ + + return u.Upstream.Exchange(req) +} + +func TestProxy_Exchange_loadBalance(t *testing.T) { + // Make the test deterministic. + randSrc := rand.NewSource(42) + + const ( + testRTT = 1 * time.Second + requestsNum = 10_000 + ) + + // zeroingClock returns the value of currentNow and sets it back to + // zeroTime, so that all the calls since the second one return the same zero + // value until currentNow is modified elsewhere. + zeroTime := time.Unix(0, 0) + currentNow := zeroTime + zeroingClock := &fakeClock{ + onNow: func() (now time.Time) { + now, currentNow = currentNow, zeroTime + + return now + }, + } + constClock := &fakeClock{ + onNow: func() (now time.Time) { + now, currentNow = currentNow, currentNow.Add(testRTT/50) + + return now + }, + } + + fastUps := &fakeUpstream{ + onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + currentNow = zeroTime.Add(testRTT / 100) + + return (&dns.Msg{}).SetReply(req), nil + }, + onAddress: func() (addr string) { return "fast" }, + onClose: func() (_ error) { panic("not implemented") }, + } + slowerUps := &fakeUpstream{ + onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + currentNow = zeroTime.Add(testRTT / 10) + + return (&dns.Msg{}).SetReply(req), nil + }, + onAddress: func() (addr string) { return "slower" }, + onClose: func() (_ error) { panic("not implemented") }, + } + slowestUps := &fakeUpstream{ + onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + currentNow = zeroTime.Add(testRTT / 2) + + return (&dns.Msg{}).SetReply(req), nil + }, + onAddress: func() (addr string) { return "slowest" }, + onClose: func() (_ error) { panic("not implemented") }, + } + + err1Ups := &fakeUpstream{ + onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError }, + onAddress: func() (addr string) { return "error1" }, + onClose: func() (_ error) { panic("not implemented") }, + } + err2Ups := &fakeUpstream{ + onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError }, + onAddress: func() (addr string) { return "error2" }, + onClose: func() (_ error) { panic("not implemented") }, + } + + singleError := &sync.Once{} + // fastestUps responds with an error on the first request. + fastestUps := &fakeUpstream{ + onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + singleError.Do(func() { err = assert.AnError }) + currentNow = zeroTime.Add(testRTT / 200) + + return (&dns.Msg{}).SetReply(req), err + }, + onAddress: func() (addr string) { return "fastest" }, + onClose: func() (_ error) { panic("not implemented") }, + } + + each200 := newUpstreamWithErrorRate(200, "each_200") + each100 := newUpstreamWithErrorRate(100, "each_100") + each50 := newUpstreamWithErrorRate(50, "each_50") + + testCases := []struct { + wantStat map[string]int64 + clock clock + name string + servers []upstream.Upstream + }{{ + wantStat: map[string]int64{ + fastUps.Address(): 8917, + slowerUps.Address(): 911, + slowestUps.Address(): 172, + }, + clock: zeroingClock, + name: "all_good", + servers: []upstream.Upstream{slowestUps, slowerUps, fastUps}, + }, { + wantStat: map[string]int64{ + fastUps.Address(): 9081, + slowerUps.Address(): 919, + err1Ups.Address(): 7, + }, + clock: zeroingClock, + name: "one_bad", + servers: []upstream.Upstream{fastUps, err1Ups, slowerUps}, + }, { + wantStat: map[string]int64{ + err1Ups.Address(): requestsNum, + err2Ups.Address(): requestsNum, + }, + clock: zeroingClock, + name: "all_bad", + servers: []upstream.Upstream{err2Ups, err1Ups}, + }, { + wantStat: map[string]int64{ + fastUps.Address(): 7803, + slowerUps.Address(): 833, + fastestUps.Address(): 1365, + }, + clock: zeroingClock, + name: "error_once", + servers: []upstream.Upstream{fastUps, slowerUps, fastestUps}, + }, { + wantStat: map[string]int64{ + each200.Address(): 5316, + each100.Address(): 3090, + each50.Address(): 1683, + }, + clock: constClock, + name: "error_each_nth", + servers: []upstream.Upstream{each200, each100, each50}, + }} + + req := createTestMessage() + cli := netip.AddrPortFrom(netutil.IPv4Localhost(), 1234) + + for _, tc := range testCases { + p := createTestProxy(t, nil) + p.UpstreamConfig.Upstreams = nil + p.time = tc.clock + p.randSrc = randSrc + wantStat := tc.wantStat + + stats := map[string]int64{} + for _, s := range tc.servers { + p.UpstreamConfig.Upstreams = append(p.UpstreamConfig.Upstreams, measuredUpstream{ + Upstream: s, + stats: stats, + }) + } + + t.Run(tc.name, func(t *testing.T) { + for i := 0; i < requestsNum; i++ { + _ = p.Resolve(&DNSContext{Req: req, Addr: cli}) + } + + assert.Equal(t, wantStat, stats) + }) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 5aa8b4bd1..d99136de7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -25,6 +25,7 @@ import ( gocache "github.com/patrickmn/go-cache" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" + "golang.org/x/exp/rand" "golang.org/x/exp/slices" ) @@ -57,7 +58,8 @@ const ( UnqualifiedNames = "unqualified_names" ) -// Proxy combines the proxy server state and configuration +// Proxy combines the proxy server state and configuration. It must not be used +// until initialized with [Proxy.Init]. // // TODO(a.garipov): Consider extracting conf blocks for better fieldalignment. type Proxy struct { @@ -110,9 +112,10 @@ type Proxy struct { // Upstream // -- - // upstreamRTTStats is a map of upstream addresses and their rtt. Used to - // sort upstreams by their latency. - upstreamRTTStats map[string]int + // upstreamRTTStats maps the upstream address to its round-trip time + // statistics. It's holds the statistics for all upstreams to perform a + // weighted random selection when using the load balancing mode. + upstreamRTTStats map[string]upstreamRTTStats // rttLock protects upstreamRTTStats. rttLock sync.Mutex @@ -177,6 +180,12 @@ type Proxy struct { // See also: https://github.com/AdguardTeam/AdGuardHome/issues/2242. requestsSema syncutil.Semaphore + // time provides the current time. + time clock + + // randSrc provides the source of randomness. + randSrc rand.Source + // Config is the proxy configuration. // // TODO(a.garipov): Remove this embed and create a proper initializer. @@ -236,6 +245,8 @@ func (p *Proxy) Init() (err error) { p.RatelimitWhitelist = slices.Clone(p.RatelimitWhitelist) slices.SortFunc(p.RatelimitWhitelist, netip.Addr.Compare) + p.time = realClock{} + return nil } @@ -537,7 +548,7 @@ func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) { src := "upstream" // Perform the DNS request. - resp, u, err := p.exchange(req, upstreams) + resp, u, err := p.exchangeUpstreams(req, upstreams) if dns64Ups := p.performDNS64(req, resp, upstreams); dns64Ups != nil { u = dns64Ups } else if p.isBogusNXDomain(resp) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 53d7b4f44..ba85db0f2 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -201,7 +201,7 @@ func TestProxy_Resolve_dnssecCache(t *testing.T) { Signature: "c29tZSBycnNpZyByZWxhdGVkIHN0dWZm", } - p := &Proxy{} + p := createTestProxy(t, nil) p.UpstreamConfig = &UpstreamConfig{ Upstreams: []upstream.Upstream{&testDNSSECUpstream{ a: a, @@ -309,47 +309,6 @@ func TestProxy_Resolve_dnssecCache(t *testing.T) { } } -func TestUpstreamsSort(t *testing.T) { - testProxy := createTestProxy(t, nil) - upstreams := []upstream.Upstream{} - - // there are 4 upstreams in configuration - config := []string{"1.2.3.4", "1.1.1.1", "2.3.4.5", "8.8.8.8"} - for _, u := range config { - up, err := upstream.AddressToUpstream(u, &upstream.Options{Timeout: 1 * time.Second}) - if err != nil { - t.Fatalf("Failed to create %s upstream: %s", u, err) - } - upstreams = append(upstreams, up) - } - - upstreamRTTStats := map[string]int{} - upstreamRTTStats["1.1.1.1:53"] = 10 - upstreamRTTStats["2.3.4.5:53"] = 20 - upstreamRTTStats["1.2.3.4:53"] = 30 - testProxy.upstreamRTTStats = upstreamRTTStats - - sortedUpstreams := testProxy.getSortedUpstreams(upstreams) - - // upstream without rtt stats means `zero rtt`; this upstream should be the first one after sorting - if sortedUpstreams[0].Address() != "8.8.8.8:53" { - t.Fatalf("wrong sort algorithm!") - } - - // upstreams with rtt stats should be sorted from fast to slow - if sortedUpstreams[1].Address() != "1.1.1.1:53" { - t.Fatalf("wrong sort algorithm!") - } - - if sortedUpstreams[2].Address() != "2.3.4.5:53" { - t.Fatalf("wrong sort algorithm!") - } - - if sortedUpstreams[3].Address() != "1.2.3.4:53" { - t.Fatalf("wrong sort algorithm!") - } -} - func TestExchangeWithReservedDomains(t *testing.T) { dnsProxy := createTestProxy(t, nil) @@ -786,38 +745,25 @@ func TestNoQuestion(t *testing.T) { assert.Equal(t, dns.RcodeServerFailure, r.Rcode) } -// funcUpstream is a mock upstream implementation to simplify testing. It +// fakeUpstream is a mock upstream implementation to simplify testing. It // allows assigning custom Exchange and Address methods. -type funcUpstream struct { - exchangeFunc func(m *dns.Msg) (resp *dns.Msg, err error) - addressFunc func() (addr string) +type fakeUpstream struct { + onExchange func(m *dns.Msg) (resp *dns.Msg, err error) + onAddress func() (addr string) + onClose func() (err error) } // type check -var _ upstream.Upstream = (*funcUpstream)(nil) +var _ upstream.Upstream = (*fakeUpstream)(nil) // Exchange implements upstream.Upstream interface for *funcUpstream. -func (wu *funcUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { - if wu.exchangeFunc == nil { - return nil, nil - } - - return wu.exchangeFunc(m) -} +func (u *fakeUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return u.onExchange(m) } // Address implements upstream.Upstream interface for *funcUpstream. -func (wu *funcUpstream) Address() (addr string) { - if wu.addressFunc == nil { - return "stub" - } - - return wu.addressFunc() -} +func (u *fakeUpstream) Address() (addr string) { return u.onAddress() } // Close implements upstream.Upstream interface for *funcUpstream. -func (wu *funcUpstream) Close() (err error) { - return nil -} +func (u *fakeUpstream) Close() (err error) { return u.onClose() } func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) { dnsProxy := createTestProxy(t, nil) @@ -841,8 +787,10 @@ func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) { return resp, nil } - u := &funcUpstream{ - exchangeFunc: exchangeFunc, + u := &fakeUpstream{ + onExchange: exchangeFunc, + onAddress: func() (addr string) { return "stub" }, + onClose: func() error { panic("not implemented") }, } d := &DNSContext{ @@ -929,8 +877,10 @@ func TestExchangeCustomUpstreamConfigCache(t *testing.T) { return resp, nil } - u := &funcUpstream{ - exchangeFunc: exchangeFunc, + u := &fakeUpstream{ + onExchange: exchangeFunc, + onAddress: func() (addr string) { return "stub" }, + onClose: func() error { panic("not implemented") }, } customUpstreamConfig := NewCustomUpstreamConfig( @@ -1212,10 +1162,12 @@ func getFreePort() uint { return port } -func createTestProxy(t *testing.T, tlsConfig *tls.Config) *Proxy { +func createTestProxy(t *testing.T, tlsConfig *tls.Config) (p *Proxy) { t.Helper() - p := Proxy{} + p = &Proxy{ + time: realClock{}, + } if ip := net.ParseIP(listenIP); tlsConfig != nil { p.TLSListenAddr = []*net.TCPAddr{{IP: ip, Port: 0}} @@ -1241,7 +1193,7 @@ func createTestProxy(t *testing.T, tlsConfig *tls.Config) *Proxy { p.RatelimitSubnetLenIPv4 = 24 p.RatelimitSubnetLenIPv6 = 64 - return &p + return p } func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) { From a87a3dfd6b737144a16ef92d130319d12184f129 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Fri, 29 Dec 2023 15:27:08 +0300 Subject: [PATCH 3/7] Pull request 311: 6574 timeout log level Updates AdguardTeam/AdGuardHome#6574. Squashed commit of the following: commit 039f567e9f1cb36b189fd60362744c4e10c0a6ce Author: Eugene Burkov Date: Fri Dec 29 16:12:01 2023 +0500 upstream: imp log commit 6643fc0d0d0a581db0e89f109847981aa02b8dd8 Author: Eugene Burkov Date: Thu Dec 28 17:09:23 2023 +0500 upstream: fix net, imp doc commit 648565de0472bbe118f694ab3d9cc662377b47e6 Author: Eugene Burkov Date: Wed Dec 27 15:35:33 2023 +0300 upstream: actually add doc commit baa4e607ae5ccd0a1cf88a18cb71a7f8b43ca366 Author: Eugene Burkov Date: Wed Dec 27 15:34:05 2023 +0300 upstream: fix import, imp doc commit e9591096134bdf35bedc9fc428e956cea0d4d773 Author: Eugene Burkov Date: Wed Dec 27 15:28:41 2023 +0300 upstream: imp logging --- upstream/doh.go | 18 ++++++++--------- upstream/quic.go | 9 +++++++-- upstream/upstream.go | 46 +++++++++++++++++++++++++++++++++++--------- 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/upstream/doh.go b/upstream/doh.go index ad8e24790..e85cb6feb 100644 --- a/upstream/doh.go +++ b/upstream/doh.go @@ -220,10 +220,9 @@ func (p *dnsOverHTTPS) exchangeHTTPS(client *http.Client, req *dns.Msg) (resp *d } logBegin(p.addrRedacted, n, req) - resp, err = p.exchangeHTTPSClient(client, req) - logFinish(p.addrRedacted, n, err) + defer func() { logFinish(p.addrRedacted, n, err) }() - return resp, err + return p.exchangeHTTPSClient(client, req) } // exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified @@ -277,13 +276,12 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient( } if httpResp.StatusCode != http.StatusOK { - return nil, - fmt.Errorf( - "expected status %d, got %d from %s", - http.StatusOK, - httpResp.StatusCode, - p.addrRedacted, - ) + return nil, fmt.Errorf( + "expected status %d, got %d from %s", + http.StatusOK, + httpResp.StatusCode, + p.addrRedacted, + ) } resp = &dns.Msg{} diff --git a/upstream/quic.go b/upstream/quic.go index 9768b2d5c..8e57778f4 100644 --- a/upstream/quic.go +++ b/upstream/quic.go @@ -193,7 +193,12 @@ func (p *dnsOverQUIC) Close() (err error) { // exchangeQUIC attempts to open a QUIC connection, send the DNS message // through it and return the response it got from the server. -func (p *dnsOverQUIC) exchangeQUIC(m *dns.Msg) (resp *dns.Msg, err error) { +func (p *dnsOverQUIC) exchangeQUIC(req *dns.Msg) (resp *dns.Msg, err error) { + addr := p.Address() + + logBegin(addr, networkUDP, req) + defer func() { logFinish(addr, networkUDP, err) }() + var conn quic.Connection conn, err = p.getConnection(true) if err != nil { @@ -201,7 +206,7 @@ func (p *dnsOverQUIC) exchangeQUIC(m *dns.Msg) (resp *dns.Msg, err error) { } var buf []byte - buf, err = m.Pack() + buf, err = req.Pack() if err != nil { return nil, fmt.Errorf("failed to pack DNS message for DoQ: %w", err) } diff --git a/upstream/upstream.go b/upstream/upstream.go index cb2e3f79a..5b147ea76 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -11,12 +11,14 @@ import ( "net" "net/netip" "net/url" + "os" "strconv" "strings" "sync/atomic" "time" "github.com/AdguardTeam/dnsproxy/internal/bootstrap" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/ameshkov/dnscrypt/v2" @@ -274,25 +276,51 @@ func addPort(u *url.URL, port uint16) { // logBegin logs the start of DNS request resolution. It should be called right // before dialing the connection to the upstream. n is the [network] that will // be used to send the request. -func logBegin(upsAddr string, n network, req *dns.Msg) { - qtype := "" - target := "" +func logBegin(addr string, n network, req *dns.Msg) { + var qtype dns.Type + var qname string if len(req.Question) != 0 { - qtype = dns.Type(req.Question[0].Qtype).String() - target = req.Question[0].Name + qtype = dns.Type(req.Question[0].Qtype) + qname = req.Question[0].Name } - log.Debug("dnsproxy: %s: sending request over %s: %s %s", upsAddr, n, qtype, target) + log.Debug("dnsproxy: sending request to %s over %s: %s %q", addr, n, qtype, qname) } -// Write to log about the result of DNS request -func logFinish(upsAddr string, n network, err error) { +// logFinish logs the end of DNS request resolution. It should be called right +// after receiving the response from the upstream or the failing action. n is +// the [network] that was used to send the request. +func logFinish(addr string, n network, err error) { + logRoutine := log.Debug + status := "ok" if err != nil { status = err.Error() + if isTimeout(err) { + // Notify user about the timeout. + logRoutine = log.Error + } } - log.Debug("dnsproxy: %s: response received over %s: %q", upsAddr, n, status) + logRoutine("dnsproxy: %s: response received over %s: %q", addr, n, status) +} + +// isTimeout returns true if err is a timeout error. +// +// TODO(e.burkov): Move to golibs. +func isTimeout(err error) (ok bool) { + var netErr net.Error + switch { + case + errors.Is(err, context.Canceled), + errors.Is(err, context.DeadlineExceeded), + errors.Is(err, os.ErrDeadlineExceeded): + return true + case errors.As(err, &netErr): + return netErr.Timeout() + default: + return false + } } // DialerInitializer returns the handler that it creates. All the subsequent From 04571e60c201e0fc446c148966fe234dde46760d Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Thu, 11 Jan 2024 15:24:51 +0300 Subject: [PATCH 4/7] Pull request 312: 6321 boot ttl vol.1 Updates AdguardTeam/AdGuardHome#6321. Squashed commit of the following: commit fb8f98b6fec99e0a2b668ee62e6e63f974e5b801 Merge: 329e3e3 a87a3df Author: Eugene Burkov Date: Thu Jan 11 14:14:06 2024 +0300 Merge branch 'master' into 6321-boot-ttl-vol.1 commit 329e3e3a2ad1cdecc29b5fcedc4f16f80c8e9521 Author: Eugene Burkov Date: Tue Jan 9 23:27:41 2024 +0300 all: imp docs commit 91cc3f0f77f4c39f6a13fbdceb0486ca695cd07e Author: Eugene Burkov Date: Fri Dec 29 17:59:36 2023 +0500 all: imp code, rm redundant changes commit f0df8c2fbdfa2db47d268397db95cd3c904053cc Author: Eugene Burkov Date: Thu Dec 28 19:20:07 2023 +0500 all: imp code commit c1fd087dd1e563d02e67ebffedf31657e73325fe Author: Eugene Burkov Date: Wed Dec 27 18:23:38 2023 +0300 all: move code, use new types commit e1d94052a1c243e5d59cad36e0b081a71f7c2dda Author: Eugene Burkov Date: Wed Dec 27 18:17:08 2023 +0300 upstream: add separate test --- internal/bootstrap/bootstrap.go | 75 ++++---- internal/bootstrap/bootstrap_test.go | 8 +- internal/bootstrap/error.go | 6 + internal/bootstrap/resolver.go | 66 ++++++-- internal/netutil/netutil.go | 36 ++++ proxy/proxy_test.go | 4 +- upstream/resolver.go | 244 +++++++++++++++++++-------- upstream/upstream.go | 31 +--- upstream/upstream_test.go | 33 ++-- 9 files changed, 342 insertions(+), 161 deletions(-) create mode 100644 internal/bootstrap/error.go diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 1084cbefa..ae08de9d1 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -14,20 +14,42 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "golang.org/x/exp/slices" +) + +// Network is a network type for use in [Resolver]'s methods. +type Network = string + +const ( + // NetworkIP is a network type for both address families. + NetworkIP Network = "ip" + + // NetworkIP4 is a network type for IPv4 address family. + NetworkIP4 Network = "ip4" + + // NetworkIP6 is a network type for IPv6 address family. + NetworkIP6 Network = "ip6" + + // NetworkTCP is a network type for TCP connections. + NetworkTCP Network = "tcp" + + // NetworkUDP is a network type for UDP connections. + NetworkUDP Network = "udp" ) // DialHandler is a dial function for creating unencrypted network connections // to the upstream server. It establishes the connection to the server -// specified at initialization and ignores the addr. -type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn, err error) +// specified at initialization and ignores the addr. network must be one of +// [NetworkTCP] or [NetworkUDP]. +type DialHandler func(ctx context.Context, network Network, addr string) (conn net.Conn, err error) // ResolveDialContext returns a DialHandler that uses addresses resolved from u // using resolver. u must not be nil. func ResolveDialContext( u *url.URL, timeout time.Duration, - resolver Resolver, - preferIPv6 bool, + r Resolver, + preferV6 bool, ) (h DialHandler, err error) { defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }() @@ -38,7 +60,7 @@ func ResolveDialContext( return nil, err } - if resolver == nil { + if r == nil { return nil, fmt.Errorf("resolver is nil: %w", ErrNoResolvers) } @@ -49,21 +71,20 @@ func ResolveDialContext( defer cancel() } - ips, err := resolver.LookupNetIP(ctx, "ip", host) + ips, err := r.LookupNetIP(ctx, NetworkIP, host) if err != nil { return nil, fmt.Errorf("resolving hostname: %w", err) } - proxynetutil.SortNetIPAddrs(ips, preferIPv6) + if preferV6 { + slices.SortStableFunc(ips, proxynetutil.PreferIPv6) + } else { + slices.SortStableFunc(ips, proxynetutil.PreferIPv4) + } addrs := make([]string, 0, len(ips)) for _, ip := range ips { - if !ip.IsValid() { - // All invalid addresses should be in the tail after sorting. - break - } - - addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)).String()) + addrs = append(addrs, netip.AddrPortFrom(ip, port).String()) } return NewDialContext(timeout, addrs...), nil @@ -71,14 +92,7 @@ func ResolveDialContext( // NewDialContext returns a DialHandler that dials addrs and returns the first // successful connection. At least a single addr should be specified. -// -// TODO(e.burkov): Consider using [Resolver] instead of -// [upstream.Options.Bootstrap] and [upstream.Options.ServerIPAddrs]. func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) { - dialer := &net.Dialer{ - Timeout: timeout, - } - l := len(addrs) if l == 0 { log.Debug("bootstrap: no addresses to dial") @@ -88,9 +102,11 @@ func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) { } } - // TODO(e.burkov): Check IPv6 preference here. + dialer := &net.Dialer{ + Timeout: timeout, + } - return func(ctx context.Context, network, _ string) (conn net.Conn, err error) { + return func(ctx context.Context, network Network, _ string) (conn net.Conn, err error) { var errs []error // Return first succeeded connection. Note that we're using addrs @@ -101,17 +117,18 @@ func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) { start := time.Now() conn, err = dialer.DialContext(ctx, network, addr) elapsed := time.Since(start) - if err == nil { - log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed) + if err != nil { + log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err) + errs = append(errs, err) - return conn, nil + continue } - log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err) - errs = append(errs, err) + log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed) + + return conn, nil } - // TODO(e.burkov): Use errors.Join in Go 1.20. - return nil, errors.List("all dialers failed", errs...) + return nil, errors.Join(errs...) } } diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go index e25f45c5f..53a0db11d 100644 --- a/internal/bootstrap/bootstrap_test.go +++ b/internal/bootstrap/bootstrap_test.go @@ -87,7 +87,7 @@ func TestResolveDialContext(t *testing.T) { network string, host string, ) (addrs []netip.Addr, err error) { - require.Equal(pt, "ip", network) + require.Equal(pt, bootstrap.NetworkIP, network) require.Equal(pt, hostname, host) return tc.addresses, nil @@ -103,7 +103,7 @@ func TestResolveDialContext(t *testing.T) { ) require.NoError(t, err) - conn, err := dialContext(context.Background(), "tcp", "") + conn, err := dialContext(context.Background(), bootstrap.NetworkTCP, "") require.NoError(t, err) expected, ok := testutil.RequireReceive(t, sig, testTimeout) @@ -120,7 +120,7 @@ func TestResolveDialContext(t *testing.T) { network string, host string, ) (addrs []netip.Addr, err error) { - require.Equal(pt, "ip", network) + require.Equal(pt, bootstrap.NetworkIP, network) require.Equal(pt, hostname, host) return nil, nil @@ -135,7 +135,7 @@ func TestResolveDialContext(t *testing.T) { ) require.NoError(t, err) - _, err = dialContext(context.Background(), "tcp", "") + _, err = dialContext(context.Background(), bootstrap.NetworkTCP, "") testutil.AssertErrorMsg(t, "no addresses", err) }) diff --git a/internal/bootstrap/error.go b/internal/bootstrap/error.go new file mode 100644 index 000000000..9f65e8226 --- /dev/null +++ b/internal/bootstrap/error.go @@ -0,0 +1,6 @@ +package bootstrap + +import "github.com/AdguardTeam/golibs/errors" + +// ErrNoResolvers is returned when zero resolvers specified. +const ErrNoResolvers errors.Error = "no resolvers specified" diff --git a/internal/bootstrap/resolver.go b/internal/bootstrap/resolver.go index b2c57c76f..9891adc25 100644 --- a/internal/bootstrap/resolver.go +++ b/internal/bootstrap/resolver.go @@ -8,22 +8,21 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "golang.org/x/exp/slices" ) -// Resolver resolves the hostnames to IP addresses. +// Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver] +// from standard library also implements this interface. type Resolver interface { - // LookupNetIP looks up the IP addresses for the given host. network must - // be one of "ip", "ip4" or "ip6". The response may be empty even if err is - // nil. - LookupNetIP(ctx context.Context, network, host string) (addrs []netip.Addr, err error) + // LookupNetIP looks up the IP addresses for the given host. network should + // be one of [NetworkIP], [NetworkIP4] or [NetworkIP6]. The response may be + // empty even if err is nil. All the addrs must be valid. + LookupNetIP(ctx context.Context, network Network, host string) (addrs []netip.Addr, err error) } // type check var _ Resolver = &net.Resolver{} -// ErrNoResolvers is returned when zero resolvers specified. -const ErrNoResolvers errors.Error = "no resolvers specified" - // ParallelResolver is a slice of resolvers that are queried concurrently. The // first successful response is returned. type ParallelResolver []Resolver @@ -34,7 +33,7 @@ var _ Resolver = ParallelResolver(nil) // LookupNetIP implements the [Resolver] interface for ParallelResolver. func (r ParallelResolver) LookupNetIP( ctx context.Context, - network string, + network Network, host string, ) (addrs []netip.Addr, err error) { resolversNum := len(r) @@ -48,7 +47,7 @@ func (r ParallelResolver) LookupNetIP( } // Size of channel must accommodate results of lookups from all resolvers, - // sending into channel will be block otherwise. + // sending into channel will block otherwise. ch := make(chan any, resolversNum) for _, rslv := range r { go lookupAsync(ctx, rslv, network, host, ch) @@ -97,3 +96,50 @@ func lookup(ctx context.Context, r Resolver, network, host string) (addrs []neti return addrs, err } + +// ConsequentResolver is a slice of resolvers that are queried in order until +// the first successful non-empty response, as opposed to just successful +// response requirement in [ParallelResolver]. +type ConsequentResolver []Resolver + +// type check +var _ Resolver = ConsequentResolver(nil) + +// LookupNetIP implements the [Resolver] interface for ConsequentResolver. +func (resolvers ConsequentResolver) LookupNetIP( + ctx context.Context, + network Network, + host string, +) (addrs []netip.Addr, err error) { + if len(resolvers) == 0 { + return nil, ErrNoResolvers + } + + var errs []error + for _, r := range resolvers { + addrs, err = r.LookupNetIP(ctx, network, host) + if err == nil && len(addrs) > 0 { + return addrs, nil + } + + errs = append(errs, err) + } + + return nil, errors.Join(errs...) +} + +// StaticResolver is a resolver which always responds with an underlying slice +// of IP addresses regardless of host and network. +type StaticResolver []netip.Addr + +// type check +var _ Resolver = StaticResolver(nil) + +// LookupNetIP implements the [Resolver] interface for StaticResolver. +func (r StaticResolver) LookupNetIP( + _ context.Context, + _ Network, + _ string, +) (addrs []netip.Addr, err error) { + return slices.Clone(r), nil +} diff --git a/internal/netutil/netutil.go b/internal/netutil/netutil.go index 6a1719841..21faae213 100644 --- a/internal/netutil/netutil.go +++ b/internal/netutil/netutil.go @@ -12,6 +12,42 @@ import ( "golang.org/x/exp/slices" ) +// PreferIPv4 compares two addresses, preferring IPv4 addresses over IPv6 ones. +// Invalid addresses are sorted near the end. +func PreferIPv4(a, b netip.Addr) (res int) { + if !a.IsValid() { + return 1 + } else if !b.IsValid() { + return -1 + } + + if aIs4 := a.Is4(); aIs4 == b.Is4() { + return a.Compare(b) + } else if aIs4 { + return -1 + } + + return 1 +} + +// PreferIPv6 compares two addresses, preferring IPv6 addresses over IPv4 ones. +// Invalid addresses are sorted near the end. +func PreferIPv6(a, b netip.Addr) (res int) { + if !a.IsValid() { + return 1 + } else if !b.IsValid() { + return -1 + } + + if aIs6 := a.Is6(); aIs6 == b.Is6() { + return a.Compare(b) + } else if aIs6 { + return -1 + } + + return 1 +} + // SortNetIPAddrs sorts addrs in accordance with the protocol preferences. // Invalid addresses are sorted near the end. Zones are ignored. func SortNetIPAddrs(addrs []netip.Addr, preferIPv6 bool) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index ba85db0f2..5a813ee70 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -329,7 +329,7 @@ func TestExchangeWithReservedDomains(t *testing.T) { upstreams, &upstream.Options{ InsecureSkipVerify: false, - Bootstrap: googleRslv, + Bootstrap: upstream.NewCachingResolver(googleRslv), Timeout: 1 * time.Second, }, ) @@ -412,7 +412,7 @@ func TestOneByOneUpstreamsExchange(t *testing.T) { u, err = upstream.AddressToUpstream( line, &upstream.Options{ - Bootstrap: googleRslv, + Bootstrap: upstream.NewCachingResolver(googleRslv), Timeout: timeOut, }, ) diff --git a/upstream/resolver.go b/upstream/resolver.go index ab017d975..5f98c800d 100644 --- a/upstream/resolver.go +++ b/upstream/resolver.go @@ -5,20 +5,35 @@ import ( "fmt" "net/netip" "net/url" + "strings" + "sync" + "time" "github.com/AdguardTeam/dnsproxy/internal/bootstrap" "github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" - "golang.org/x/exp/slices" ) -// Resolver is an alias for the internal [bootstrap.Resolver] to allow custom -// implementations. Note, that the [net.Resolver] from standard library also -// implements this interface. +// Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver] +// from standard library also implements this interface. type Resolver = bootstrap.Resolver +// StaticResolver is a resolver which always responds with an underlying slice +// of IP addresses. +type StaticResolver = bootstrap.StaticResolver + +// ParallelResolver is a slice of resolvers that are queried concurrently until +// the first successful response is returned, as opposed to all resolvers being +// queried in order in [ConsequentResolver]. +type ParallelResolver = bootstrap.ParallelResolver + +// ConsequentResolver is a slice of resolvers that are queried in order until +// the first successful non-empty response, as opposed to just successful +// response requirement in [ParallelResolver]. +type ConsequentResolver = bootstrap.ConsequentResolver + // UpstreamResolver is a wrapper around Upstream that implements the // [bootstrap.Resolver] interface. type UpstreamResolver struct { @@ -105,57 +120,107 @@ func validateBootstrap(u Upstream) (err error) { // type check var _ Resolver = &UpstreamResolver{} -// LookupNetIP implements the [Resolver] interface for upstreamResolver. +// LookupNetIP implements the [Resolver] interface for *UpstreamResolver. It +// doesn't consider the TTL of the DNS records. // -// TODO(e.burkov): Use context. +// TODO(e.burkov): Investigate why the empty slice is returned instead of nil. func (r *UpstreamResolver) LookupNetIP( - _ context.Context, - network string, + ctx context.Context, + network bootstrap.Network, host string, ) (ips []netip.Addr, err error) { if host == "" { return nil, nil } - switch network { - case "ip4", "ip6": - host = dns.Fqdn(host) - ips, err = r.resolve(host, network) - case "ip": - host = dns.Fqdn(host) - resCh := make(chan any, 2) - go r.resolveAsync(resCh, host, "ip4") - go r.resolveAsync(resCh, host, "ip6") - - var errs []error - for i := 0; i < 2; i++ { - switch res := <-resCh; res := res.(type) { - case error: - errs = append(errs, res) - case []netip.Addr: - ips = append(ips, res...) - } + host = dns.Fqdn(strings.ToLower(host)) + + rr, err := r.resolveIP(ctx, network, host) + if err != nil { + return []netip.Addr{}, err + } + + for _, ip := range rr { + ips = append(ips, ip.addr) + } + + return ips, err +} + +// ipResult reflects a single A/AAAA record from the DNS response. It's used +// to cache the results of lookups. +type ipResult struct { + addr netip.Addr + expire time.Time +} + +// filterExpired returns the addresses from res that are not expired yet. It +// returns nil if all the addresses are expired. +func filterExpired(res []ipResult, now time.Time) (filtered []netip.Addr) { + for _, r := range res { + if r.expire.After(now) { + filtered = append(filtered, r.addr) } + } - err = errors.Join(errs...) + return filtered +} + +// resolveIP performs a DNS lookup of host and returns the result. network must +// be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6] or +// [bootstrap.NetworkIP]. host must be in a lower-case FQDN form. +// +// TODO(e.burkov): Use context. +func (r *UpstreamResolver) resolveIP( + _ context.Context, + network bootstrap.Network, + host string, +) (rr []ipResult, err error) { + switch network { + case bootstrap.NetworkIP4, bootstrap.NetworkIP6: + return r.resolve(host, network) + case bootstrap.NetworkIP: + // Go on. default: - return []netip.Addr{}, fmt.Errorf("unsupported network %s", network) + return nil, fmt.Errorf("unsupported network %s", network) } - if len(ips) == 0 { - ips = []netip.Addr{} + resCh := make(chan any, 2) + go r.resolveAsync(resCh, host, bootstrap.NetworkIP4) + go r.resolveAsync(resCh, host, bootstrap.NetworkIP6) + + var errs []error + + for i := 0; i < 2; i++ { + switch res := <-resCh; res := res.(type) { + case error: + errs = append(errs, res) + case []ipResult: + rr = append(rr, res...) + } } - return ips, err + return rr, errors.Join(errs...) } // resolve performs a single DNS lookup of host and returns all the valid // addresses from the answer section of the response. network must be either -// "ip4" or "ip6". -func (r *UpstreamResolver) resolve(host, network string) (addrs []netip.Addr, err error) { - qtype := dns.TypeA - if network == "ip6" { +// "ip4" or "ip6". host must be in a lower-case FQDN form. +// +// TODO(e.burkov): Consider NS and Extra sections when setting TTL. Check out +// what RFCs say about it. +func (r *UpstreamResolver) resolve( + host string, + n bootstrap.Network, +) (res []ipResult, err error) { + var qtype uint16 + switch n { + case bootstrap.NetworkIP4: + qtype = dns.TypeA + case bootstrap.NetworkIP6: qtype = dns.TypeAAAA + default: + panic(fmt.Sprintf("unsupported network %q", n)) } req := &dns.Msg{ @@ -170,78 +235,107 @@ func (r *UpstreamResolver) resolve(host, network string) (addrs []netip.Addr, er }}, } - resp, err := r.Upstream.Exchange(req) - if err != nil || resp == nil { + // As per [upstream.Exchange] documentation, the response is always returned + // if no error occurred. + resp, err := r.Exchange(req) + if err != nil { return nil, err } + now := time.Now() for _, rr := range resp.Answer { - if addr := proxyutil.IPFromRR(rr); addr.IsValid() { - addrs = append(addrs, addr) + ip := proxyutil.IPFromRR(rr) + if !ip.IsValid() { + continue } + + res = append(res, ipResult{ + addr: ip, + expire: now.Add(time.Duration(rr.Header().Ttl) * time.Second), + }) } - return addrs, nil + return res, nil } // resolveAsync performs a single DNS lookup and sends the result to ch. It's // intended to be used as a goroutine. func (r *UpstreamResolver) resolveAsync(resCh chan<- any, host, network string) { - resp, err := r.resolve(host, network) + res, err := r.resolve(host, network) if err != nil { resCh <- err } else { - resCh <- resp + resCh <- res } } -// StaticResolver is a resolver which always responds with an underlying slice -// of IP addresses. -type StaticResolver []netip.Addr +// CachingResolver is a [Resolver] that caches the results of lookups. It's +// required to be created with [NewCachingResolver]. +type CachingResolver struct { + // resolver is the underlying resolver to use for lookups. + resolver *UpstreamResolver -// type check -var _ Resolver = StaticResolver(nil) + // mu protects cached and it's elements. + mu *sync.RWMutex -// LookupNetIP implements the [Resolver] interface for StaticResolver. -func (r StaticResolver) LookupNetIP( - ctx context.Context, - network string, - host string, -) (addrs []netip.Addr, err error) { - return slices.Clone(r), nil + // cached is the set of cached results sorted by [resolveResult.name]. + cached map[string][]ipResult } -// ConsequentResolver is a slice of resolvers that are queried in order until -// the first successful non-empty response, as opposed to just successful -// response requirement in [ParallelResolver]. -type ConsequentResolver []Resolver +// NewCachingResolver creates a new caching resolver that uses r for lookups. +func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) { + return &CachingResolver{ + resolver: r, + mu: &sync.RWMutex{}, + cached: map[string][]ipResult{}, + } +} // type check -var _ Resolver = ConsequentResolver(nil) +var _ Resolver = (*CachingResolver)(nil) -// LookupNetIP implements the [Resolver] interface for ConsequentResolver. -func (resolvers ConsequentResolver) LookupNetIP( +// LookupNetIP implements the [Resolver] interface for *CachingResolver. +func (r *CachingResolver) LookupNetIP( ctx context.Context, - network string, + network bootstrap.Network, host string, ) (addrs []netip.Addr, err error) { - if len(resolvers) == 0 { - return nil, bootstrap.ErrNoResolvers + now := time.Now() + host = dns.Fqdn(strings.ToLower(host)) + + addrs = r.findCached(host, now) + if addrs != nil { + return addrs, nil } - var errs []error - for _, r := range resolvers { - addrs, err = r.LookupNetIP(ctx, network, host) - if err == nil && len(addrs) > 0 { - return addrs, nil - } + newRes, err := r.resolver.resolveIP(ctx, network, host) + if err != nil { + return []netip.Addr{}, err + } - errs = append(errs, err) + addrs = filterExpired(newRes, now) + if len(addrs) == 0 { + return []netip.Addr{}, nil } - return nil, errors.Join(errs...) + r.mu.Lock() + defer r.mu.Unlock() + + r.cached[host] = newRes + + return addrs, nil } -// ParallelResolver is an alias for the internal [bootstrap.ParallelResolver] to -// allow it's usage outside of the module. -type ParallelResolver = bootstrap.ParallelResolver +// findCached returns the cached addresses for host if it's not expired yet, and +// the corresponding cached result, if any. +func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.Addr) { + r.mu.RLock() + defer r.mu.RUnlock() + + res, ok := r.cached[host] + if !ok { + return nil + } + + return filterExpired(res, now) +} diff --git a/upstream/upstream.go b/upstream/upstream.go index 5b147ea76..9cd345d87 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -14,7 +14,6 @@ import ( "os" "strconv" "strings" - "sync/atomic" "time" "github.com/AdguardTeam/dnsproxy/internal/bootstrap" @@ -323,9 +322,7 @@ func isTimeout(err error) (ok bool) { } } -// DialerInitializer returns the handler that it creates. All the subsequent -// calls to it, except the first one, will return the same handler so that -// resolving will be performed only once. +// DialerInitializer returns the handler that it creates. type DialerInitializer func() (handler bootstrap.DialHandler, err error) // newDialerInitializer creates an initializer of the dialer that will dial the @@ -335,7 +332,9 @@ func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer) { // Don't resolve the address of the server since it's already an IP. handler := bootstrap.NewDialContext(opts.Timeout, u.Host) - return func() (bootstrap.DialHandler, error) { return handler, nil } + return func() (h bootstrap.DialHandler, dialerErr error) { + return handler, nil + } } boot := opts.Bootstrap @@ -344,27 +343,7 @@ func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer) { boot = net.DefaultResolver } - var dialHandler atomic.Pointer[bootstrap.DialHandler] - return func() (h bootstrap.DialHandler, err error) { - // Check if the dial handler has already been created. - if hPtr := dialHandler.Load(); hPtr != nil { - return *hPtr, nil - } - - // TODO(e.burkov): It may appear that several exchanges will try to - // resolve the upstream hostname at the same time. Currently, the last - // successful value will be stored in dialHandler, but ideally we should - // resolve only once. - h, err = bootstrap.ResolveDialContext(u, opts.Timeout, boot, opts.PreferIPv6) - if err != nil { - return nil, fmt.Errorf("creating dial handler: %w", err) - } - - if !dialHandler.CompareAndSwap(nil, &h) { - return *dialHandler.Load(), nil - } - - return h, nil + return bootstrap.ResolveDialContext(u, opts.Timeout, boot, opts.PreferIPv6) } } diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index f56569253..0ad53b660 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -55,7 +55,7 @@ func TestUpstream_bootstrapTimeout(t *testing.T) { // Create an upstream that uses this faulty bootstrap. u, err := AddressToUpstream("tls://random-domain-name", &Options{ - Bootstrap: rslv, + Bootstrap: NewCachingResolver(rslv), Timeout: timeout, }) require.NoError(t, err) @@ -114,17 +114,20 @@ func TestUpstreams(t *testing.T) { }) require.NoError(t, err) + googleBoot := NewCachingResolver(googleRslv) + cloudflareBoot := NewCachingResolver(cloudflareRslv) + upstreams := []struct { bootstrap Resolver address string }{{ - bootstrap: googleRslv, + bootstrap: googleBoot, address: "8.8.8.8:53", }, { bootstrap: nil, address: "1.1.1.1", }, { - bootstrap: cloudflareRslv, + bootstrap: cloudflareBoot, address: "1.1.1.1", }, { bootstrap: nil, @@ -139,19 +142,19 @@ func TestUpstreams(t *testing.T) { bootstrap: nil, address: "tls://9.9.9.9:853", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "tls://dns.adguard.com", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "tls://dns.adguard.com:853", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "tls://dns.adguard.com:853", }, { bootstrap: nil, address: "tls://one.one.one.one", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "https://1dot1dot1dot1.cloudflare-dns.com/dns-query", }, { bootstrap: nil, @@ -165,11 +168,11 @@ func TestUpstreams(t *testing.T) { address: "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", }, { // AdGuard Family (DNSCrypt) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMjo1NDQzILgxXdexS27jIKRw3C7Wsao5jMnlhvhdRUXWuMm1AFq6ITIuZG5zY3J5cHQuZmFtaWx5Lm5zMS5hZGd1YXJkLmNvbQ", }, { // Cloudflare DNS (DNS-over-HTTPS) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk", }, { // Google (Plain) @@ -177,11 +180,11 @@ func TestUpstreams(t *testing.T) { address: "sdns://AAcAAAAAAAAABzguOC44Ljg", }, { // AdGuard DNS (DNS-over-TLS) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t", }, { // AdGuard DNS (DNS-over-QUIC) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://BAcAAAAAAAAAAAAXZG5zLmFkZ3VhcmQtZG5zLmNvbTo3ODQ", }, { // Cloudflare DNS (DNS-over-HTTPS) @@ -189,7 +192,7 @@ func TestUpstreams(t *testing.T) { address: "https://1.1.1.1/dns-query", }, { // AdGuard DNS (DNS-over-QUIC) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "quic://dns.adguard-dns.com", }, { // Google DNS (HTTP3) @@ -215,7 +218,7 @@ func TestAddressToUpstream(t *testing.T) { cloudflareRslv, err := NewUpstreamResolver("1.1.1.1", nil) require.NoError(t, err) - opt := &Options{Bootstrap: cloudflareRslv} + opt := &Options{Bootstrap: NewCachingResolver(cloudflareRslv)} testCases := []struct { addr string @@ -314,7 +317,7 @@ func TestUpstreamDoTBootstrap(t *testing.T) { require.NoError(t, err) u, err := AddressToUpstream(tc.address, &Options{ - Bootstrap: rslv, + Bootstrap: NewCachingResolver(rslv), Timeout: timeout, }) require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address) @@ -361,7 +364,7 @@ func TestUpstreamsInvalidBootstrap(t *testing.T) { }) require.NoError(t, err) - rslv = append(rslv, r) + rslv = append(rslv, NewCachingResolver(r)) } u, err := AddressToUpstream(tc.address, &Options{ From f1ceef03ad5be7272faeab541b7b71dd18d89f4e Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Fri, 12 Jan 2024 13:09:22 +0300 Subject: [PATCH 5/7] Pull request 314: 6321 boot ttl vol.2 Squashed commit of the following: commit 8b0d6af26f07c858f6fe564963eb823c3f10f619 Author: Eugene Burkov Date: Thu Jan 11 15:41:30 2024 +0300 main: use caching resolvers --- main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.go b/main.go index dc53a047d..efcd73e3f 100644 --- a/main.go +++ b/main.go @@ -465,13 +465,13 @@ func initBootstrap(bootstraps []string, opts *upstream.Options) (r upstream.Reso var resolvers []upstream.Resolver for i, b := range bootstraps { - var resolver upstream.Resolver - resolver, err = upstream.NewUpstreamResolver(b, opts) + var ur *upstream.UpstreamResolver + ur, err = upstream.NewUpstreamResolver(b, opts) if err != nil { return nil, fmt.Errorf("creating bootstrap resolver at index %d: %w", i, err) } - resolvers = append(resolvers, resolver) + resolvers = append(resolvers, upstream.NewCachingResolver(ur)) } switch len(resolvers) { From edb394b25c873f5201c60294a60034e25cb08566 Mon Sep 17 00:00:00 2001 From: Dimitry Kolyshev Date: Tue, 16 Jan 2024 11:32:06 +0300 Subject: [PATCH 6/7] Pull request: AG-28961-upd-golibs Squashed commit of the following: commit 75ef9755582205d0ff7fcd2e9d3e224266d248d9 Merge: 6e9453f f1ceef0 Author: Dimitry Kolyshev Date: Tue Jan 16 10:02:11 2024 +0200 Merge remote-tracking branch 'origin/master' into AG-28961-upd-golibs commit 6e9453fa6969d4ec772d68da38ea74243d225ead Author: Dimitry Kolyshev Date: Mon Jan 15 12:35:10 2024 +0200 all: upd golibs commit 42fbf390a6209298b92dbbd131b9c3ecfee56e34 Author: Dimitry Kolyshev Date: Mon Jan 15 11:24:46 2024 +0200 proxy: imp code commit 1b3dde756306c4fa3d8271284dc940aae21540a2 Author: Dimitry Kolyshev Date: Mon Jan 15 11:02:20 2024 +0200 proxy: imp code commit 129935dbd2efeb068b4875b3ff570ffcf95217dc Author: Dimitry Kolyshev Date: Wed Jan 10 14:49:36 2024 +0200 proxy: imp code commit 9205a0c8dc534b367828508a0bac8a2a0f4b875a Author: Dimitry Kolyshev Date: Wed Jan 10 14:46:27 2024 +0200 proxy: conf commit 7d8200024fd34b4387aa3c519b65a5a7693f78d0 Author: Dimitry Kolyshev Date: Wed Jan 10 13:56:55 2024 +0200 proxy: slices commit 4e50cd17aebba713b2be8065e860154972fd7b2f Author: Dimitry Kolyshev Date: Wed Jan 10 13:53:06 2024 +0200 proxy: addr commit 46ee8f7df6c53d6cf00bd1bc6656303e2af19580 Author: Dimitry Kolyshev Date: Wed Jan 10 13:42:57 2024 +0200 proxy: conf --- go.mod | 16 ++++++++-------- go.sum | 34 +++++++++++++++++----------------- main.go | 12 ++++++++---- proxy/config.go | 9 ++++----- proxy/dns64.go | 13 ++++--------- proxy/proxy.go | 22 +++++++--------------- proxy/proxy_test.go | 5 ++++- proxy/proxycache.go | 16 ++++++++++++++-- proxy/server_https.go | 3 +-- proxy/server_https_test.go | 8 ++++---- 10 files changed, 71 insertions(+), 67 deletions(-) diff --git a/go.mod b/go.mod index d64b2a0d0..7de938790 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/dnsproxy go 1.20 require ( - github.com/AdguardTeam/golibs v0.18.1 + github.com/AdguardTeam/golibs v0.19.0 github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnsstamps v1.0.3 github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 @@ -13,9 +13,9 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/quic-go/quic-go v0.39.1 github.com/stretchr/testify v1.8.4 - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 - golang.org/x/net v0.17.0 - golang.org/x/sys v0.13.0 + golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 + golang.org/x/net v0.19.0 + golang.org/x/sys v0.15.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -32,10 +32,10 @@ require ( github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qtls-go1-20 v0.3.4 // indirect go.uber.org/mock v0.3.0 // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/mod v0.12.0 // indirect - golang.org/x/text v0.13.0 // indirect - golang.org/x/tools v0.13.0 // indirect + golang.org/x/crypto v0.16.0 // indirect + golang.org/x/mod v0.14.0 // indirect + golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.16.0 // indirect gonum.org/v1/gonum v0.14.0 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect ) diff --git a/go.sum b/go.sum index 3e0a8f0bd..7d0876d13 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/golibs v0.18.1 h1:6u0fvrIj2qjUsRdbIGJ9AR0g5QRSWdKIo/DYl3tp5aM= -github.com/AdguardTeam/golibs v0.18.1/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= +github.com/AdguardTeam/golibs v0.19.0 h1:y/x+Xn3pDg1ZfQ+QEZapPJqaeVYUIMp/EODMtVhn7PM= +github.com/AdguardTeam/golibs v0.19.0/go.mod h1:3WunclLLfrVAq7fYQRhd6f168FHOEMssnipVXCxDL/w= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= @@ -52,22 +52,22 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= -golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.16.0 h1:GO788SKMRunPIBCXiQyo2AaexLstOrVhuAL5YwsckQM= +golang.org/x/tools v0.16.0/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= diff --git a/main.go b/main.go index efcd73e3f..b6f23706f 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/mathutil" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/osutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/ameshkov/dnscrypt/v2" @@ -345,10 +346,13 @@ func createProxyConfig(options *Options) (conf proxy.Config) { CacheOptimistic: options.CacheOptimistic, RefuseAny: options.RefuseAny, HTTP3: options.HTTP3, - // TODO(e.burkov): The following CIDRs are aimed to match any - // address. This is not quite proper approach to be used by - // default so think about configuring it. - TrustedProxies: []string{"0.0.0.0/0", "::0/0"}, + // TODO(e.burkov): The following CIDRs are aimed to match any address. + // This is not quite proper approach to be used by default so think + // about configuring it. + TrustedProxies: netutil.SliceSubnetSet{ + netip.MustParsePrefix("0.0.0.0/0"), + netip.MustParsePrefix("::0/0"), + }, EnableEDNSClientSubnet: options.EnableEDNSSubnet, UDPBufferSize: options.UDPBufferSize, HTTPSServerName: options.HTTPSServerName, diff --git a/proxy/config.go b/proxy/config.go index 33c0fd296..7aa907a88 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -86,11 +86,10 @@ type Config struct { // RefuseAny makes proxy refuse the requests of type ANY. RefuseAny bool - // TrustedProxies is the list of IP addresses and CIDR networks to - // detect proxy servers addresses the DoH requests from which should be - // handled. The value of nil or an empty slice for this field makes - // Proxy not trust any address. - TrustedProxies []string + // TrustedProxies is the trusted list of CIDR networks to detect proxy + // servers addresses from where the DoH requests should be handled. The + // value of nil makes Proxy not trust any address. + TrustedProxies netutil.SubnetSet // Upstream DNS servers and their settings // -- diff --git a/proxy/dns64.go b/proxy/dns64.go index 89960fc48..64148e14f 100644 --- a/proxy/dns64.go +++ b/proxy/dns64.go @@ -213,7 +213,7 @@ func (p *Proxy) withinDNS64(ip netip.Addr) (ok bool) { return false } -// shouldStripDNS64 returns true if DNS64 is enabled and ip has either one of +// shouldStripDNS64 returns true if DNS64 is enabled and addr has either one of // custom DNS64 prefixes or the Well-Known one. This is intended to be used // with PTR requests. // @@ -223,21 +223,16 @@ func (p *Proxy) withinDNS64(ip netip.Addr) (ok bool) { // DNS64. // // See https://datatracker.ietf.org/doc/html/rfc6147#section-5.3.1. -func (p *Proxy) shouldStripDNS64(ip net.IP) (ok bool) { +func (p *Proxy) shouldStripDNS64(addr netip.Addr) (ok bool) { if len(p.dns64Prefs) == 0 { return false } - addr, err := netutil.IPToAddr(ip, netutil.AddrFamilyIPv6) - if err != nil { - return false - } - switch { case p.withinDNS64(addr): - log.Debug("proxy: %s is within DNS64 custom prefix set", ip) + log.Debug("proxy: %s is within DNS64 custom prefix set", addr) case dns64WellKnownPref.Contains(addr): - log.Debug("proxy: %s is within DNS64 well-known prefix", ip) + log.Debug("proxy: %s is within DNS64 well-known prefix", addr) default: return false } diff --git a/proxy/proxy.go b/proxy/proxy.go index d99136de7..151b79955 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -137,9 +137,6 @@ type Proxy struct { // ratelimitLock protects ratelimitBuckets. ratelimitLock sync.Mutex - // proxyVerifier checks if the proxy is in the trusted list. - proxyVerifier netutil.SubnetSet - // DNS cache // -- @@ -229,14 +226,6 @@ func (p *Proxy) Init() (err error) { } } - var trusted []*net.IPNet - trusted, err = netutil.ParseSubnets(p.TrustedProxies...) - if err != nil { - return fmt.Errorf("initializing subnet detector for proxies verifying: %w", err) - } - - p.proxyVerifier = netutil.SliceSubnetSet(trusted) - err = p.setupDNS64() if err != nil { return fmt.Errorf("setting up DNS64: %w", err) @@ -528,7 +517,7 @@ func (p *Proxy) selectUpstreams(d *DNSContext) (upstreams []upstream.Upstream) { // TODO(e.burkov): Detect against the actual configured subnet set. // Perhaps, even much earlier. - if !netutil.IsLocallyServedAddr(d.Addr.Addr()) { + if !netutil.IsLocallyServed(d.Addr.Addr()) { return nil } @@ -719,12 +708,15 @@ func (dctx *DNSContext) processECS(cliIP net.IP) { } } - // Set ECS. + var cliAddr netip.Addr if cliIP == nil { - cliIP = dctx.Addr.Addr().AsSlice() + cliAddr = dctx.Addr.Addr() + cliIP = cliAddr.AsSlice() + } else { + cliAddr, _ = netip.AddrFromSlice(cliIP) } - if !netutil.IsSpecialPurpose(cliIP) { + if !netutil.IsSpecialPurpose(cliAddr) { // A Stub Resolver MUST set SCOPE PREFIX-LENGTH to 0. See RFC 7871 // Section 6. dctx.ReqECS = setECS(dctx.Req, cliIP, 0) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 5a813ee70..c59367d6d 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1188,7 +1188,10 @@ func createTestProxy(t *testing.T, tlsConfig *tls.Config) (p *Proxy) { p.UpstreamConfig = &UpstreamConfig{} p.UpstreamConfig.Upstreams = append(upstreams, dnsUpstream) - p.TrustedProxies = []string{"0.0.0.0/0", "::0/0"} + p.TrustedProxies = netutil.SliceSubnetSet{ + netip.MustParsePrefix("0.0.0.0/0"), + netip.MustParsePrefix("::0/0"), + } p.RatelimitSubnetLenIPv4 = 24 p.RatelimitSubnetLenIPv6 = 64 diff --git a/proxy/proxycache.go b/proxy/proxycache.go index c17ed98e2..391990f86 100644 --- a/proxy/proxycache.go +++ b/proxy/proxycache.go @@ -4,7 +4,7 @@ import ( "net" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/netutil" + "golang.org/x/exp/slices" ) // cacheForContext returns cache object for the given context. @@ -52,7 +52,7 @@ func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) { minCtxClone := &DNSContext{ // It is only read inside the optimistic resolver. CustomUpstreamConfig: d.CustomUpstreamConfig, - ReqECS: netutil.CloneIPNet(d.ReqECS), + ReqECS: cloneIPNet(d.ReqECS), } if d.Req != nil { minCtxClone.Req = d.Req.Copy() @@ -65,6 +65,18 @@ func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) { return hit } +// cloneIPNet returns a deep clone of n. +func cloneIPNet(n *net.IPNet) (clone *net.IPNet) { + if n == nil { + return nil + } + + return &net.IPNet{ + IP: slices.Clone(n.IP), + Mask: slices.Clone(n.Mask), + } +} + // cacheResp stores the response from d in general or subnet cache. In case the // cache is present in d, it's used first. func (p *Proxy) cacheResp(d *DNSContext) { diff --git a/proxy/server_https.go b/proxy/server_https.go index 1175590db..f842eeeb8 100644 --- a/proxy/server_https.go +++ b/proxy/server_https.go @@ -162,8 +162,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if prx.IsValid() { log.Debug("dnsproxy: request came from proxy server %s", prx) - // TODO(s.chzhen): Consider using []netip.Prefix. - if !p.proxyVerifier.Contains(prx.Addr().AsSlice()) { + if !p.TrustedProxies.Contains(prx.Addr()) { log.Debug("dnsproxy: proxy %s is not trusted, using original remote addr", prx) d.Addr = prx } diff --git a/proxy/server_https_test.go b/proxy/server_https_test.go index cf81bca93..ab9440760 100644 --- a/proxy/server_https_test.go +++ b/proxy/server_https_test.go @@ -66,7 +66,7 @@ func TestProxy_trustedProxies(t *testing.T) { proxyAddr = netip.MustParseAddr("127.0.0.1") ) - doRequest := func(t *testing.T, addr string, expectedClientIP netip.Addr) { + doRequest := func(t *testing.T, addr, expectedClientIP netip.Addr) { // Prepare the proxy server. tlsConf, caPem := createServerTLSConfig(t) dnsProxy := createTestProxy(t, tlsConf) @@ -82,7 +82,7 @@ func TestProxy_trustedProxies(t *testing.T) { msg := createTestMessage() - dnsProxy.TrustedProxies = []string{addr} + dnsProxy.TrustedProxies = netip.PrefixFrom(addr, addr.BitLen()) // Start listening. serr := dnsProxy.Start() @@ -100,11 +100,11 @@ func TestProxy_trustedProxies(t *testing.T) { } t.Run("success", func(t *testing.T) { - doRequest(t, proxyAddr.String(), clientAddr) + doRequest(t, proxyAddr, clientAddr) }) t.Run("not_in_trusted", func(t *testing.T) { - doRequest(t, "127.0.0.2", proxyAddr) + doRequest(t, netip.MustParseAddr("127.0.0.2"), proxyAddr) }) } From a2506cdb8217517800f9f935075ffe2236a20cf6 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Tue, 23 Jan 2024 16:32:32 +0300 Subject: [PATCH 7/7] Pull request 315: 373 fix goroutines leak Updates #373. Squashed commit of the following: commit 0632b4f36b2eee75c97c550e0e6b96169dec34c5 Author: Eugene Burkov Date: Tue Jan 23 16:21:41 2024 +0300 upstream: imp code, logging commit cea34d59ea26298325c5a12d377b0649950d9460 Author: Eugene Burkov Date: Tue Jan 23 15:50:53 2024 +0300 upstream: use mutex. imp logging --- upstream/quic.go | 74 ++++++++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/upstream/quic.go b/upstream/quic.go index 8e57778f4..9c69b3b7f 100644 --- a/upstream/quic.go +++ b/upstream/quic.go @@ -78,13 +78,13 @@ type dnsOverQUIC struct { bytesPool *sync.Pool // quicConfigMu protects quicConfig. - quicConfigMu sync.Mutex + quicConfigMu *sync.Mutex // connMu protects conn. - connMu sync.RWMutex + connMu *sync.Mutex // bytesPoolGuard protects bytesPool. - bytesPoolMu sync.Mutex + bytesPoolMu *sync.Mutex // timeout is the timeout for the upstream connection. timeout time.Duration @@ -118,7 +118,10 @@ func newDoQ(addr *url.URL, opts *Options) (u Upstream, err error) { VerifyConnection: opts.VerifyConnection, NextProtos: compatProtoDQ, }, - timeout: opts.Timeout, + quicConfigMu: &sync.Mutex{}, + connMu: &sync.Mutex{}, + bytesPoolMu: &sync.Mutex{}, + timeout: opts.Timeout, } runtime.SetFinalizer(u, (*dnsOverQUIC).Close) @@ -159,7 +162,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { // connection was closed (due to inactivity for example) AND the server // refuses to open a 0-RTT connection. for i := 0; hasConnection && p.shouldRetry(err) && i < 2; i++ { - log.Debug("re-creating the QUIC connection and retrying due to %v", err) + log.Debug("dnsproxy: re-creating the QUIC connection and retrying due to %v", err) // Close the active connection to make sure we'll try to re-connect. p.closeConnWithError(err) @@ -214,7 +217,7 @@ func (p *dnsOverQUIC) exchangeQUIC(req *dns.Msg) (resp *dns.Msg, err error) { var stream quic.Stream stream, err = p.openStream(conn) if err != nil { - return nil, err + return nil, fmt.Errorf("opening stream: %w", err) } _, err = stream.Write(proxyutil.AddPrefix(buf)) @@ -226,7 +229,10 @@ func (p *dnsOverQUIC) exchangeQUIC(req *dns.Msg) (resp *dns.Msg, err error) { // indicate through the STREAM FIN mechanism that no further data will // be sent on that stream. Note, that stream.Close() closes the // write-direction of the stream, but does not prevent reading from it. - _ = stream.Close() + err = stream.Close() + if err != nil { + log.Debug("dnsproxy: closing quic stream: %s", err) + } return p.readMsg(stream) } @@ -259,29 +265,30 @@ func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { // argument controls whether we should try to use the existing cached // connection. If it is false, we will forcibly create a new connection and // close the existing one if needed. -func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { +func (p *dnsOverQUIC) getConnection(useCached bool) (c quic.Connection, err error) { var conn quic.Connection - p.connMu.RLock() - conn = p.conn - if conn != nil && useCached { - p.connMu.RUnlock() - - return conn, nil - } - if conn != nil { - // we're recreating the connection, let's create a new one. - _ = conn.CloseWithError(QUICCodeNoError, "") - } - p.connMu.RUnlock() p.connMu.Lock() defer p.connMu.Unlock() - var err error + conn = p.conn + if conn != nil { + if useCached { + return conn, nil + } + + // We're recreating the connection, let's create a new one. + err = conn.CloseWithError(QUICCodeNoError, "") + if err != nil { + log.Debug("dnsproxy: closing stale connection: %s", err) + } + } + conn, err = p.openConnection() if err != nil { return nil, err } + p.conn = conn return conn, nil @@ -320,7 +327,9 @@ func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) { defer cancel() stream, err := conn.OpenStreamSync(ctx) - if err == nil { + if err != nil { + log.Debug("dnsproxy: opening quic stream: %s", err) + } else { return stream, nil } @@ -330,15 +339,16 @@ func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) { if err != nil { return nil, err } + // Open a new stream. return newConn.OpenStreamSync(ctx) } -// openConnection opens a new QUIC connection. +// openConnection dials a new QUIC connection. func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { dialContext, err := p.getDialer() if err != nil { - return nil, fmt.Errorf("failed to bootstrap QUIC connection: %w", err) + return nil, fmt.Errorf("bootstrapping %s: %w", p.addr, err) } // we're using bootstrapped address instead of what's passed to the function @@ -346,14 +356,18 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { // what IP is actually reachable (when there're v4/v6 addresses). rawConn, err := dialContext(context.Background(), "udp", "") if err != nil { - return nil, fmt.Errorf("failed to open a QUIC connection: %w", err) + return nil, fmt.Errorf("dialing raw connection to %s: %w", p.addr, err) + } + + // It's never actually used. + err = rawConn.Close() + if err != nil { + log.Debug("dnsproxy: closing raw connection for %s: %s", p.addr, err) } - // It's never actually used - _ = rawConn.Close() udpConn, ok := rawConn.(*net.UDPConn) if !ok { - return nil, fmt.Errorf("failed to open connection to %s", p.addr) + return nil, fmt.Errorf("unexpected type %T of connection; should be %T", rawConn, udpConn) } addr := udpConn.RemoteAddr().String() @@ -363,7 +377,7 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { conn, err = quic.DialAddrEarly(ctx, addr, p.tlsConf.Clone(), p.getQUICConfig()) if err != nil { - return nil, fmt.Errorf("opening quic connection to %s: %w", p.addr, err) + return nil, fmt.Errorf("dialing quic connection to %s: %w", p.addr, err) } return conn, nil @@ -393,7 +407,7 @@ func (p *dnsOverQUIC) closeConnWithError(err error) { err = p.conn.CloseWithError(code, "") if err != nil { - log.Error("failed to close the conn: %v", err) + log.Error("dnsproxy: failed to close the conn: %v", err) } p.conn = nil }