diff --git a/internal/netutil/udp.go b/internal/netutil/udp.go index f0ff4d0b5..181f64a55 100644 --- a/internal/netutil/udp.go +++ b/internal/netutil/udp.go @@ -19,6 +19,8 @@ func UDPSetOptions(c *net.UDPConn) (err error) { // UDPRead reads the message from conn using buf and receives a control-message // payload of size udpOOBSize from it. It returns the number of bytes copied // into buf and the source address of the message. +// +// TODO(s.chzhen): Consider using netip.Addr. func UDPRead( conn *net.UDPConn, buf []byte, @@ -28,6 +30,8 @@ func UDPRead( } // UDPWrite writes the data to the remoteAddr using conn. +// +// TODO(s.chzhen): Consider using netip.Addr. func UDPWrite( data []byte, conn *net.UDPConn, diff --git a/main.go b/main.go index 4c1821c1d..fbad255bb 100644 --- a/main.go +++ b/main.go @@ -23,7 +23,6 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/mathutil" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/ameshkov/dnscrypt/v2" goFlags "github.com/jessevdk/go-flags" @@ -153,7 +152,7 @@ type Options struct { // RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for // rate limiting requests - RatelimitSubnetLenIPv6 int `yaml:"ratelimit-subnet-len-ipv6" long:"ratelimit-subnet-len-ipv6" description:"Ratelimit subnet length for IPv6." default:"64"` + RatelimitSubnetLenIPv6 int `yaml:"ratelimit-subnet-len-ipv6" long:"ratelimit-subnet-len-ipv6" description:"Ratelimit subnet length for IPv6." default:"56"` // If true, refuse ANY requests RefuseAny bool `yaml:"refuse-any" long:"refuse-any" description:"If specified, refuse ANY requests" optional:"yes" optional-value:"true"` @@ -336,8 +335,8 @@ func runPprof(options *Options) { // createProxyConfig creates proxy.Config from the command line arguments func createProxyConfig(options *Options) (conf proxy.Config) { conf = proxy.Config{ - RatelimitSubnetMaskIPv4: net.CIDRMask(options.RatelimitSubnetLenIPv4, netutil.IPv4BitLen), - RatelimitSubnetMaskIPv6: net.CIDRMask(options.RatelimitSubnetLenIPv6, netutil.IPv6BitLen), + RatelimitSubnetLenIPv4: options.RatelimitSubnetLenIPv4, + RatelimitSubnetLenIPv6: options.RatelimitSubnetLenIPv6, Ratelimit: options.Ratelimit, CacheEnabled: options.Cache, diff --git a/proxy/cache_test.go b/proxy/cache_test.go index 5f0de6ce7..d607d0729 100644 --- a/proxy/cache_test.go +++ b/proxy/cache_test.go @@ -2,6 +2,7 @@ package proxy import ( "net" + "net/netip" "strings" "sync" "testing" @@ -324,7 +325,7 @@ func TestCacheExpirationWithTTLOverride(t *testing.T) { t.Run("replace_min", func(t *testing.T) { d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{} + d.Addr = netip.AddrPort{} u.ans = []dns.RR{&dns.A{ Hdr: dns.RR_Header{ @@ -348,7 +349,7 @@ func TestCacheExpirationWithTTLOverride(t *testing.T) { t.Run("replace_max", func(t *testing.T) { d.Req = createHostTestMessage("host2") - d.Addr = &net.TCPAddr{} + d.Addr = netip.AddrPort{} u.ans = []dns.RR{&dns.A{ Hdr: dns.RR_Header{ diff --git a/proxy/config.go b/proxy/config.go index 8166f7bec..93b8831a5 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -65,18 +65,26 @@ type Config struct { // Rate-limiting and anti-DNS amplification measures // -- + // + // TODO(s.chzhen): Extract ratelimit settings to a separate structure. - // RatelimitSubnetMaskIPv4 is a subnet mask for IPv4 addresses used for + // RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for // rate limiting requests. - RatelimitSubnetMaskIPv4 net.IPMask + RatelimitSubnetLenIPv4 int - // RatelimitSubnetMaskIPv6 is a subnet mask for IPv6 addresses used for + // RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for // rate limiting requests. - RatelimitSubnetMaskIPv6 net.IPMask + RatelimitSubnetLenIPv6 int + + // Ratelimit is a maximum number of requests per second from a given IP (0 + // to disable). + Ratelimit int + + // RatelimitWhitelist is a list of IP addresses excluded from rate limiting. + RatelimitWhitelist []netip.Addr - Ratelimit int // max number of requests per second from a given IP (0 to disable) - RatelimitWhitelist []string // a list of whitelisted client IP addresses - RefuseAny bool // if true, refuse ANY requests + // RefuseAny makes proxy refuse the requests of type ANY. + RefuseAny bool // TrustedProxies is the list of IP addresses and CIDR networks to // detect proxy servers addresses the DoH requests from which should be @@ -240,22 +248,27 @@ func (p *Proxy) validateRatelimit() (err error) { return nil } - if p.RatelimitSubnetMaskIPv4 == nil { - return errors.Error("ipv4 subnet mask is nil") + err = checkInclusion(p.RatelimitSubnetLenIPv4, 0, netutil.IPv4BitLen) + if err != nil { + return fmt.Errorf("ratelimit subnet len ipv4 is invalid: %w", err) } - _, bits := p.RatelimitSubnetMaskIPv4.Size() - if bits != netutil.IPv4BitLen { - return fmt.Errorf("ipv4 subnet mask must contain %d bits, got %d", netutil.IPv4BitLen, bits) + err = checkInclusion(p.RatelimitSubnetLenIPv6, 0, netutil.IPv6BitLen) + if err != nil { + return fmt.Errorf("ratelimit subnet len ipv6 is invalid: %w", err) } - if p.RatelimitSubnetMaskIPv6 == nil { - return errors.Error("ipv6 subnet is nil") - } + return nil +} - _, bits = p.RatelimitSubnetMaskIPv6.Size() - if bits != netutil.IPv6BitLen { - return fmt.Errorf("ipv6 subnet mask must contain %d bits, got %d", netutil.IPv6BitLen, bits) +// checkInclusion returns an error if a n is not in the inclusive range between +// minN and maxN. +func checkInclusion(n, minN, maxN int) (err error) { + switch { + case n < minN: + return fmt.Errorf("value %d less than min %d", n, minN) + case n > maxN: + return fmt.Errorf("value %d greater than max %d", n, maxN) } return nil @@ -268,14 +281,11 @@ func (p *Proxy) logConfigInfo() { } if p.Ratelimit > 0 { - sizeV4, _ := p.RatelimitSubnetMaskIPv4.Size() - sizeV6, _ := p.RatelimitSubnetMaskIPv6.Size() - log.Info( "Ratelimit is enabled and set to %d rps, IPv4 subnet mask len %d, IPv6 subnet mask len %d", p.Ratelimit, - sizeV4, - sizeV6, + p.RatelimitSubnetLenIPv4, + p.RatelimitSubnetLenIPv6, ) } diff --git a/proxy/dns64_test.go b/proxy/dns64_test.go index 5bba35d5d..b007a4eff 100644 --- a/proxy/dns64_test.go +++ b/proxy/dns64_test.go @@ -156,10 +156,7 @@ func TestProxy_Resolve_dns64(t *testing.T) { require.NoError(t, err) ptrGlobDomain = dns.Fqdn(ptrGlobDomain) - cliIP := &net.TCPAddr{ - IP: net.IP{192, 168, 1, 1}, - Port: 1234, - } + cliAddrPort := netip.MustParseAddrPort("192.168.1.1:1234") const ( sectionAnswer = iota @@ -354,7 +351,7 @@ func TestProxy_Resolve_dns64(t *testing.T) { req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) dctx := &DNSContext{ Req: req, - Addr: cliIP, + Addr: cliAddrPort, } err = p.Resolve(dctx) diff --git a/proxy/dnscontext.go b/proxy/dnscontext.go index c1770d0e0..a6e366d98 100644 --- a/proxy/dnscontext.go +++ b/proxy/dnscontext.go @@ -28,9 +28,7 @@ type DNSContext struct { QUICStream quic.Stream // Addr is the address of the client. - // - // TODO(s.chzhen): Use [netip.AddrPort]. - Addr net.Addr + Addr netip.AddrPort // Upstream is the upstream that resolved the request. In case of cached // response it's nil. diff --git a/proxy/proxy.go b/proxy/proxy.go index 8932a1801..b6e2740c6 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -24,6 +24,7 @@ import ( gocache "github.com/patrickmn/go-cache" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" + "golang.org/x/exp/slices" ) const ( @@ -183,6 +184,7 @@ type Proxy struct { // Init populates fields of p but does not start listeners. func (p *Proxy) Init() (err error) { + // TODO(s.chzhen): Consider moving to [Proxy.validateConfig]. err = p.validateBasicAuth() if err != nil { return fmt.Errorf("basic auth: %w", err) @@ -233,6 +235,9 @@ func (p *Proxy) Init() (err error) { return fmt.Errorf("setting up DNS64: %w", err) } + p.RatelimitWhitelist = slices.Clone(p.RatelimitWhitelist) + slices.SortFunc(p.RatelimitWhitelist, netip.Addr.Compare) + return nil } @@ -512,10 +517,9 @@ func (p *Proxy) selectUpstreams(d *DNSContext) (upstreams []upstream.Upstream) { return nil } - ip, _ := netutil.IPAndPortFromAddr(d.Addr) // TODO(e.burkov): Detect against the actual configured subnet set. // Perhaps, even much earlier. - if !netutil.IsLocallyServed(ip) { + if !netutil.IsLocallyServedAddr(d.Addr.Addr()) { return nil } @@ -708,10 +712,7 @@ func (dctx *DNSContext) processECS(cliIP net.IP) { // Set ECS. if cliIP == nil { - cliIP, _ = netutil.IPAndPortFromAddr(dctx.Addr) - if cliIP == nil { - return - } + cliIP = dctx.Addr.Addr().AsSlice() } if !netutil.IsSpecialPurpose(cliIP) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 5957ccb74..53d7b4f44 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -852,10 +852,8 @@ func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) { 0, false, ), - Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{ - IP: net.IP{1, 2, 3, 0}, - }, + Req: createHostTestMessage("host"), + Addr: netip.MustParseAddrPort("1.2.3.0:1234"), } var err error @@ -893,7 +891,7 @@ func TestExchangeCustomUpstreamConfig(t *testing.T) { false, ), Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 0}}, + Addr: netip.MustParseAddrPort("1.2.3.0:1234"), } err = prx.Resolve(&d) @@ -945,7 +943,7 @@ func TestExchangeCustomUpstreamConfigCache(t *testing.T) { d := DNSContext{ CustomUpstreamConfig: customUpstreamConfig, Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 0}}, + Addr: netip.MustParseAddrPort("1.2.3.0:1234"), } err = prx.Resolve(&d) @@ -1036,7 +1034,7 @@ func TestECSProxy(t *testing.T) { t.Run("cache_subnet", func(t *testing.T) { d := DNSContext{ Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: ip1230}, + Addr: netip.MustParseAddrPort("1.2.3.0:1234"), } err = prx.Resolve(&d) @@ -1049,7 +1047,7 @@ func TestECSProxy(t *testing.T) { t.Run("serve_subnet_cache", func(t *testing.T) { d := DNSContext{ Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 1}}, + Addr: netip.MustParseAddrPort("1.2.3.1:1234"), } u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil @@ -1063,7 +1061,7 @@ func TestECSProxy(t *testing.T) { t.Run("another_subnet", func(t *testing.T) { d := DNSContext{ Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: ip2230}, + Addr: netip.MustParseAddrPort("2.2.3.0:1234"), } u.ans = []dns.RR{&dns.A{ Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60}, @@ -1081,7 +1079,7 @@ func TestECSProxy(t *testing.T) { t.Run("cache_general", func(t *testing.T) { d := DNSContext{ Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: net.IP{127, 0, 0, 1}}, + Addr: netip.MustParseAddrPort("127.0.0.1:1234"), } u.ans = []dns.RR{&dns.A{ Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60}, @@ -1099,7 +1097,7 @@ func TestECSProxy(t *testing.T) { t.Run("serve_general_cache", func(t *testing.T) { d := DNSContext{ Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: net.IP{127, 0, 0, 2}}, + Addr: netip.MustParseAddrPort("127.0.0.2:1234"), } u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil @@ -1138,7 +1136,7 @@ func TestECSProxyCacheMinMaxTTL(t *testing.T) { // first request d := DNSContext{ Req: createHostTestMessage("host"), - Addr: &net.TCPAddr{IP: clientIP}, + Addr: netip.MustParseAddrPort("1.2.3.0:1234"), } err = prx.Resolve(&d) require.NoError(t, err) @@ -1156,9 +1154,7 @@ func TestECSProxyCacheMinMaxTTL(t *testing.T) { // 2nd request clientIP = net.IP{1, 2, 4, 0} d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: clientIP, - } + d.Addr = netip.MustParseAddrPort("1.2.4.0:1234") u.ans = []dns.RR{&dns.A{ Hdr: dns.RR_Header{ Rrtype: dns.TypeA, @@ -1242,8 +1238,8 @@ func createTestProxy(t *testing.T, tlsConfig *tls.Config) *Proxy { p.TrustedProxies = []string{"0.0.0.0/0", "::0/0"} - p.RatelimitSubnetMaskIPv4 = net.CIDRMask(24, 32) - p.RatelimitSubnetMaskIPv6 = net.CIDRMask(64, 128) + p.RatelimitSubnetLenIPv4 = 24 + p.RatelimitSubnetLenIPv6 = 64 return &p } diff --git a/proxy/ratelimit.go b/proxy/ratelimit.go index 793ba3703..b2f07dc5c 100644 --- a/proxy/ratelimit.go +++ b/proxy/ratelimit.go @@ -1,11 +1,10 @@ package proxy import ( - "net" + "net/netip" "time" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/netutil" rate "github.com/beefsack/go-rate" gocache "github.com/patrickmn/go-cache" "golang.org/x/exp/slices" @@ -28,44 +27,33 @@ func (p *Proxy) limiterForIP(ip string) interface{} { return value } -// isRatelimited checks if the specified IP is ratelimited. -func (p *Proxy) isRatelimited(addr net.Addr) (ok bool) { +func (p *Proxy) isRatelimited(addr netip.Addr) (ok bool) { if p.Ratelimit <= 0 { // The ratelimit is disabled. return false } - ip, _ := netutil.IPAndPortFromAddr(addr) - if ip == nil { - log.Printf("failed to split %v into host/port", addr) - + addr = addr.Unmap() + // Already sorted by [Proxy.Init]. + _, ok = slices.BinarySearchFunc(p.RatelimitWhitelist, addr, netip.Addr.Compare) + if ok { return false } - ipStr := ip.String() - - if len(p.RatelimitWhitelist) > 0 { - slices.Sort(p.RatelimitWhitelist) - _, ok = slices.BinarySearch(p.RatelimitWhitelist, ipStr) - if ok { - // Don't ratelimit if the IP is allowlisted. - return false - } - } - - if ip.To4() != nil { - ip = ip.Mask(p.RatelimitSubnetMaskIPv4) + var pref netip.Prefix + if addr.Is4() { + pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv4) } else { - ip = ip.Mask(p.RatelimitSubnetMaskIPv6) + pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv6) } + pref = pref.Masked() // TODO(s.chzhen): Improve caching. Decrease allocations. - ipStr = ip.String() - + ipStr := pref.Addr().String() value := p.limiterForIP(ipStr) rl, ok := value.(*rate.RateLimiter) if !ok { - log.Printf("SHOULD NOT HAPPEN: %T found in ratelimit cache", value) + log.Error("dnsproxy: %T found in ratelimit cache", value) return false } diff --git a/proxy/ratelimit_test.go b/proxy/ratelimit_test.go index 2af129516..64402c17b 100644 --- a/proxy/ratelimit_test.go +++ b/proxy/ratelimit_test.go @@ -1,7 +1,7 @@ package proxy import ( - "net" + "net/netip" "testing" "time" @@ -52,7 +52,7 @@ func TestRatelimiting(t *testing.T) { p := Proxy{} p.Ratelimit = 1 - addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1232} + addr := netip.MustParseAddr("127.0.0.1") limited := p.isRatelimited(addr) @@ -71,9 +71,13 @@ func TestWhitelist(t *testing.T) { // rate limit is 1 per sec with whitelist p := Proxy{} p.Ratelimit = 1 - p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"} + p.RatelimitWhitelist = []netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("127.0.0.2"), + netip.MustParseAddr("127.0.0.125"), + } - addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1232} + addr := netip.MustParseAddr("127.0.0.1") limited := p.isRatelimited(addr) diff --git a/proxy/server.go b/proxy/server.go index e6ec5ddda..68bebe19a 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -101,7 +101,7 @@ func (p *Proxy) handleDNSRequest(d *DNSContext) error { } // ratelimit based on IP only, protects CPU cycles and outbound connections - if d.Proto == ProtoUDP && p.isRatelimited(d.Addr) { + if d.Proto == ProtoUDP && p.isRatelimited(d.Addr.Addr()) { log.Tracef("Ratelimiting %v based on IP only", d.Addr) return nil // do nothing, don't reply, we got ratelimited } diff --git a/proxy/server_dnscrypt.go b/proxy/server_dnscrypt.go index fa658cd41..cdd1abeae 100644 --- a/proxy/server_dnscrypt.go +++ b/proxy/server_dnscrypt.go @@ -6,6 +6,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/ameshkov/dnscrypt/v2" "github.com/miekg/dns" ) @@ -69,7 +70,7 @@ var _ dnscrypt.Handler = &dnsCryptHandler{} // ServeDNS - processes the DNS query func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) error { d := h.proxy.newDNSContext(ProtoDNSCrypt, req) - d.Addr = rw.RemoteAddr() + d.Addr = netutil.NetAddrToAddrPort(rw.RemoteAddr()) d.DNSCryptResponseWriter = rw h.requestGoroutinesSema.acquire() diff --git a/proxy/server_https.go b/proxy/server_https.go index 5031d6c63..1175590db 100644 --- a/proxy/server_https.go +++ b/proxy/server_https.go @@ -7,13 +7,12 @@ import ( "io" "net" "net/http" + "net/netip" "net/url" - "strconv" "strings" "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" @@ -160,11 +159,12 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { d.HTTPRequest = r d.HTTPResponseWriter = w - if prx != nil { - ip, _ := netutil.IPAndPortFromAddr(prx) + if prx.IsValid() { log.Debug("dnsproxy: request came from proxy server %s", prx) - if !p.proxyVerifier.Contains(ip) { - log.Debug("dnsproxy: proxy %s is not trusted, using original remote addr", ip) + + // TODO(s.chzhen): Consider using []netip.Prefix. + if !p.proxyVerifier.Contains(prx.Addr().AsSlice()) { + log.Debug("dnsproxy: proxy %s is not trusted, using original remote addr", prx) d.Addr = prx } } @@ -181,7 +181,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (p *Proxy) checkBasicAuth( w http.ResponseWriter, r *http.Request, - raddr net.Addr, + raddr netip.AddrPort, ) (shouldHandle bool) { ui := p.Config.Userinfo if ui == nil { @@ -241,73 +241,54 @@ func (p *Proxy) respondHTTPS(d *DNSContext) (err error) { } // realIPFromHdrs extracts the actual client's IP address from the first -// suitable r's header. It returns nil if r doesn't contain any information -// about real client's IP address. Current headers priority is: +// suitable r's header. It returns an error if r doesn't contain any +// information about real client's IP address. Current headers priority is: // // 1. [httphdr.CFConnectingIP] // 2. [httphdr.TrueClientIP] // 3. [httphdr.XRealIP] // 4. [httphdr.XForwardedFor] -func realIPFromHdrs(r *http.Request) (realIP net.IP) { +func realIPFromHdrs(r *http.Request) (realIP netip.Addr, err error) { for _, h := range []string{ httphdr.CFConnectingIP, httphdr.TrueClientIP, httphdr.XRealIP, } { - realIP = net.ParseIP(strings.TrimSpace(r.Header.Get(h))) - if realIP != nil { - return realIP + realIP, err = netip.ParseAddr(strings.TrimSpace(r.Header.Get(h))) + if err == nil { + return realIP, nil } } xff := r.Header.Get(httphdr.XForwardedFor) firstComma := strings.IndexByte(xff, ',') - if firstComma == -1 { - return net.ParseIP(strings.TrimSpace(xff)) + if firstComma > 0 { + xff = xff[:firstComma] } - return net.ParseIP(strings.TrimSpace(xff[:firstComma])) + return netip.ParseAddr(strings.TrimSpace(xff)) } // remoteAddr returns the real client's address and the IP address of the latest // proxy server if any. -func remoteAddr(r *http.Request) (addr, prx net.Addr, err error) { - var hostStr, portStr string - if hostStr, portStr, err = net.SplitHostPort(r.RemoteAddr); err != nil { - return nil, nil, err +func remoteAddr(r *http.Request) (addr, prx netip.AddrPort, err error) { + host, err := netip.ParseAddrPort(r.RemoteAddr) + if err != nil { + return netip.AddrPort{}, netip.AddrPort{}, err } - var port int - if port, err = strconv.Atoi(portStr); err != nil { - return nil, nil, err - } + realIP, err := realIPFromHdrs(r) + if err != nil { + log.Debug("dnsproxy: getting ip address from http request: %s", err) - host := net.ParseIP(hostStr) - if host == nil { - return nil, nil, fmt.Errorf("invalid ip: %s", hostStr) + return host, netip.AddrPort{}, nil } - h3 := r.Context().Value(http3.ServerContextKey) != nil - - if realIP := realIPFromHdrs(r); realIP != nil { - log.Tracef("Using IP address from HTTP request: %s", realIP) + log.Debug("dnsproxy: using ip address from http request: %s", realIP) - // TODO(a.garipov): Add port if we can get it from headers like - // X-Real-Port, X-Forwarded-Port, etc. - if h3 { - addr = &net.UDPAddr{IP: realIP, Port: 0} - prx = &net.UDPAddr{IP: host, Port: port} - } else { - addr = &net.TCPAddr{IP: realIP, Port: 0} - prx = &net.TCPAddr{IP: host, Port: port} - } - - return addr, prx, nil - } - - if h3 { - return &net.UDPAddr{IP: host, Port: port}, nil, nil - } + // TODO(a.garipov): Add port if we can get it from headers like X-Real-Port, + // X-Forwarded-Port, etc. + addr = netip.AddrPortFrom(realIP, 0) - return &net.TCPAddr{IP: host, Port: port}, nil, nil + return addr, host, nil } diff --git a/proxy/server_https_test.go b/proxy/server_https_test.go index 70b1138c5..cf81bca93 100644 --- a/proxy/server_https_test.go +++ b/proxy/server_https_test.go @@ -9,11 +9,11 @@ import ( "io" "net" "net/http" + "net/netip" "net/url" "strings" "testing" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -61,16 +61,19 @@ func TestHttpsProxy(t *testing.T) { } func TestProxy_trustedProxies(t *testing.T) { - clientIP, proxyIP := net.IP{1, 2, 3, 4}, net.IP{127, 0, 0, 1} + var ( + clientAddr = netip.MustParseAddr("1.2.3.4") + proxyAddr = netip.MustParseAddr("127.0.0.1") + ) - doRequest := func(t *testing.T, proxyAddr string, expectedClientIP net.IP) { + doRequest := func(t *testing.T, addr string, expectedClientIP netip.Addr) { // Prepare the proxy server. tlsConf, caPem := createServerTLSConfig(t) dnsProxy := createTestProxy(t, tlsConf) - var gotAddr net.Addr + var gotAddr netip.Addr dnsProxy.RequestHandler = func(_ *Proxy, d *DNSContext) (err error) { - gotAddr = d.Addr + gotAddr = d.Addr.Addr() return dnsProxy.Resolve(d) } @@ -79,7 +82,7 @@ func TestProxy_trustedProxies(t *testing.T) { msg := createTestMessage() - dnsProxy.TrustedProxies = []string{proxyAddr} + dnsProxy.TrustedProxies = []string{addr} // Start listening. serr := dnsProxy.Start() @@ -87,51 +90,59 @@ func TestProxy_trustedProxies(t *testing.T) { testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop) hdrs := map[string]string{ - "X-Forwarded-For": strings.Join([]string{clientIP.String(), proxyIP.String()}, ","), + "X-Forwarded-For": strings.Join([]string{clientAddr.String(), proxyAddr.String()}, ","), } resp := sendTestDoHMessage(t, client, msg, hdrs) requireResponse(t, msg, resp) - ip, _ := netutil.IPAndPortFromAddr(gotAddr) - require.True(t, ip.Equal(expectedClientIP)) + require.Equal(t, expectedClientIP, gotAddr) } t.Run("success", func(t *testing.T) { - doRequest(t, proxyIP.String(), clientIP) + doRequest(t, proxyAddr.String(), clientAddr) }) t.Run("not_in_trusted", func(t *testing.T) { - doRequest(t, "127.0.0.2", proxyIP) + doRequest(t, "127.0.0.2", proxyAddr) }) } func TestAddrsFromRequest(t *testing.T) { - theIP, anotherIP := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5} - theIPStr, anotherIPStr := theIP.String(), anotherIP.String() + var ( + theIP = netip.AddrFrom4([4]byte{1, 2, 3, 4}) + anotherIP = netip.AddrFrom4([4]byte{1, 2, 3, 5}) + + theIPStr = theIP.String() + anotherIPStr = anotherIP.String() + ) testCases := []struct { - name string - hdrs map[string]string - wantIP net.IP + name string + hdrs map[string]string + wantIP netip.Addr + wantErr string }{{ name: "cf-connecting-ip", hdrs: map[string]string{ "CF-Connecting-IP": theIPStr, }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }, { name: "true-client-ip", hdrs: map[string]string{ "True-Client-IP": theIPStr, }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }, { name: "x-real-ip", hdrs: map[string]string{ "X-Real-IP": theIPStr, }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }, { name: "no_any", hdrs: map[string]string{ @@ -139,7 +150,8 @@ func TestAddrsFromRequest(t *testing.T) { "True-Client-IP": "invalid", "X-Real-IP": "invalid", }, - wantIP: nil, + wantIP: netip.Addr{}, + wantErr: `ParseAddr(""): unable to parse IP`, }, { name: "priority", hdrs: map[string]string{ @@ -148,43 +160,50 @@ func TestAddrsFromRequest(t *testing.T) { "X-Real-IP": anotherIPStr, "CF-Connecting-IP": theIPStr, }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }, { name: "x-forwarded-for_simple", hdrs: map[string]string{ "X-Forwarded-For": strings.Join([]string{anotherIPStr, theIPStr}, ","), }, - wantIP: anotherIP, + wantIP: anotherIP, + wantErr: "", }, { name: "x-forwarded-for_single", hdrs: map[string]string{ "X-Forwarded-For": theIPStr, }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }, { name: "x-forwarded-for_invalid_proxy", hdrs: map[string]string{ "X-Forwarded-For": strings.Join([]string{theIPStr, "invalid"}, ","), }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }, { name: "x-forwarded-for_empty", hdrs: map[string]string{ "X-Forwarded-For": "", }, - wantIP: nil, + wantIP: netip.Addr{}, + wantErr: `ParseAddr(""): unable to parse IP`, }, { name: "x-forwarded-for_redundant_spaces", hdrs: map[string]string{ "X-Forwarded-For": " " + theIPStr + " ,\t" + anotherIPStr, }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }, { name: "cf-connecting-ip_redundant_spaces", hdrs: map[string]string{ "CF-Connecting-IP": " " + theIPStr + "\t", }, - wantIP: theIP, + wantIP: theIP, + wantErr: "", }} for _, tc := range testCases { @@ -196,31 +215,44 @@ func TestAddrsFromRequest(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - ip := realIPFromHdrs(r) - assert.True(t, tc.wantIP.Equal(ip)) + var ip netip.Addr + ip, err = realIPFromHdrs(r) + testutil.AssertErrorMsg(t, tc.wantErr, err) + + assert.Equal(t, tc.wantIP, ip) }) } } func TestRemoteAddr(t *testing.T) { - theIP, anotherIP, thirdIP := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}, net.IP{1, 2, 3, 6} - theIPStr, anotherIPStr, thirdIPStr := theIP.String(), anotherIP.String(), thirdIP.String() - rAddr := &net.TCPAddr{IP: theIP, Port: 1} + const thePort = 4321 + + var ( + theIP = netip.AddrFrom4([4]byte{1, 2, 3, 4}) + anotherIP = netip.AddrFrom4([4]byte{1, 2, 3, 5}) + thirdIP = netip.AddrFrom4([4]byte{1, 2, 3, 6}) + + theIPStr = theIP.String() + anotherIPStr = anotherIP.String() + thirdIPStr = thirdIP.String() + ) + + rAddr := netip.AddrPortFrom(theIP, thePort) testCases := []struct { name string remoteAddr string hdrs map[string]string wantErr string - wantIP net.IP - wantProxy net.IP + wantIP netip.AddrPort + wantProxy netip.AddrPort }{{ name: "no_proxy", remoteAddr: rAddr.String(), hdrs: nil, wantErr: "", - wantIP: theIP, - wantProxy: nil, + wantIP: netip.AddrPortFrom(theIP, thePort), + wantProxy: netip.AddrPort{}, }, { name: "proxied_with_cloudflare", remoteAddr: rAddr.String(), @@ -228,8 +260,8 @@ func TestRemoteAddr(t *testing.T) { "CF-Connecting-IP": anotherIPStr, }, wantErr: "", - wantIP: anotherIP, - wantProxy: theIP, + wantIP: netip.AddrPortFrom(anotherIP, 0), + wantProxy: netip.AddrPortFrom(theIP, thePort), }, { name: "proxied_once", remoteAddr: rAddr.String(), @@ -237,8 +269,8 @@ func TestRemoteAddr(t *testing.T) { "X-Forwarded-For": anotherIPStr, }, wantErr: "", - wantIP: anotherIP, - wantProxy: theIP, + wantIP: netip.AddrPortFrom(anotherIP, 0), + wantProxy: netip.AddrPortFrom(theIP, thePort), }, { name: "proxied_multiple", remoteAddr: rAddr.String(), @@ -246,38 +278,38 @@ func TestRemoteAddr(t *testing.T) { "X-Forwarded-For": strings.Join([]string{anotherIPStr, thirdIPStr}, ","), }, wantErr: "", - wantIP: anotherIP, - wantProxy: theIP, + wantIP: netip.AddrPortFrom(anotherIP, 0), + wantProxy: netip.AddrPortFrom(theIP, thePort), }, { name: "no_port", remoteAddr: theIPStr, hdrs: nil, - wantErr: "address " + theIPStr + ": missing port in address", - wantIP: nil, - wantProxy: nil, + wantErr: "not an ip:port", + wantIP: netip.AddrPort{}, + wantProxy: netip.AddrPort{}, }, { name: "bad_port", remoteAddr: theIPStr + ":notport", hdrs: nil, - wantErr: "strconv.Atoi: parsing \"notport\": invalid syntax", - wantIP: nil, - wantProxy: nil, + wantErr: `invalid port "notport" parsing "1.2.3.4:notport"`, + wantIP: netip.AddrPort{}, + wantProxy: netip.AddrPort{}, }, { name: "bad_host", remoteAddr: "host:1", hdrs: nil, - wantErr: "invalid ip: host", - wantIP: nil, - wantProxy: nil, + wantErr: `ParseAddr("host"): unable to parse IP`, + wantIP: netip.AddrPort{}, + wantProxy: netip.AddrPort{}, }, { name: "bad_proxied_host", remoteAddr: "host:1", hdrs: map[string]string{ "CF-Connecting-IP": theIPStr, }, - wantErr: "invalid ip: host", - wantIP: nil, - wantProxy: nil, + wantErr: `ParseAddr("host"): unable to parse IP`, + wantIP: netip.AddrPort{}, + wantProxy: netip.AddrPort{}, }} for _, tc := range testCases { @@ -290,20 +322,17 @@ func TestRemoteAddr(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - addr, prx, aErr := remoteAddr(r) + var addr, prx netip.AddrPort + addr, prx, err = remoteAddr(r) if tc.wantErr != "" { - assert.Equal(t, tc.wantErr, aErr.Error()) + testutil.AssertErrorMsg(t, tc.wantErr, err) return } - require.NoError(t, aErr) - - ip, _ := netutil.IPAndPortFromAddr(addr) - assert.True(t, ip.Equal(tc.wantIP)) - - prxIP, _ := netutil.IPAndPortFromAddr(prx) - assert.True(t, tc.wantProxy.Equal(prxIP)) + require.NoError(t, err) + assert.Equal(t, tc.wantIP, addr) + assert.Equal(t, tc.wantProxy, prx) }) } } diff --git a/proxy/server_quic.go b/proxy/server_quic.go index 658d75e2f..9936b18a0 100644 --- a/proxy/server_quic.go +++ b/proxy/server_quic.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/bluele/gcache" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -203,7 +204,7 @@ func (p *Proxy) handleQUICStream(stream quic.Stream, conn quic.Connection) { } d := p.newDNSContext(ProtoQUIC, req) - d.Addr = conn.RemoteAddr() + d.Addr = netutil.NetAddrToAddrPort(conn.RemoteAddr()) d.QUICStream = stream d.QUICConnection = conn d.DoQVersion = doqVersion diff --git a/proxy/server_tcp.go b/proxy/server_tcp.go index 749a6ecd6..780123786 100644 --- a/proxy/server_tcp.go +++ b/proxy/server_tcp.go @@ -12,6 +12,7 @@ import ( proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) @@ -126,7 +127,7 @@ func (p *Proxy) handleTCPConnection(conn net.Conn, proto Proto) { } d := p.newDNSContext(proto, req) - d.Addr = conn.RemoteAddr() + d.Addr = netutil.NetAddrToAddrPort(conn.RemoteAddr()) d.Conn = conn err = p.handleDNSRequest(d) diff --git a/proxy/server_udp.go b/proxy/server_udp.go index 70096b4dc..677f023e6 100644 --- a/proxy/server_udp.go +++ b/proxy/server_udp.go @@ -9,6 +9,7 @@ import ( proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) @@ -114,7 +115,7 @@ func (p *Proxy) udpHandlePacket( } d := p.newDNSContext(ProtoUDP, req) - d.Addr = remoteAddr + d.Addr = netutil.NetAddrToAddrPort(remoteAddr) d.Conn = conn d.localIP = localIP @@ -139,7 +140,7 @@ func (p *Proxy) respondUDP(d *DNSContext) error { } conn := d.Conn.(*net.UDPConn) - rAddr := d.Addr.(*net.UDPAddr) + rAddr := net.UDPAddrFromAddrPort(d.Addr) n, err := proxynetutil.UDPWrite(bytes, conn, rAddr, d.localIP) if err != nil { if errors.Is(err, net.ErrClosed) {