Skip to content

Commit

Permalink
upstream: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Jan 11, 2024
1 parent 23d43c8 commit f9e7372
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 21 deletions.
46 changes: 33 additions & 13 deletions upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net"
"net/netip"
"net/url"
"strconv"
"strings"
"sync/atomic"
"time"
Expand Down Expand Up @@ -181,32 +182,51 @@ func AddressToUpstream(addr string, opts *Options) (u Upstream, err error) {
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", addr, err)
}
} else {
uu = &url.URL{
Scheme: "udp",
Host: addr,
}
}

// TODO(s.chzhen): Validate hostname. Consider DNS Stamp.
return urlToUpstream(uu, opts)
err = validateUpstreamURL(uu)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}

uu = &url.URL{
Scheme: "udp",
Host: addr,
return urlToUpstream(uu, opts)
}

// validateUpstreamURL returns an error if the upstream URL is not valid.
func validateUpstreamURL(u *url.URL) (err error) {
if u.Scheme == "sdns" {
return nil
}

_, err = netip.ParseAddr(addr)
if err == nil {
return urlToUpstream(uu, opts)
host := u.Host
h, port, splitErr := net.SplitHostPort(host)
if splitErr == nil {
// Validate port.
_, err = strconv.ParseUint(port, 10, 16)
if err != nil {
return fmt.Errorf("invalid port %s: %w", port, err)
}

host = h
}

_, err = netip.ParseAddrPort(addr)
_, err = netip.ParseAddr(host)
if err == nil {
return urlToUpstream(uu, opts)
return nil
}

err = netutil.ValidateHostname(addr)
err = netutil.ValidateHostname(host)
if err != nil {
return nil, fmt.Errorf("invalid address %s: %w", addr, err)
return fmt.Errorf("invalid address %s: %w", host, err)
}

return urlToUpstream(uu, opts)
return nil
}

// urlToUpstream converts uu to an Upstream using opts.
Expand Down
47 changes: 39 additions & 8 deletions upstream/upstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ func TestAddressToUpstream(t *testing.T) {
addr: "1.1.1.1",
opt: nil,
want: "1.1.1.1:53",
}, {
addr: "1.1.1.1:5353",
opt: nil,
want: "1.1.1.1:5353",
}, {
addr: "one:5353",
opt: nil,
want: "one:5353",
}, {
addr: "one.one.one.one",
opt: nil,
Expand Down Expand Up @@ -271,17 +279,18 @@ func TestAddressToUpstream_bads(t *testing.T) {
wantErrMsg: "unsupported url scheme: asdf",
}, {
addr: "12345.1.1.1:1234567",
wantErrMsg: `invalid address 12345.1.1.1:1234567: bad hostname ` +
`"12345.1.1.1:1234567": bad top-level domain name label "1:1234567": ` +
`bad top-level domain name label rune ':'`,
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: ":1234567",
wantErrMsg: `invalid address :1234567: bad hostname ":1234567": bad top-level ` +
`domain name label ":1234567": bad top-level domain name label rune ':'`,
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: "host:",
wantErrMsg: `invalid address host:: bad hostname "host:": bad top-level ` +
`domain name label "host:": bad top-level domain name label rune ':'`,
addr: "host:",
wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`,
}, {
addr: ":53",
wantErrMsg: `invalid address : bad hostname "": hostname is empty`,
}, {
addr: "!!!",
wantErrMsg: `invalid address !!!: bad hostname "!!!": bad top-level domain name ` +
Expand All @@ -290,6 +299,28 @@ func TestAddressToUpstream_bads(t *testing.T) {
addr: "123",
wantErrMsg: `invalid address 123: bad hostname "123": bad top-level domain name ` +
`label "123": all octets are numeric`,
}, {
addr: "tcp://12345.1.1.1:1234567",
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: "tcp://:1234567",
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: "tcp://host:",
wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`,
}, {
addr: "tcp://:53",
wantErrMsg: `invalid address : bad hostname "": hostname is empty`,
}, {
addr: "tcp://!!!",
wantErrMsg: `invalid address !!!: bad hostname "!!!": bad top-level domain name ` +
`label "!!!": bad top-level domain name label rune '!'`,
}, {
addr: "tcp://123",
wantErrMsg: `invalid address 123: bad hostname "123": bad top-level domain name ` +
`label "123": all octets are numeric`,
}}

for _, tc := range testCases {
Expand Down

0 comments on commit f9e7372

Please sign in to comment.