diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index fe8d7b4ef19..39b55f1052b 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -1,8 +1,9 @@ -//go:build !android +//go:build !android && !ios package routemanager import ( + "errors" "fmt" "net/netip" "sync" @@ -53,6 +54,9 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref if ref.count == 0 { log.Debugf("Adding route for prefix %s", prefix) nexthop, intf, err := rm.addRoute(prefix) + if errors.Is(err, errRouteNotFound) { + return nil + } if err != nil { return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 00000000000..c6f3376e032 --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,410 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var errRouteNotFound = fmt.Errorf("route not found") + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Warnf("Failed to get route for %s: %v", ip, err) + return netip.Addr{}, nil, errRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, errRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. +// If the next hop or interface is pointing to the VPN interface, it will return an error +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(): + return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) + case addr.IsLinkLocalUnicast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) + case addr.IsLinkLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) + case addr.IsInterfaceLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) + case addr.IsUnspecified(): + return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) + case addr.IsMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func genericAddVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil && !errors.Is(err, errRouteNotFound) { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil && !errors.Is(err, errRouteNotFound) { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index d21a3bfbfea..44691f0d65b 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -35,6 +35,9 @@ const ( var ErrTableIDExists = errors.New("ID exists with different name") +var routeManager = &RouteManager{} +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" + type ruleParams struct { fwmark int tableID int @@ -66,7 +69,12 @@ func getSetupRules() []ruleParams { // enabling VPN connectivity. // // The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { + if isLegacy { + log.Infof("Using legacy routing setup") + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } @@ -82,6 +90,11 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { + if errors.Is(err, syscall.EOPNOTSUPP) { + log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") + isLegacy = true + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } @@ -93,6 +106,10 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. func cleanupRouting() error { + if isLegacy { + return cleanupRoutingWithRouteManager(routeManager) + } + var result *multierror.Error if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { @@ -104,7 +121,7 @@ func cleanupRouting() error { rules := getSetupRules() for _, rule := range rules { - if err := removeAllRules(rule); err != nil { + if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) } } @@ -112,49 +129,104 @@ func cleanupRouting() error { return result.ErrorOrNil() } +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + func addVPNRoute(prefix netip.Prefix, intf string) error { - // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 + if isLegacy { + return genericAddVPNRoute(prefix, intf) + } + + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("add blackhole: %w", err) } } - if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("add route: %w", err) } return nil } func removeVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericRemoveVPNRoute(prefix, intf) + } + // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove unreachable route: %w", err) } } - if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove route: %w", err) } return nil } +func getRoutesFromTable() ([]netip.Prefix, error) { + v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) + if err != nil { + return nil, fmt.Errorf("get v4 routes: %w", err) + } + v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) + if err != nil { + return nil, fmt.Errorf("get v6 routes: %w", err) + + } + return append(v4Routes, v6Routes...), nil +} + +// getRoutes fetches routes from a specific routing table identified by tableID. +func getRoutes(tableID, family int) ([]netip.Prefix, error) { + var prefixList []netip.Prefix + + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) + } + + for _, route := range routes { + if route.Dst != nil { + addr, ok := netip.AddrFromSlice(route.Dst.IP) + if !ok { + return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) + } + + ones, _ := route.Dst.Mask.Size() + + prefix := netip.PrefixFrom(addr, ones) + if prefix.IsValid() { + prefixList = append(prefixList, prefix) + } + } + } + + return prefixList, nil +} + // addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { +func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, - Family: family, + Family: getAddressFamily(prefix), } - if prefix != nil { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - route.Dst = ipNet + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) } + route.Dst = ipNet if err := addNextHop(addr, intf, route); err != nil { return fmt.Errorf("add gateway and device: %w", err) @@ -170,7 +242,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err // addUnreachableRoute adds an unreachable route for the specified IP family and routing table. // ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. // tableID specifies the routing table to which the unreachable route will be added. -func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { +func addUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -179,7 +251,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { route := &netlink.Route{ Type: syscall.RTN_UNREACHABLE, Table: tableID, - Family: ipFamily, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -190,7 +262,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { return nil } -func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { +func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -199,7 +271,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { route := &netlink.Route{ Type: syscall.RTN_UNREACHABLE, Table: tableID, - Family: ipFamily, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -212,7 +284,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { } // removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return fmt.Errorf("parse prefix %s: %w", prefix, err) @@ -221,7 +293,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Table: tableID, - Family: family, + Family: getAddressFamily(prefix), Dst: ipNet, } @@ -392,23 +464,25 @@ func removeAllRules(params ruleParams) error { } // addNextHop adds the gateway and device to the route. -func addNextHop(addr *string, intf *string, route *netlink.Route) error { - if addr != nil { - ip := net.ParseIP(*addr) - if ip == nil { - return fmt.Errorf("parsing address %s failed", *addr) - } - - route.Gw = ip +func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { + if addr.IsValid() { + route.Gw = addr.AsSlice() } - if intf != nil { - link, err := netlink.LinkByName(*intf) + if intf != "" { + link, err := netlink.LinkByName(intf) if err != nil { - return fmt.Errorf("set interface %s: %w", *intf, err) + return fmt.Errorf("set interface %s: %w", intf, err) } route.LinkIndex = link.Attrs().Index } return nil } + +func getAddressFamily(prefix netip.Prefix) int { + if prefix.Addr().Is4() { + return netlink.FAMILY_V4 + } + return netlink.FAMILY_V6 +} diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 50a02401a68..d77c7cc7dcf 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -21,8 +21,6 @@ var expectedLoopbackInt = "lo" var expectedExternalInt = "dummyext0" var expectedInternalInt = "dummyint0" -var errRouteNotFound = fmt.Errorf("route not found") - func init() { testCases = append(testCases, []testCase{ { diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 4bc186f215e..38026107ec7 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -3,414 +3,21 @@ package routemanager import ( - "context" - "errors" - "fmt" - "net" "net/netip" "runtime" - "github.com/hashicorp/go-multierror" - "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - -var errRouteNotFound = fmt.Errorf("route not found") - func enableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } -// TODO: fix: for default our wg address now appears as the default gw -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(defaultGateway) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) - if defaultGateway.Is6() { - gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - var exitIntf string - gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - if intf != nil { - exitIntf = intf.Name - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) -} - -func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Errorf("Getting routes returned an error: %v", err) - return netip.Addr{}, nil, errRouteNotFound - } - - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - if preferredSrc == nil { - return netip.Addr{}, nil, errRouteNotFound - } - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) - } - return addr.Unmap(), intf, nil - } - - addr, ok := netip.AddrFromSlice(gateway) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) - } - - return addr.Unmap(), intf, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. -// If the next hop or interface is pointing to the VPN interface, it will return an error -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(): - return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) - case addr.IsLinkLocalUnicast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) - case addr.IsLinkLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) - case addr.IsInterfaceLocalMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) - case addr.IsUnspecified(): - return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) - case addr.IsMulticast(): - return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) - } - - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) - if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err) - } - - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } - - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") - } - - // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) - exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } - } - - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) - } - - return exitNextHop, exitIntf, nil -} - -// addVPNRoute adds a new route to the vpn interface, it splits the default prefix -// in two /1 prefixes to avoid replacing the existing default route func addVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - return err - } - if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } - - // TODO: remove once IPv6 is supported on the interface - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } else if prefix == defaultv6 { - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } - - return addNonExistingRoute(prefix, intf) + return genericAddVPNRoute(prefix, intf) } -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, netip.Addr{}, intf) -} - -// removeVPNRoute removes the route from the vpn interface. If a default prefix is given, -// it will remove the split /1 prefixes func removeVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - // TODO: remove once IPv6 is supported on the interface - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } else if prefix == defaultv6 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } - - return removeFromRouteTable(prefix, netip.Addr{}, intf) -} - -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil { - log.Errorf("Unable to get initial v4 default next hop: %v", err) - } - initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil { - log.Errorf("Unable to get initial v6 default next hop: %v", err) - } - - *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { - addr := prefix.Addr() - nexthop, intf := initialNextHopV4, initialIntfV4 - if addr.Is6() { - nexthop, intf = initialNextHopV6, initialIntfV6 - } - return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) - }, - removeFromRouteTable, - ) - - return setupHooks(*routeManager, initAddresses) -} - -func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil -} - -func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil + return genericRemoveVPNRoute(prefix, intf) } diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_test.go similarity index 70% rename from client/internal/routemanager/systemops_nonlinux_test.go rename to client/internal/routemanager/systemops_test.go index adb83bac6d8..97386f19a1a 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -1,13 +1,15 @@ -//go:build !linux && !ios +//go:build !android && !ios package routemanager import ( "bytes" + "context" "fmt" "net" "net/netip" "os" + "runtime" "strings" "testing" @@ -20,16 +22,9 @@ import ( "github.com/netbirdio/netbird/iface" ) -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - - prefixGateway, _, err := getNextHop(prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) } func TestAddRemoveRoutes(t *testing.T) { @@ -72,8 +67,8 @@ func TestAddRemoveRoutes(t *testing.T) { assert.NoError(t, cleanupRouting()) }) - err = addVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.shouldRouteToWireguard { assertWGOutInterface(t, testCase.prefix, wgInterface, false) @@ -83,8 +78,8 @@ func TestAddRemoveRoutes(t *testing.T) { exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "removeVPNRoute should not return err") + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericRemoveVPNRoute should not return err") prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) require.NoError(t, err, "getNextHop should not return err") @@ -144,7 +139,7 @@ func TestGetNextHop(t *testing.T) { } } -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { +func TestAddExistAndRemoveRoute(t *testing.T) { defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { @@ -205,20 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addVPNRoute(testCase.prefix, wgInterface.Name()) + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -228,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -284,12 +273,6 @@ func TestIsSubRange(t *testing.T) { } func TestExistsInRouteTable(t *testing.T) { - _, _, err := setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - addresses, err := net.InterfaceAddrs() if err != nil { t.Fatal("shouldn't return error when fetching interface addresses: ", err) @@ -298,10 +281,19 @@ func TestExistsInRouteTable(t *testing.T) { var addressPrefixes []netip.Prefix for _, address := range addresses { p := netip.MustParsePrefix(address.String()) + if p.Addr().Is6() { + continue + } // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() { - addressPrefixes = append(addressPrefixes, p.Masked()) + if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { + continue + } + // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence + if runtime.GOOS == "linux" && p.Addr().IsLoopback() { + continue } + + addressPrefixes = append(addressPrefixes, p.Masked()) } for _, prefix := range addressPrefixes { @@ -314,3 +306,97 @@ func TestExistsInRouteTable(t *testing.T) { } } } + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet() + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} + +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { + return + } + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go deleted file mode 100644 index 28a6502d2ef..00000000000 --- a/client/internal/routemanager/sytemops_test.go +++ /dev/null @@ -1,101 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "context" - "net" - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" -) - -type dialer interface { - Dial(network, address string) (net.Conn, error) - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet(nil) - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) { - t.Helper() - - setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - _, _, err := setupRouting(nil, wgIface) - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) -} diff --git a/util/net/dialer.go b/util/net/dialer.go index 7b9bddbb52a..d3adef363a0 100644 --- a/util/net/dialer.go +++ b/util/net/dialer.go @@ -35,7 +35,7 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { udpConn, ok := conn.(*net.UDPConn) if !ok { if err := conn.Close(); err != nil { - log.Errorf("Failed to closeConn connection: %v", err) + log.Errorf("Failed to close connection: %v", err) } return nil, fmt.Errorf("expected UDP connection, got different type") } diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index ae412415ff9..a195bdeb917 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -145,10 +145,19 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { // ListenUDP listens on the network address and returns a transport.UDPConn // which includes support for write and close hooks. func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { - udpConn, err := net.ListenUDP(network, laddr) + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) if err != nil { return nil, fmt.Errorf("listen UDP: %w", err) } - connID := GenerateConnID() - return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type") + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil }