diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index f6fab80c527..42f740e9b54 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -14,8 +14,8 @@ jobs: test: strategy: matrix: - arch: ['386','amd64'] - store: ['jsonfile', 'sqlite'] + arch: [ '386','amd64' ] + store: [ 'jsonfile', 'sqlite' ] runs-on: ubuntu-latest steps: - name: Install Go @@ -36,7 +36,11 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - name: Install modules run: go mod tidy @@ -67,7 +71,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - name: Install modules run: go mod tidy @@ -82,7 +86,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... + run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... @@ -109,7 +113,7 @@ jobs: - name: Run Engine tests in docker with file store run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 - + - name: Run Engine tests in docker with sqlite store run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9f543c74c45..13228250d59 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -40,7 +40,7 @@ jobs: cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index 78d26f0b8fb..7f7b5ef55ba 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -230,8 +230,8 @@ func (e *Engine) Start() error { wgIface, err := e.newWgIface() if err != nil { - log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error()) - return err + log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err) + return fmt.Errorf("new wg interface: %w", err) } e.wgInterface = wgIface @@ -244,29 +244,33 @@ func (e *Engine) Start() error { } e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName) if err != nil { - return err + return fmt.Errorf("create rosenpass manager: %w", err) } err := e.rpManager.Run() if err != nil { - return err + return fmt.Errorf("run rosenpass manager: %w", err) } } initialRoutes, dnsServer, err := e.newDnsServer() if err != nil { e.close() - return err + return fmt.Errorf("create dns server: %w", err) } e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) + if err := e.routeManager.Init(); err != nil { + e.close() + return fmt.Errorf("init route manager: %w", err) + } e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() if err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) e.close() - return err + return fmt.Errorf("create wg interface: %w", err) } e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) @@ -278,7 +282,7 @@ func (e *Engine) Start() error { err = e.routeManager.EnableServerRouter(e.firewall) if err != nil { e.close() - return err + return fmt.Errorf("enable server router: %w", err) } } @@ -286,7 +290,7 @@ func (e *Engine) Start() error { if err != nil { log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) e.close() - return err + return fmt.Errorf("up wg interface: %w", err) } if e.firewall != nil { @@ -296,7 +300,7 @@ func (e *Engine) Start() error { err = e.dnsServer.Initialize() if err != nil { e.close() - return err + return fmt.Errorf("initialize dns server: %w", err) } e.receiveSignalEvents() diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 1d8e6846d4e..84fd72e49c9 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -10,6 +10,9 @@ import ( "github.com/pion/stun/v2" "github.com/pion/turn/v3" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" + nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request @@ -27,7 +30,15 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } }() - client, err := stun.DialURI(uri, &stun.DialConfig{}) + net, err := stdnet.NewNet(nil) + if err != nil { + probeErr = fmt.Errorf("new net: %w", err) + return + } + + client, err := stun.DialURI(uri, &stun.DialConfig{ + Net: net, + }) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return @@ -85,14 +96,13 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) switch uri.Proto { case stun.ProtoTypeUDP: var err error - conn, err = net.ListenPacket("udp", "") + conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "") if err != nil { probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: - dialer := net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) + tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return @@ -109,12 +119,18 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } }() + net, err := stdnet.NewNet(nil) + if err != nil { + probeErr = fmt.Errorf("new net: %w", err) + return + } cfg := &turn.ClientConfig{ STUNServerAddr: turnServerAddr, TURNServerAddr: turnServerAddr, Conn: conn, Username: uri.Username, Password: uri.Password, + Net: net, } client, err := turn.NewClient(cfg) if err != nil { diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index f7ead582720..b2dff7f08cf 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -41,6 +41,7 @@ type clientNetwork struct { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { ctx, cancel := context.WithCancel(ctx) + client := &clientNetwork{ ctx: ctx, stop: cancel, @@ -72,6 +73,18 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { return routePeerStatuses } +// getBestRouteFromStatuses determines the most optimal route from the available routes +// within a clientNetwork, taking into account peer connection status, route metrics, and +// preference for non-relayed and direct connections. +// +// It follows these prioritization rules: +// * Connected peers: Only routes with connected peers are considered. +// * Metric: Routes with lower metrics (better) are prioritized. +// * Non-relayed: Routes without relays are preferred. +// * Direct connections: Routes with direct peer connections are favored. +// * Stability: In case of equal scores, the currently active route (if any) is maintained. +// +// It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" chosenScore := 0 @@ -158,7 +171,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) if err != nil { - return err + return fmt.Errorf("get peer state: %v", err) } delete(state.Routes, c.network.String()) @@ -172,7 +185,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) if err != nil { - return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", + return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } return nil @@ -180,30 +193,26 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } - err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) - if err != nil { - return fmt.Errorf("couldn't remove route %s from system, err: %v", - c.network, err) + + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route: %v", err) } } return nil } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { - - var err error - routerPeerStatuses := c.getRouterPeerStatuses() chosen := c.getBestRouteFromStatuses(routerPeerStatuses) + + // If no route is chosen, remove the route from the peer and system if chosen == "" { - err = c.removeRouteFromPeerAndSystem() - if err != nil { - return err + if err := c.removeRouteFromPeerAndSystem(); err != nil { + return fmt.Errorf("remove route from peer and system: %v", err) } c.chosenRoute = nil @@ -211,6 +220,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } + // If the chosen route is the same as the current route, do nothing if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute.IsEqual(c.routes[chosen]) { return nil @@ -218,13 +228,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) - if err != nil { - return err + // If a previous route exists, remove it from the peer + if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { + return fmt.Errorf("remove route from peer: %v", err) } } else { - err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) - if err != nil { + // otherwise add the route to the system + if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -245,8 +255,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } - err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) - if err != nil { + if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } @@ -287,21 +296,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system: %v", err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("received a routes update with smaller serial number, ignoring it") + log.Warnf("Received a routes update with smaller serial number, ignoring it") continue } - log.Debugf("received a new client network route update for %s", c.network) + log.Debugf("Received a new client network route update for %s", c.network) c.handleUpdate(update) @@ -309,7 +318,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Error(err) + log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index b624d8c34ce..6a0d954da09 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,6 +2,8 @@ package routemanager import ( "context" + "fmt" + "net/netip" "runtime" "sync" @@ -15,8 +17,14 @@ import ( "github.com/netbirdio/netbird/version" ) +var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + +// nolint:unused +var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + // Manager is a route manager interface type Manager interface { + Init() error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -56,6 +64,19 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } +// Init sets up the routing +func (m *DefaultManager) Init() error { + if err := cleanupRouting(); err != nil { + log.Warnf("Failed cleaning up routing: %v", err) + } + + if err := setupRouting(); err != nil { + return fmt.Errorf("setup routing: %w", err) + } + log.Info("Routing setup complete") + return nil +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -71,9 +92,15 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } + if err := cleanupRouting(); err != nil { + log.Errorf("Error cleaning up routing: %v", err) + } else { + log.Info("Routing cleanup complete") + } + m.ctx = nil } -// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps +// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -91,7 +118,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return err + return fmt.Errorf("update routes: %w", err) } } @@ -156,11 +183,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] for _, newRoute := range newRoutes { networkID := route.GetHAUniqueID(newRoute) if !ownNetworkIDs[networkID] { - // if prefix is too small, lets assume is a possible default route which is not yet supported - // we skip this route management - if newRoute.Network.Bits() < minRangeBits { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", - version.NetbirdVersion(), newRoute.Network) + if !isPrefixSupported(newRoute.Network) { continue } newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) @@ -178,3 +201,18 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } + +func isPrefixSupported(prefix netip.Prefix) bool { + if runtime.GOOS == "linux" { + return true + } + + // If prefix is too small, lets assume it is a possible default prefix which is not yet supported + // we skip this prefix management + if prefix.Bits() < minRangeBits { + log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", + version.NetbirdVersion(), prefix) + return false + } + return true +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2e5cf6649d8..9d92bf90d2f 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedLinux int }{ { name: "Should create 2 client networks", @@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedLinux: 1, }, { name: "Remove 1 Client Route", @@ -415,6 +417,8 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) + err = routeManager.Init() + require.NoError(t, err, "should init route manager") defer routeManager.Stop() if testCase.removeSrvRouter { @@ -429,7 +433,11 @@ func TestManagerUpdateRoutes(t *testing.T) { err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") - require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") + expectedWatchers := testCase.clientNetworkWatchersExpected + if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedLinux + } + require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index a1214cbb9ec..e812b3a85b6 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -16,6 +16,10 @@ type MockManager struct { StopFunc func() } +func (m *MockManager) Init() error { + return nil +} + // InitialRouteRange mock implementation of InitialRouteRange from Manager interface func (m *MockManager) InitialRouteRange() []string { return nil diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 19236787772..00df735fb8a 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,6 +4,7 @@ package routemanager import ( "context" + "fmt" "net/netip" "sync" @@ -48,7 +49,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er oldRoute := m.routes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } delete(m.routes, routeID) @@ -62,7 +63,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er err := m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue } m.routes[id] = newRoute @@ -81,15 +82,22 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not removing from server network because context is done") + log.Infof("Not removing from server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.RemoveRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("remove routing rules: %w", err) } + delete(m.routes, route.ID) state := m.statusRecorder.GetLocalPeerState() @@ -103,15 +111,22 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not adding to server network because context is done") + log.Infof("Not adding to server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) + + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) if err != nil { - return err + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.InsertRoutingRules(routerPair) + if err != nil { + return fmt.Errorf("insert routing rules: %w", err) } + m.routes[route.ID] = route state := m.statusRecorder.GetLocalPeerState() @@ -129,9 +144,15 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) + routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) + if err != nil { + log.Errorf("Failed to convert route to router pair: %v", err) + continue + } + + err = m.firewall.RemoveRoutingRules(routerPair) if err != nil { - log.Warnf("failed to remove clean up route: %s", r.ID) + log.Errorf("Failed to remove cleanup route: %v", err) } state := m.statusRecorder.GetLocalPeerState() @@ -139,13 +160,15 @@ func (m *defaultServerRouter) cleanUp() { m.statusRecorder.UpdateLocalPeerState(state) } } - -func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { - parsed := netip.MustParsePrefix(source).Masked() +func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { + parsed, err := netip.ParsePrefix(source) + if err != nil { + return firewall.RouterPair{}, err + } return firewall.RouterPair{ ID: route.ID, Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, - } + }, nil } diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 950a268434c..291826780af 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -4,10 +4,10 @@ import ( "net/netip" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index b2da8075cfa..173e7c0e847 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -1,5 +1,4 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package routemanager diff --git a/client/internal/routemanager/systemops_bsd_nonios.go b/client/internal/routemanager/systemops_bsd_nonios.go new file mode 100644 index 00000000000..f60c7afc3a0 --- /dev/null +++ b/client/internal/routemanager/systemops_bsd_nonios.go @@ -0,0 +1,13 @@ +//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios + +package routemanager + +import "net/netip" + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { + return genericAddToRouteTableIfNoExists(prefix, addr, intf) +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { + return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index aae0f8dc8f2..291826780af 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,15 +1,13 @@ -//go:build ios - package routemanager import ( "net/netip" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 0562826a55d..192509992c7 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -3,142 +3,298 @@ package routemanager import ( + "bufio" + "errors" + "fmt" "net" "net/netip" "os" "syscall" - "unsafe" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + + nbnet "github.com/netbirdio/netbird/util/net" ) -// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html -// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. -type routeInfoInMemory struct { - Family byte - DstLen byte - SrcLen byte - TOS byte +const ( + // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. + NetbirdVPNTableID = 0x1BD0 + // NetbirdVPNTableName is the name of the custom routing table used by Netbird. + NetbirdVPNTableName = "netbird" + + // rtTablesPath is the path to the file containing the routing table names. + rtTablesPath = "/etc/iproute2/rt_tables" - Table byte - Protocol byte - Scope byte - Type byte + // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. + ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +) - Flags uint32 +var ErrTableIDExists = errors.New("ID exists with different name") + +type ruleParams struct { + fwmark int + tableID int + family int + priority int + invert bool + suppressPrefix int + description string } -const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +func getSetupRules() []ruleParams { + return []ruleParams{ + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"}, + } +} -func addToRouteTable(prefix netip.Prefix, addr string) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return err +// setupRouting establishes the routing configuration for the VPN, including essential rules +// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. +// +// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over +// potential routes received and configured for the VPN. This rule is skipped for the default route and routes +// that are not in the main table. +// +// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. +// This table is where a default route or other specific routes received from the management server are configured, +// enabling VPN connectivity. +// +// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. +func setupRouting() (err error) { + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) } - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" + defer func() { + if err != nil { + if cleanErr := cleanupRouting(); cleanErr != nil { + log.Errorf("Error cleaning up routing: %v", cleanErr) + } + } + }() + + rules := getSetupRules() + for _, rule := range rules { + if err := addRule(rule); err != nil { + return fmt.Errorf("%s: %w", rule.description, err) + } } - ip, _, err := net.ParseCIDR(addr + addrMask) - if err != nil { - return err + return nil +} + +// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. +// It systematically removes the three rules and any associated routing table entries to ensure a clean state. +// The function uses error aggregation to report any errors encountered during the cleanup process. +func cleanupRouting() error { + var result *multierror.Error + + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err)) + } + if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err)) } + rules := getSetupRules() + for _, rule := range rules { + if err := removeAllRules(rule); err != nil { + result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) + } + } + + return result.ErrorOrNil() +} + +func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error { + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 + + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + return fmt.Errorf("add blackhole: %w", err) + } + } + if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + return fmt.Errorf("add route: %w", err) + } + return nil +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error { + // TODO remove this once we have ipv6 support + if prefix == defaultv4 { + if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + return fmt.Errorf("remove unreachable route: %w", err) + } + } + if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil +} + +func getRoutesFromTable() ([]netip.Prefix, error) { + return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4) +} + +// addRoute adds a route to a specific routing table identified by tableID. +func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: family, } - err = netlink.RouteAdd(route) - if err != nil { - return err + if prefix != nil { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + route.Dst = ipNet + } + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("netlink add route: %w", err) } return nil } -func removeFromRouteTable(prefix netip.Prefix, addr string) error { +// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. +// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. +// tableID specifies the routing table to which the unreachable route will be added. +func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } - addrMask := "/32" - if prefix.Addr().Unmap().Is6() { - addrMask = "/128" + route := &netlink.Route{ + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: ipFamily, + Dst: ipNet, } - ip, _, err := net.ParseCIDR(addr + addrMask) + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("netlink add unreachable route: %w", err) + } + + return nil +} + +func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Dst: ipNet, - Gw: ip, + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: ipFamily, + Dst: ipNet, } - err = netlink.RouteDel(route) - if err != nil { - return err + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + return fmt.Errorf("netlink remove unreachable route: %w", err) } return nil + } -func getRoutesFromTable() ([]netip.Prefix, error) { - tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) +// removeRoute removes a route from a specific routing table identified by tableID. +func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { - return nil, err + return fmt.Errorf("parse prefix %s: %w", prefix, err) } - msgs, err := syscall.ParseNetlinkMessage(tab) + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: family, + Dst: ipNet, + } + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { + return fmt.Errorf("netlink remove route: %w", err) + } + + return nil +} + +func flushRoutes(tableID, family int) error { + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) if err != nil { - return nil, err + return fmt.Errorf("list routes from table %d: %w", tableID, err) } - var prefixList []netip.Prefix -loop: - for _, m := range msgs { - switch m.Header.Type { - case syscall.NLMSG_DONE: - break loop - case syscall.RTM_NEWROUTE: - rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) - msg := m - attrs, err := syscall.ParseNetlinkRouteAttr(&msg) - if err != nil { - return nil, err + + var result *multierror.Error + for i := range routes { + route := routes[i] + // unreachable default routes don't come back with Dst set + if route.Gw == nil && route.Src == nil && route.Dst == nil { + if family == netlink.FAMILY_V4 { + routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} + } else { + routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} } - if rt.Family != syscall.AF_INET { - continue loop + } + if err := netlink.RouteDel(&routes[i]); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) + } + } + + return result.ErrorOrNil() +} + +// getRoutes fetches routes from a specific routing table identified by tableID. +func getRoutes(tableID, family int) ([]netip.Prefix, error) { + var prefixList []netip.Prefix + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) + } + + for _, route := range routes { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) } - for _, attr := range attrs { - if attr.Attr.Type == syscall.RTA_DST { - addr, ok := netip.AddrFromSlice(attr.Value) - if !ok { - continue - } - mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) - cidr, _ := mask.Size() - routePrefix := netip.PrefixFrom(addr, cidr) - if routePrefix.IsValid() && routePrefix.Addr().Is4() { - prefixList = append(prefixList, routePrefix) - } - } + ones, _ := route.Dst.Mask.Size() + + prefix := netip.PrefixFrom(addr, ones) + if prefix.IsValid() { + prefixList = append(prefixList, prefix) } } } + return prefixList, nil } func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { - return err + return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) } // check if it is already enabled @@ -147,5 +303,142 @@ func enableIPForwarding() error { return nil } - return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec + //nolint:gosec + if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { + return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) + } + return nil +} + +// entryExists checks if the specified ID or name already exists in the rt_tables file +// and verifies if existing names start with "netbird_". +func entryExists(file *os.File, id int) (bool, error) { + if _, err := file.Seek(0, 0); err != nil { + return false, fmt.Errorf("seek rt_tables: %w", err) + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + var existingID int + var existingName string + if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil { + if existingID == id { + if existingName != NetbirdVPNTableName { + return true, ErrTableIDExists + } + return true, nil + } + } + } + if err := scanner.Err(); err != nil { + return false, fmt.Errorf("scan rt_tables: %w", err) + } + return false, nil +} + +// addRoutingTableName adds human-readable names for custom routing tables. +func addRoutingTableName() error { + file, err := os.Open(rtTablesPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("open rt_tables: %w", err) + } + defer func() { + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables: %v", err) + } + }() + + exists, err := entryExists(file, NetbirdVPNTableID) + if err != nil { + return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err) + } + if exists { + return nil + } + + // Reopen the file in append mode to add new entries + if err := file.Close(); err != nil { + log.Errorf("Error closing rt_tables before appending: %v", err) + } + file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) + if err != nil { + return fmt.Errorf("open rt_tables for appending: %w", err) + } + + if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil { + return fmt.Errorf("append entry to rt_tables: %w", err) + } + + return nil +} + +// addRule adds a routing rule to a specific routing table identified by tableID. +func addRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Priority = params.priority + rule.Invert = params.invert + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleAdd(rule); err != nil { + return fmt.Errorf("add routing rule: %w", err) + } + + return nil +} + +// removeRule removes a routing rule from a specific routing table identified by tableID. +func removeRule(params ruleParams) error { + rule := netlink.NewRule() + rule.Table = params.tableID + rule.Mark = params.fwmark + rule.Family = params.family + rule.Invert = params.invert + rule.Priority = params.priority + rule.SuppressPrefixlen = params.suppressPrefix + + if err := netlink.RuleDel(rule); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + + return nil +} + +func removeAllRules(params ruleParams) error { + for { + if err := removeRule(params); err != nil { + if errors.Is(err, syscall.ENOENT) { + break + } + return err + } + } + return nil +} + +// addNextHop adds the gateway and device to the route. +func addNextHop(addr *string, intf *string, route *netlink.Route) error { + if addr != nil { + ip := net.ParseIP(*addr) + if ip == nil { + return fmt.Errorf("parsing address %s failed", *addr) + } + + route.Gw = ip + } + + if intf != nil { + link, err := netlink.LinkByName(*intf) + if err != nil { + return fmt.Errorf("set interface %s: %w", *intf, err) + } + route.LinkIndex = link.Attrs().Index + } + + return nil } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go new file mode 100644 index 00000000000..96e43d20f0b --- /dev/null +++ b/client/internal/routemanager/systemops_linux_test.go @@ -0,0 +1,469 @@ +//go:build !android + +package routemanager + +import ( + "errors" + "fmt" + "net" + "net/netip" + "os" + "strings" + "syscall" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +func TestEntryExists(t *testing.T) { + tempDir := t.TempDir() + tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) + + content := []string{ + "1000 reserved", + fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName), + "9999 other_table", + } + require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644)) + + file, err := os.Open(tempFilePath) + require.NoError(t, err) + defer func() { + assert.NoError(t, file.Close()) + }() + + tests := []struct { + name string + id int + shouldExist bool + err error + }{ + { + name: "ExistsWithNetbirdPrefix", + id: 7120, + shouldExist: true, + err: nil, + }, + { + name: "ExistsWithDifferentName", + id: 1000, + shouldExist: true, + err: ErrTableIDExists, + }, + { + name: "DoesNotExist", + id: 1234, + shouldExist: false, + err: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + exists, err := entryExists(file, tc.id) + if tc.err != nil { + assert.ErrorIs(t, err, tc.err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.shouldExist, exists) + }) + } +} + +func TestRoutingWithTables(t *testing.T) { + testCases := []struct { + name string + destination string + captureInterface string + dialer *net.Dialer + packetExpectation PacketExpectation + }{ + { + name: "To external host without fwmark via vpn", + destination: "192.0.2.1:53", + captureInterface: "wgtest0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with fwmark via physical interface", + destination: "192.0.2.1:53", + captureInterface: "dummyext0", + dialer: nbnet.NewDialer(), + packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with fwmark via physical interface", + destination: "10.0.0.1:53", + captureInterface: "dummyint0", + dialer: nbnet.NewDialer(), + packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), + }, + { + name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence + destination: "10.0.0.1:53", + captureInterface: "dummyint0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), + }, + + { + name: "To unique vpn route with fwmark via physical interface", + destination: "172.16.0.1:53", + captureInterface: "dummyext0", + dialer: nbnet.NewDialer(), + packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53), + }, + { + name: "To unique vpn route without fwmark via vpn", + destination: "172.16.0.1:53", + captureInterface: "wgtest0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53), + }, + + { + name: "To more specific route without fwmark via vpn interface", + destination: "10.10.0.1:53", + captureInterface: "dummyint0", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53), + }, + + { + name: "To more specific route (local) without fwmark via physical interface", + destination: "127.0.10.1:53", + captureInterface: "lo", + dialer: &net.Dialer{}, + packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wgIface, _, _ := setupTestEnv(t) + + // default route exists in main table and vpn table + err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 10.0.0.0/8 route exists in main table and vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 10.10.0.0/24 more specific route exists in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 127.0.10.0/24 more specific route exists in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // unique route in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.captureInterface, filter) + + sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.packetExpectation) + }) + } +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } + +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy { + t.Helper() + + dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} + err := netlink.LinkDel(dummy) + if err != nil && !errors.Is(err, syscall.EINVAL) { + t.Logf("Failed to delete dummy interface: %v", err) + } + + err = netlink.LinkAdd(dummy) + require.NoError(t, err) + + err = netlink.LinkSetUp(dummy) + require.NoError(t, err) + + if ipAddressCIDR != "" { + addr, err := netlink.ParseAddr(ipAddressCIDR) + require.NoError(t, err) + err = netlink.AddrAdd(dummy, addr) + require.NoError(t, err) + } + + return dummy +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { + t.Helper() + + _, dstIPNet, err := net.ParseCIDR(dstCIDR) + require.NoError(t, err) + + if dstIPNet.String() == "0.0.0.0/0" { + gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4) + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + // Handle existing routes with metric 0 + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + if err == nil { + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + } else if !errors.Is(err, syscall.ESRCH) { + t.Logf("Failed to delete route: %v", err) + } + } + + route := &netlink.Route{ + Dst: dstIPNet, + Gw: gw, + LinkIndex: linkIndex, + } + err = netlink.RouteDel(route) + if err != nil && !errors.Is(err, syscall.ESRCH) { + t.Logf("Failed to delete route: %v", err) + } + + err = netlink.RouteAdd(route) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } +} + +// fetchOriginalGateway returns the original gateway IP address and the interface index. +func fetchOriginalGateway(family int) (net.IP, int, error) { + routes, err := netlink.RouteList(nil, family) + if err != nil { + return nil, 0, err + } + + for _, route := range routes { + if route.Dst == nil { + return route.Gw, route.LinkIndex, nil + } + } + + return nil, 0, fmt.Errorf("default route not found") +} + +func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index) + + otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index) + + t.Cleanup(func() { + err := netlink.LinkDel(defaultDummy) + assert.NoError(t, err) + err = netlink.LinkDel(otherDummy) + assert.NoError(t, err) + }) + + return defaultDummy.Name, otherDummy.Name +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet(nil) + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) { + t.Helper() + + defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + err := setupRouting() + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + return wgIface, defaultDummy, otherDummy +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + dialer.LocalAddr = localUDPAddr + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go index 11247c7dccd..65f670ace17 100644 --- a/client/internal/routemanager/systemops_nonandroid.go +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -1,11 +1,15 @@ -//go:build !android && !ios +//go:build !android +//nolint:unused package routemanager import ( + "errors" "fmt" "net" "net/netip" + "os/exec" + "runtime" "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" @@ -13,41 +17,16 @@ import ( var errRouteNotFound = fmt.Errorf("route not found") -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return err - } - if ok { - log.Warnf("skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return err - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, addr) -} - -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil && err != errRouteNotFound { - return err +func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + defaultGateway, err := getExistingRIBRouteGateway(defaultv4) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) } addr := netip.MustParseAddr(defaultGateway.String()) if !prefix.Contains(addr) { - log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) return nil } @@ -59,56 +38,79 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { } if ok { - log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) return nil } gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) - if err != nil && err != errRouteNotFound { + if err != nil && !errors.Is(err, errRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } - log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop.String()) + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "") } -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() +func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { + ok, err := existsInRouteTable(prefix) if err != nil { - return false, err + return fmt.Errorf("exists in route table: %w", err) } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil } - return false, nil -} -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() + ok, err = isSubRange(prefix) if err != nil { - return false, err + return fmt.Errorf("sub range: %w", err) } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil + + if ok { + err := genericAddRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) } } - return false, nil + + return genericAddToRouteTable(prefix, addr, intf) +} + +func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { + return genericRemoveFromRouteTable(prefix, addr, intf) +} + +func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error { + cmd := exec.Command("route", "add", prefix.String(), addr) + out, err := cmd.Output() + if err != nil { + return fmt.Errorf("add route: %w", err) + } + log.Debugf(string(out)) + return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { - return removeFromRouteTable(prefix, addr) +func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error { + args := []string{"delete", prefix.String()} + if runtime.GOOS == "darwin" { + args = append(args, addr) + } + cmd := exec.Command("route", args...) + out, err := cmd.Output() + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + log.Debugf(string(out)) + return nil } func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { r, err := netroute.New() if err != nil { - return nil, err + return nil, fmt.Errorf("new netroute: %w", err) } _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) if err != nil { - log.Errorf("getting routes returned an error: %v", err) + log.Errorf("Getting routes returned an error: %v", err) return nil, errRouteNotFound } @@ -118,3 +120,29 @@ func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { return gateway, nil } + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go index 6f32d9634bc..aae5e5faa16 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -8,17 +8,63 @@ import ( "net" "net/netip" "os" + "os/exec" + "runtime" "strings" "testing" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + + if runtime.GOOS == "linux" { + outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String()) + require.NoError(t, err, "getOutgoingInterfaceLinux should not return error") + if invert { + require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface") + } else { + require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface") + } + return + } + + prefixGateway, err := getExistingRIBRouteGateway(prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} + +func getOutgoingInterfaceLinux(destination string) (string, error) { + cmd := exec.Command("ip", "route", "get", destination) + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("executing ip route get: %w", err) + } + + return parseOutgoingInterface(string(output)), nil +} + +func parseOutgoingInterface(routeGetOutput string) string { + fields := strings.Fields(routeGetOutput) + for i, field := range fields { + if field == "dev" && i+1 < len(fields) { + return fields[i+1] + } + } + return "" +} + func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string @@ -54,23 +100,26 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, setupRouting()) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { - require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + assertWGOutInterface(t, testCase.prefix, wgInterface, false) } else { - require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") + assertWGOutInterface(t, testCase.prefix, wgInterface, true) } exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) + prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) require.NoError(t, err, "getExistingRIBRouteGateway should not return err") internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) @@ -189,16 +238,21 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + require.NoError(t, setupRouting()) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + MockAddr := wgInterface.Address().IP.String() // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) + err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) + err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -208,7 +262,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) + err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -217,72 +271,12 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { ok, err := existsInRouteTable(testCase.prefix) t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") - if !strings.Contains(buf.String(), "because it already exists") { + + // Linux uses a separate routing table, so the route can exist in both tables. + // The main routing table takes precedence over the wireguard routing table. + if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { require.False(t, ok, "route should not exist") } }) } } - -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { - addressPrefixes = append(addressPrefixes, p.Masked()) - } - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} - -func TestIsSubRange(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var subRangeAddressPrefixes []netip.Prefix - var nonSubRangeAddressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { - p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) - subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) - nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) - } - } - - for _, prefix := range subRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if !isSubRangePrefix { - t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) - } - } - - for _, prefix := range nonSubRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if isSubRangePrefix { - t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) - } - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 47bd60eb02b..d793f0fbde0 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,41 +1,22 @@ -//go:build !linux -// +build !linux +//go:build !linux || android package routemanager import ( - "net/netip" - "os/exec" "runtime" log "github.com/sirupsen/logrus" ) -func addToRouteTable(prefix netip.Prefix, addr string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) +func setupRouting() error { return nil } -func removeFromRouteTable(prefix netip.Prefix, addr string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) - } - cmd := exec.Command("route", args...) - out, err := cmd.Output() - if err != nil { - return err - } - log.Debugf(string(out)) +func cleanupRouting() error { return nil } func enableIPForwarding() error { - log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go new file mode 100644 index 00000000000..afaf5ba7724 --- /dev/null +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -0,0 +1,80 @@ +//go:build !linux || android + +package routemanager + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsSubRange(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var subRangeAddressPrefixes []netip.Prefix + var nonSubRangeAddressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { + p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) + subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) + nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) + } + } + + for _, prefix := range subRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if !isSubRangePrefix { + t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) + } + } + + for _, prefix := range nonSubRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if isSubRangePrefix { + t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) + } + } +} + +func TestExistsInRouteTable(t *testing.T) { + require.NoError(t, setupRouting()) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is4() { + addressPrefixes = append(addressPrefixes, p.Masked()) + } + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 309c184b9ca..c009ce66b9d 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -1,12 +1,13 @@ //go:build windows -// +build windows package routemanager import ( + "fmt" "net" "net/netip" + log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" ) @@ -21,17 +22,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) { err := wmi.Query(query, &routes) if err != nil { - return nil, err + return nil, fmt.Errorf("get routes: %w", err) } var prefixList []netip.Prefix for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { + log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) continue } maskSlice := net.ParseIP(route.Mask).To4() if maskSlice == nil { + log.Warnf("Unable to parse route mask %s", route.Mask) continue } mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) @@ -44,3 +47,11 @@ func getRoutesFromTable() ([]netip.Prefix, error) { } return prefixList, nil } + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { + return genericAddToRouteTableIfNoExists(prefix, addr, intf) +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { + return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) +} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go new file mode 100644 index 00000000000..e80adb42b20 --- /dev/null +++ b/client/internal/stdnet/dialer.go @@ -0,0 +1,24 @@ +package stdnet + +import ( + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// Dial connects to the address on the named network. +func (n *Net) Dial(network, address string) (net.Conn, error) { + return nbnet.NewDialer().Dial(network, address) +} + +// DialUDP connects to the address on the named UDP network. +func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.DialUDP(network, laddr, raddr) +} + +// DialTCP connects to the address on the named TCP network. +func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + return nbnet.DialTCP(network, laddr, raddr) +} diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go new file mode 100644 index 00000000000..9ce0a555610 --- /dev/null +++ b/client/internal/stdnet/listener.go @@ -0,0 +1,20 @@ +package stdnet + +import ( + "context" + "net" + + "github.com/pion/transport/v3" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// ListenPacket listens for incoming packets on the given network and address. +func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { + return nbnet.NewListener().ListenPacket(context.Background(), network, address) +} + +// ListenUDP acts like ListenPacket for UDP networks. +func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { + return nbnet.ListenUDP(network, locAddr) +} diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/portlookup.go index 6f3d33487ea..6ede4b83f1d 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/portlookup.go @@ -1,8 +1,10 @@ package wgproxy import ( + "context" "fmt" - "net" + + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -23,7 +25,7 @@ func (pl portLookup) searchFreePort() (int, error) { } func (pl portLookup) tryToBind(port int) error { - l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) + l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port)) if err != nil { return err } diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 6ca19c9737e..b91cd7b439d 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) // WGEBPFProxy definition for proxy with EBPF support @@ -66,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - p.conn, err = net.ListenUDP("udp", &addr) + p.conn, err = nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -208,20 +209,41 @@ generatePort: } func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) if err != nil { - return nil, err + return nil, fmt.Errorf("creating raw socket failed: %w", err) } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) if err != nil { - return nil, err + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) } + + // Bind the socket to the "lo" interface. err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") if err != nil { - return nil, err + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) } - return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))) + return packetConn, nil } func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index b692ea70842..17ebfbc499b 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -6,6 +6,8 @@ import ( "net" log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -33,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/go.mod b/go.mod index 6aba599f810..f3bd1634b86 100644 --- a/go.mod +++ b/go.mod @@ -47,8 +47,9 @@ require ( github.com/google/go-cmp v0.5.9 github.com/google/gopacket v1.1.19 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 - github.com/hashicorp/go-multierror v1.1.0 + github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 github.com/libp2p/go-netroute v0.2.0 @@ -123,7 +124,6 @@ require ( github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect - github.com/gopacket/gopacket v1.1.1 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index ca10cd55367..48fb1333f8b 100644 --- a/go.sum +++ b/go.sum @@ -291,8 +291,8 @@ github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f2 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= -github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= diff --git a/iface/address.go b/iface/address.go index 5ff4fbc0645..2920d009fa1 100644 --- a/iface/address.go +++ b/iface/address.go @@ -23,6 +23,24 @@ func parseWGAddress(address string) (WGAddress, error) { }, nil } +// Masked returns the WGAddress with the IP address part masked according to its network mask. +func (addr WGAddress) Masked() WGAddress { + ip := addr.IP.To4() + if ip == nil { + ip = addr.IP.To16() + } + + maskedIP := make(net.IP, len(ip)) + for i := range ip { + maskedIP[i] = ip[i] & addr.Network.Mask[i] + } + + return WGAddress{ + IP: maskedIP, + Network: addr.Network, + } +} + func (addr WGAddress) String() string { maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 36fd13cc262..9fe987cee21 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,6 +10,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -29,7 +31,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := 0 + fwmark := nbnet.NetbirdFwmark config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 200bfbc9614..24dfadf1408 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -13,6 +13,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbnet "github.com/netbirdio/netbird/util/net" ) type wgUSPConfigurer struct { @@ -37,7 +39,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error if err != nil { return err } - fwmark := 0 + fwmark := getFwmark() config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, @@ -345,3 +347,10 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } return sb.String() } + +func getFwmark() int { + if runtime.GOOS == "linux" { + return nbnet.NetbirdFwmark + } + return 0 +} diff --git a/management/client/grpc.go b/management/client/grpc.go index 0234f866cb8..0b1804906c2 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -57,6 +58,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE mgmCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 656fdc8ca24..74ac6c163ad 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -21,6 +21,8 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" + + nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -55,8 +57,7 @@ var writeSerializerOptions = gopacket.SerializeOptions{ } // Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines -func Listen(port int, filter BPFFilter) (net.PacketConn, error) { - var err error +func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { ctx, cancel := context.WithCancel(context.Background()) rawSock := &SharedSocket{ ctx: ctx, @@ -65,37 +66,51 @@ func Listen(port int, filter BPFFilter) (net.PacketConn, error) { packetDemux: make(chan rcvdPacket), } + defer func() { + if err != nil { + if closeErr := rawSock.Close(); closeErr != nil { + log.Errorf("Failed to close raw socket: %v", closeErr) + } + } + }() + rawSock.router, err = netroute.New() if err != nil { - return nil, fmt.Errorf("failed to create raw socket router: %v", err) + return nil, fmt.Errorf("failed to create raw socket router: %w", err) } rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) if err != nil { - return nil, fmt.Errorf("failed to create ipv4 raw socket: %v", err) + return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } - rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) - if err != nil { - log.Errorf("failed to create ipv6 raw socket: %v", err) + if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + } + + var sockErr error + rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) + if sockErr != nil { + log.Errorf("Failed to create ipv6 raw socket: %v", err) + } else { + if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { + return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) if err != nil { - _ = rawSock.Close() - return nil, fmt.Errorf("getBPFInstructions failed with: %rawSock", err) + return nil, fmt.Errorf("getBPFInstructions failed with: %w", err) } err = rawSock.conn4.SetBPF(ipv4Instructions) if err != nil { - _ = rawSock.Close() - return nil, fmt.Errorf("socket4.SetBPF failed with: %rawSock", err) + return nil, fmt.Errorf("socket4.SetBPF failed with: %w", err) } if rawSock.conn6 != nil { err = rawSock.conn6.SetBPF(ipv6Instructions) if err != nil { - _ = rawSock.Close() - return nil, fmt.Errorf("socket6.SetBPF failed with: %rawSock", err) + return nil, fmt.Errorf("socket6.SetBPF failed with: %w", err) } } @@ -121,7 +136,7 @@ func (s *SharedSocket) updateRouter() { case <-ticker.C: router, err := netroute.New() if err != nil { - log.Errorf("failed to create and update packet router for stunListener: %s", err) + log.Errorf("Failed to create and update packet router for stunListener: %s", err) continue } s.routerMux.Lock() @@ -144,7 +159,7 @@ func (s *SharedSocket) LocalAddr() net.Addr { func (s *SharedSocket) SetDeadline(t time.Time) error { err := s.conn4.SetDeadline(t) if err != nil { - return fmt.Errorf("s.conn4.SetDeadline error: %s", err) + return fmt.Errorf("s.conn4.SetDeadline error: %w", err) } if s.conn6 == nil { return nil @@ -152,7 +167,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error { err = s.conn6.SetDeadline(t) if err != nil { - return fmt.Errorf("s.conn6.SetDeadline error: %s", err) + return fmt.Errorf("s.conn6.SetDeadline error: %w", err) } return nil } @@ -161,7 +176,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error { func (s *SharedSocket) SetReadDeadline(t time.Time) error { err := s.conn4.SetReadDeadline(t) if err != nil { - return fmt.Errorf("s.conn4.SetReadDeadline error: %s", err) + return fmt.Errorf("s.conn4.SetReadDeadline error: %w", err) } if s.conn6 == nil { return nil @@ -169,7 +184,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error { err = s.conn6.SetReadDeadline(t) if err != nil { - return fmt.Errorf("s.conn6.SetReadDeadline error: %s", err) + return fmt.Errorf("s.conn6.SetReadDeadline error: %w", err) } return nil } @@ -178,7 +193,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error { func (s *SharedSocket) SetWriteDeadline(t time.Time) error { err := s.conn4.SetWriteDeadline(t) if err != nil { - return fmt.Errorf("s.conn4.SetWriteDeadline error: %s", err) + return fmt.Errorf("s.conn4.SetWriteDeadline error: %w", err) } if s.conn6 == nil { return nil @@ -186,7 +201,7 @@ func (s *SharedSocket) SetWriteDeadline(t time.Time) error { err = s.conn6.SetWriteDeadline(t) if err != nil { - return fmt.Errorf("s.conn6.SetWriteDeadline error: %s", err) + return fmt.Errorf("s.conn6.SetWriteDeadline error: %w", err) } return nil } @@ -282,7 +297,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { _, _, src, err := s.router.Route(rUDPAddr.IP) if err != nil { - return 0, fmt.Errorf("got an error while checking route, err: %s", err) + return 0, fmt.Errorf("got an error while checking route, err: %w", err) } rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) @@ -292,7 +307,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { } if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil { - return -1, fmt.Errorf("failed serialize rcvdPacket: %s", err) + return -1, fmt.Errorf("failed serialize rcvdPacket: %w", err) } bufser := buffer.Bytes() diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7531608c3bb..7c4535e2896 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -76,6 +77,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo sigCtx, addr, transportOption, + nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/util/grpc/dialer_generic.go b/util/grpc/dialer_generic.go new file mode 100644 index 00000000000..1c2285b14bf --- /dev/null +++ b/util/grpc/dialer_generic.go @@ -0,0 +1,9 @@ +//go:build !linux || android + +package grpc + +import "google.golang.org/grpc" + +func WithCustomDialer() grpc.DialOption { + return grpc.EmptyDialOption{} +} diff --git a/util/grpc/dialer_linux.go b/util/grpc/dialer_linux.go new file mode 100644 index 00000000000..b29ee4b2936 --- /dev/null +++ b/util/grpc/dialer_linux.go @@ -0,0 +1,18 @@ +//go:build !android + +package grpc + +import ( + "context" + "net" + + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func WithCustomDialer() grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return nbnet.NewDialer().DialContext(ctx, "tcp", addr) + }) +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go new file mode 100644 index 00000000000..a3c3ad67c74 --- /dev/null +++ b/util/net/dialer_generic.go @@ -0,0 +1,19 @@ +//go:build !linux || android + +package net + +import ( + "net" +) + +func NewDialer() *net.Dialer { + return &net.Dialer{} +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + return net.DialUDP(network, laddr, raddr) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + return net.DialTCP(network, laddr, raddr) +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go new file mode 100644 index 00000000000..d559490c517 --- /dev/null +++ b/util/net/dialer_linux.go @@ -0,0 +1,60 @@ +//go:build !android + +package net + +import ( + "context" + "fmt" + "net" + "syscall" + + log "github.com/sirupsen/logrus" +) + +func NewDialer() *net.Dialer { + return &net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + return SetRawSocketMark(c) + }, + } +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.DialContext(context.Background(), network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type") + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.DialContext(context.Background(), network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type") + } + + return tcpConn, nil +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go new file mode 100644 index 00000000000..241c744e528 --- /dev/null +++ b/util/net/listener_generic.go @@ -0,0 +1,13 @@ +//go:build !linux || android + +package net + +import "net" + +func NewListener() *net.ListenConfig { + return &net.ListenConfig{} +} + +func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, locAddr) +} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go new file mode 100644 index 00000000000..7b9bda97c7d --- /dev/null +++ b/util/net/listener_linux.go @@ -0,0 +1,30 @@ +//go:build !android + +package net + +import ( + "context" + "fmt" + "net" + "syscall" +) + +func NewListener() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + return SetRawSocketMark(c) + }, + } +} + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err) + } + udpConn, ok := pc.(*net.UDPConn) + if !ok { + return nil, fmt.Errorf("packetConn is not a *net.UDPConn") + } + return udpConn, nil +} diff --git a/util/net/net.go b/util/net/net.go new file mode 100644 index 00000000000..5714e52294e --- /dev/null +++ b/util/net/net.go @@ -0,0 +1,6 @@ +package net + +const ( + // NetbirdFwmark is the fwmark value used by Netbird via wireguard + NetbirdFwmark = 0x1BD00 +) diff --git a/util/net/net_linux.go b/util/net/net_linux.go new file mode 100644 index 00000000000..82141750029 --- /dev/null +++ b/util/net/net_linux.go @@ -0,0 +1,35 @@ +//go:build !android + +package net + +import ( + "fmt" + "syscall" +) + +// SetSocketMark sets the SO_MARK option on the given socket connection +func SetSocketMark(conn syscall.Conn) error { + sysconn, err := conn.SyscallConn() + if err != nil { + return fmt.Errorf("get raw conn: %w", err) + } + + return SetRawSocketMark(sysconn) +} + +func SetRawSocketMark(conn syscall.RawConn) error { + var setErr error + + err := conn.Control(func(fd uintptr) { + setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + }) + if err != nil { + return fmt.Errorf("control: %w", err) + } + + if setErr != nil { + return fmt.Errorf("set SO_MARK: %w", setErr) + } + + return nil +}