Skip to content

Commit

Permalink
Add interface option (#65)
Browse files Browse the repository at this point in the history
* speedtest: move source bound Dialer setup to newDialerAddressBound().

* Add "--interface" option.
  • Loading branch information
MagnaboscoL authored Sep 10, 2024
1 parent 6103965 commit 11183cb
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 49 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ GLOBAL OPTIONS:
--server-json value Use an alternative server list from remote JSON file
--local-json value Use an alternative server list from local JSON file,
or read from stdin with "--local-json -".
--source SOURCE SOURCE IP address to bind to
--source SOURCE SOURCE IP address to bind to. Incompatible with --interface.
--interface INTERFACE The name of the network interface to bind to. Example: "enp0s3".
Not supported on Windows and incompatible with --source.
Implies --no-icmp.
--timeout TIMEOUT HTTP TIMEOUT in seconds. (default: 15)
--duration value Upload and download test duration in seconds (default: 15)
--chunks value Chunks to download from server, chunk size depends on server configuration (default: 100)
Expand Down
1 change: 1 addition & 0 deletions defs/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
OptionExclude = "exclude"
OptionServerJSON = "server-json"
OptionSource = "source"
OptionInterface = "interface"
OptionTimeout = "timeout"
OptionChunks = "chunks"
OptionUploadSize = "upload-size"
Expand Down
4 changes: 4 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ func main() {
Name: defs.OptionSource,
Usage: "`SOURCE` IP address to bind to",
},
&cli.StringFlag{
Name: defs.OptionInterface,
Usage: "network INTERFACE to bind to",
},
&cli.IntFlag{
Name: defs.OptionTimeout,
Usage: "HTTP `TIMEOUT` in seconds.",
Expand Down
4 changes: 2 additions & 2 deletions speedtest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const (
)

// doSpeedTest is where the actual speed test happens
func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.TelemetryServer, network string, silent bool) error {
func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.TelemetryServer, network string, silent bool, noICMP bool) error {
if serverCount := len(servers); serverCount > 1 {
log.Infof("Testing against %d servers", serverCount)
}
Expand Down Expand Up @@ -70,7 +70,7 @@ func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.Tel
}

// skip ICMP if option given
currentServer.NoICMP = c.Bool(defs.OptionNoICMP)
currentServer.NoICMP = noICMP

p, jitter, err := currentServer.ICMPPingAndJitter(pingCount, c.String(defs.OptionSource), network)
if err != nil {
Expand Down
117 changes: 71 additions & 46 deletions speedtest/speedtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ func SpeedTest(c *cli.Context) error {
return nil
}

if c.String(defs.OptionSource) != "" && c.String(defs.OptionInterface) != "" {
return fmt.Errorf("incompatible options '%s' and '%s'", defs.OptionSource, defs.OptionInterface)
}

// set CSV delimiter
gocsv.TagSeparator = c.String(defs.OptionCSVDelimiter)

Expand Down Expand Up @@ -138,6 +142,8 @@ func SpeedTest(c *cli.Context) error {
return errors.New("invalid concurrent requests setting")
}

noICMP := c.Bool(defs.OptionNoICMP)

// HTTP requests timeout
http.DefaultClient.Timeout = time.Duration(c.Int(defs.OptionTimeout)) * time.Second

Expand All @@ -157,57 +163,48 @@ func SpeedTest(c *cli.Context) error {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: c.Bool(defs.OptionSkipCertVerify)}

// bind to source IP address if given, or if ipv4/ipv6 is forced
if src := c.String(defs.OptionSource); src != "" || (forceIPv4 || forceIPv6) {
var localTCPAddr *net.TCPAddr
if src != "" {
// first we parse the IP to see if it's valid
addr, err := net.ResolveIPAddr(network, src)
if err != nil {
if strings.Contains(err.Error(), "no suitable address") {
if forceIPv6 {
log.Errorf("Address %s is not a valid IPv6 address", src)
} else {
log.Errorf("Address %s is not a valid IPv4 address", src)
}
} else {
log.Errorf("Error parsing source IP: %s", err)
}
return err
}

log.Debugf("Using %s as source IP", src)
localTCPAddr = &net.TCPAddr{IP: addr.IP}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
// bind to source IP address if given
if src := c.String(defs.OptionSource); src != "" {
var err error
dialer, err = newDialerAddressBound(src, network)
if err != nil {
return err
}
}

var dialContext func(context.Context, string, string) (net.Conn, error)
defaultDialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
// bind to interface if given
if iface := c.String(defs.OptionInterface); iface != "" {
var err error
dialer, err = newDialerInterfaceBound(iface)
if err != nil {
return err
}
// ICMP ping does not support interface binding.
noICMP = true
}

if localTCPAddr != nil {
defaultDialer.LocalAddr = localTCPAddr
// enforce if ipv4/ipv6 is forced
var dialContext func(context.Context, string, string) (net.Conn, error)
switch {
case forceIPv4:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return dialer.DialContext(ctx, "tcp4", address)
}

switch {
case forceIPv4:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return defaultDialer.DialContext(ctx, "tcp4", address)
}
case forceIPv6:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return defaultDialer.DialContext(ctx, "tcp6", address)
}
default:
dialContext = defaultDialer.DialContext
case forceIPv6:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return dialer.DialContext(ctx, "tcp6", address)
}

// set default HTTP client's Transport to the one that binds the source address
// this is modified from http.DefaultTransport
transport.DialContext = dialContext
default:
dialContext = dialer.DialContext
}

// set default HTTP client's Transport to the one that binds the source address
// this is modified from http.DefaultTransport
transport.DialContext = dialContext
http.DefaultClient.Transport = transport

// load server list
Expand Down Expand Up @@ -258,7 +255,7 @@ func SpeedTest(c *cli.Context) error {

// if --server is given, do speed tests with all of them
if len(c.IntSlice(defs.OptionServer)) > 0 {
return doSpeedTest(c, servers, telemetryServer, network, silent)
return doSpeedTest(c, servers, telemetryServer, network, silent, noICMP)
} else {
// else select the fastest server from the list
log.Info("Selecting the fastest server based on ping")
Expand All @@ -272,7 +269,7 @@ func SpeedTest(c *cli.Context) error {

// spawn 10 concurrent pingers
for i := 0; i < 10; i++ {
go pingWorker(jobs, results, &wg, c.String(defs.OptionSource), network, c.Bool(defs.OptionNoICMP))
go pingWorker(jobs, results, &wg, c.String(defs.OptionSource), network, noICMP)
}

// send ping jobs to workers
Expand Down Expand Up @@ -309,7 +306,7 @@ func SpeedTest(c *cli.Context) error {
}

// do speed test on the server
return doSpeedTest(c, []defs.Server{servers[serverIdx]}, telemetryServer, network, silent)
return doSpeedTest(c, []defs.Server{servers[serverIdx]}, telemetryServer, network, silent, noICMP)
}
}

Expand Down Expand Up @@ -474,3 +471,31 @@ func contains(arr []int, val int) bool {
}
return false
}

func newDialerAddressBound(src string, network string) (dialer *net.Dialer, err error) {
// first we parse the IP to see if it's valid
addr, err := net.ResolveIPAddr(network, src)
if err != nil {
if strings.Contains(err.Error(), "no suitable address") {
if network == "ip6" {
log.Errorf("Address %s is not a valid IPv6 address", src)
} else {
log.Errorf("Address %s is not a valid IPv4 address", src)
}
} else {
log.Errorf("Error parsing source IP: %s", err)
}
return nil, err
}

log.Debugf("Using %s as source IP", src)
localTCPAddr := &net.TCPAddr{IP: addr.IP}

defaultDialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}

defaultDialer.LocalAddr = localTCPAddr
return defaultDialer, nil
}
32 changes: 32 additions & 0 deletions speedtest/util_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package speedtest

import (
"net"
"syscall"
"time"

"golang.org/x/sys/unix"
)

func newDialerInterfaceBound(iface string) (dialer *net.Dialer, err error) {
// In linux there is the socket option SO_BINDTODEVICE.
// Therefore we can really bind the socket to the device instead of binding to the address that
// would be affected by the default routes.
control := func(network, address string, c syscall.RawConn) error {
var errSock error
err := c.Control((func(fd uintptr) {
errSock = unix.BindToDevice(int(fd), iface)
}))
if err != nil {
return err
}
return errSock
}

dialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: control,
}
return dialer, nil
}
10 changes: 10 additions & 0 deletions speedtest/util_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package speedtest

import (
"fmt"
"net"
)

func newDialerInterfaceBound(iface string) (dialer *net.Dialer, err error) {
return nil, fmt.Errorf("cannot bound to interface on Windows")
}

0 comments on commit 11183cb

Please sign in to comment.