diff --git a/client/iface/bind/control_android.go b/client/iface/bind/control_android.go new file mode 100644 index 00000000000..b8a865e3908 --- /dev/null +++ b/client/iface/bind/control_android.go @@ -0,0 +1,12 @@ +package bind + +import ( + wireguard "golang.zx2c4.com/wireguard/conn" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func init() { + // ControlFns is not thread safe and should only be modified during init. + *wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket) +} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 71a0f26aedf..82d1f7718ad 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark() } // setIsLegacy sets the legacy routing setup diff --git a/go.mod b/go.mod index 0a16753ea43..e8c65542280 100644 --- a/go.mod +++ b/go.mod @@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index a4d7ea7f9c1..47975d4eab4 100644 --- a/go.sum +++ b/go.sum @@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= diff --git a/util/net/conn.go b/util/net/conn.go new file mode 100644 index 00000000000..26693f84166 --- /dev/null +++ b/util/net/conn.go @@ -0,0 +1,31 @@ +//go:build !ios + +package net + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} diff --git a/util/net/dial.go b/util/net/dial.go new file mode 100644 index 00000000000..59531149278 --- /dev/null +++ b/util/net/dial.go @@ -0,0 +1,58 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_ios.go b/util/net/dial_ios.go similarity index 100% rename from util/net/dialer_ios.go rename to util/net/dial_ios.go diff --git a/util/net/dialer_android.go b/util/net/dialer_android.go deleted file mode 100644 index 4cbded53634..00000000000 --- a/util/net/dialer_android.go +++ /dev/null @@ -1,25 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/dialer_nonios.go b/util/net/dialer_dial.go similarity index 63% rename from util/net/dialer_nonios.go rename to util/net/dialer_dial.go index 34004a368c1..1659b622051 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_dial.go @@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) } -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { host, _, err := net.SplitHostPort(address) if err != nil { @@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r return result.ErrorOrNil() } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_init_android.go b/util/net/dialer_init_android.go new file mode 100644 index 00000000000..63b9033484e --- /dev/null +++ b/util/net/dialer_init_android.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = ControlProtectSocket +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_init_linux.go similarity index 88% rename from util/net/dialer_linux.go rename to util/net/dialer_init_linux.go index aed5c59a322..d801e608086 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_init_linux.go @@ -7,6 +7,6 @@ import "syscall" // init configures the net.Dialer Control function to set the fwmark on the socket func (d *Dialer) init() { d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_init_nonlinux.go similarity index 58% rename from util/net/dialer_nonlinux.go rename to util/net/dialer_init_nonlinux.go index c838441bdb5..8c57ebbaa52 100644 --- a/util/net/dialer_nonlinux.go +++ b/util/net/dialer_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (d *Dialer) init() { + // implemented on Linux and Android only } diff --git a/util/net/env.go b/util/net/env.go new file mode 100644 index 00000000000..099da39b760 --- /dev/null +++ b/util/net/env.go @@ -0,0 +1,29 @@ +package net + +import ( + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +const ( + envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" +) + +func CustomRoutingDisabled() bool { + if netstack.IsEnabled() { + return true + } + return os.Getenv(envDisableCustomRouting) == "true" +} + +func SkipSocketMark() bool { + if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" { + log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark) + return true + } + return false +} diff --git a/util/net/listen.go b/util/net/listen.go new file mode 100644 index 00000000000..3ae8a9435cb --- /dev/null +++ b/util/net/listen.go @@ -0,0 +1,37 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_ios.go b/util/net/listen_ios.go similarity index 100% rename from util/net/listener_ios.go rename to util/net/listen_ios.go diff --git a/util/net/listener_android.go b/util/net/listener_android.go deleted file mode 100644 index d4167ad53a6..00000000000 --- a/util/net/listener_android.go +++ /dev/null @@ -1,26 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect listener socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/listener_init_android.go b/util/net/listener_init_android.go new file mode 100644 index 00000000000..f7bfa1dab27 --- /dev/null +++ b/util/net/listener_init_android.go @@ -0,0 +1,6 @@ +package net + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = ControlProtectSocket +} diff --git a/util/net/listener_linux.go b/util/net/listener_init_linux.go similarity index 89% rename from util/net/listener_linux.go rename to util/net/listener_init_linux.go index 8d332160a04..e32d5d8942e 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_init_linux.go @@ -9,6 +9,6 @@ import ( // init configures the net.ListenerConfig Control function to set the fwmark on the socket func (l *ListenerConfig) init() { l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/listener_nonlinux.go b/util/net/listener_init_nonlinux.go similarity index 61% rename from util/net/listener_nonlinux.go rename to util/net/listener_init_nonlinux.go index 14a6be49dc3..80f6f7f1a55 100644 --- a/util/net/listener_nonlinux.go +++ b/util/net/listener_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (l *ListenerConfig) init() { + // implemented on Linux and Android only } diff --git a/util/net/listener_nonios.go b/util/net/listener_listen.go similarity index 84% rename from util/net/listener_nonios.go rename to util/net/listener_listen.go index ae4be34949b..efffba40e6e 100644 --- a/util/net/listener_nonios.go +++ b/util/net/listener_listen.go @@ -8,7 +8,6 @@ import ( "net" "sync" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" ) @@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { return err } - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/net.go b/util/net/net.go index 5448eb85a5f..403aa87e7d1 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -2,9 +2,6 @@ package net import ( "net" - "os" - - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) @@ -16,8 +13,6 @@ const ( PreroutingFwmarkRedirected = 0x1BD01 PreroutingFwmarkMasquerade = 0x1BD11 PreroutingFwmarkMasqueradeReturn = 0x1BD12 - - envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) // ConnectionID provides a globally unique identifier for network connections. @@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } - -func CustomRoutingDisabled() bool { - if netstack.IsEnabled() { - return true - } - return os.Getenv(envDisableCustomRouting) == "true" -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 98f49af8d00..fc486ebd496 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -4,29 +4,42 @@ package net import ( "fmt" - "os" "syscall" log "github.com/sirupsen/logrus" ) -const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK" - // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { + if isSocketMarkDisabled() { + return nil + } + sysconn, err := conn.SyscallConn() if err != nil { return fmt.Errorf("get raw conn: %w", err) } - return SetRawSocketMark(sysconn) + return setRawSocketMark(sysconn) +} + +// SetSocketOpt sets the SO_MARK option on the given file descriptor +func SetSocketOpt(fd int) error { + if isSocketMarkDisabled() { + return nil + } + + return setSocketOptInt(fd) } -func SetRawSocketMark(conn syscall.RawConn) error { +func setRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - setErr = SetSocketOpt(int(fd)) + if isSocketMarkDisabled() { + return + } + setErr = setSocketOptInt(int(fd)) }) if err != nil { return fmt.Errorf("control: %w", err) @@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error { return nil } -func SetSocketOpt(fd int) error { +func setSocketOptInt(fd int) error { + return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) +} + +func isSocketMarkDisabled() bool { if CustomRoutingDisabled() { log.Infof("Custom routing is disabled, skipping SO_MARK") - return nil + return true } - // Check for the new environment variable - if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" { - log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK") - return nil + if SkipSocketMark() { + return true } - - return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + return false } diff --git a/util/net/protectsocket_android.go b/util/net/protectsocket_android.go index 64fb45aa44e..febed8a1e2b 100644 --- a/util/net/protectsocket_android.go +++ b/util/net/protectsocket_android.go @@ -1,14 +1,42 @@ package net -import "sync" +import ( + "fmt" + "sync" + "syscall" +) var ( androidProtectSocketLock sync.Mutex androidProtectSocket func(fd int32) bool ) -func SetAndroidProtectSocketFn(f func(fd int32) bool) { +func SetAndroidProtectSocketFn(fn func(fd int32) bool) { androidProtectSocketLock.Lock() - androidProtectSocket = f + androidProtectSocket = fn androidProtectSocketLock.Unlock() } + +// ControlProtectSocket is a Control function that sets the fwmark on the socket +func ControlProtectSocket(_, _ string, c syscall.RawConn) error { + var aErr error + err := c.Control(func(fd uintptr) { + androidProtectSocketLock.Lock() + defer androidProtectSocketLock.Unlock() + + if androidProtectSocket == nil { + aErr = fmt.Errorf("socket protection function not set") + return + } + + if !androidProtectSocket(int32(fd)) { + aErr = fmt.Errorf("failed to protect socket via Android") + } + }) + + if err != nil { + return err + } + + return aErr +}