diff --git a/client/internal/engine.go b/client/internal/engine.go index 28e1f1b55d0..9ee804a87fa 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -112,7 +112,7 @@ type Engine struct { TURNs []*stun.URI // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes map[string][]*route.Route + clientRoutes route.HAMap cancel context.CancelFunc @@ -736,9 +736,9 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { for _, protoRoute := range protoRoutes { _, prefix, _ := route.ParseNetwork(protoRoute.Network) convertedRoute := &route.Route{ - ID: protoRoute.ID, + ID: route.ID(protoRoute.ID), Network: prefix, - NetID: protoRoute.NetID, + NetID: route.NetID(protoRoute.NetID), NetworkType: route.NetworkType(protoRoute.NetworkType), Peer: protoRoute.Peer, Metric: int(protoRoute.Metric), @@ -1238,18 +1238,15 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { } // GetClientRoutes returns the current routes from the route map -func (e *Engine) GetClientRoutes() map[string][]*route.Route { +func (e *Engine) GetClientRoutes() route.HAMap { return e.clientRoutes } // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only -func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route { - routes := make(map[string][]*route.Route, len(e.clientRoutes)) +func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) for id, v := range e.clientRoutes { - if i := strings.LastIndex(id, "-"); i != -1 { - id = id[:i] - } - routes[id] = v + routes[id.NetID()] = v } return routes } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f487cc71e72..13a18cf392d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -578,7 +578,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { }{} mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { input.inputSerial = updateSerial input.inputRoutes = newRoutes return nil, nil, testCase.inputErr @@ -743,7 +743,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { return nil, nil, nil }, } diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index ef71bb60da8..e82f4b1dac3 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -33,7 +33,7 @@ type clientNetwork struct { stop context.CancelFunc statusRecorder *peer.Status wgInterface *iface.WGIface - routes map[string]*route.Route + routes map[route.ID]*route.Route routeUpdate chan routesUpdate peerStateUpdate chan struct{} routePeersNotifiers map[string]chan struct{} @@ -50,7 +50,7 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st stop: cancel, statusRecorder: statusRecorder, wgInterface: wgInterface, - routes: make(map[string]*route.Route), + routes: make(map[route.ID]*route.Route), routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), @@ -59,8 +59,8 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st return client } -func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { - routePeerStatuses := make(map[string]routerPeerStatus) +func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { + routePeerStatuses := make(map[route.ID]routerPeerStatus) for _, r := range c.routes { peerStatus, err := c.statusRecorder.GetPeer(r.Peer) if err != nil { @@ -90,12 +90,12 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { // * Latency: Routes with lower latency are prioritized. // // It returns the ID of the selected optimal route. -func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { - chosen := "" +func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { + chosen := route.ID("") chosenScore := float64(0) currScore := float64(0) - currID := "" + currID := route.ID("") if c.chosenRoute != nil { currID = c.chosenRoute.ID } @@ -295,7 +295,7 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { } func (c *clientNetwork) handleUpdate(update routesUpdate) { - updateMap := make(map[string]*route.Route) + updateMap := make(map[route.ID]*route.Route) for _, r := range update.routes { updateMap[r.ID] = r diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index ca1456c924e..9419ea777fe 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -12,21 +12,21 @@ func TestGetBestrouteFromStatuses(t *testing.T) { testCases := []struct { name string - statuses map[string]routerPeerStatus - expectedRouteID string - currentRoute string - existingRoutes map[string]*route.Route + statuses map[route.ID]routerPeerStatus + expectedRouteID route.ID + currentRoute route.ID + existingRoutes map[route.ID]*route.Route }{ { name: "one route", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: false, direct: true, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -38,14 +38,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "one connected routes with relayed and direct", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: true, direct: true, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -57,14 +57,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "one connected routes with relayed and no direct", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: true, direct: false, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -76,14 +76,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "no connected peers", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: false, relayed: false, direct: false, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -95,7 +95,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "multiple connected peers with different metrics", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: false, @@ -107,7 +107,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { direct: true, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: 9000, @@ -124,7 +124,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "multiple connected peers with one relayed", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: false, @@ -136,7 +136,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { direct: true, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -153,7 +153,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "multiple connected peers with one direct", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: false, @@ -165,7 +165,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { direct: false, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -182,7 +182,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "multiple connected peers with different latencies", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, latency: 300 * time.Millisecond, @@ -192,7 +192,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { latency: 10 * time.Millisecond, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -209,7 +209,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "should ignore routes with latency 0", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, latency: 0 * time.Millisecond, @@ -219,7 +219,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { latency: 10 * time.Millisecond, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -236,7 +236,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "current route with similar score and similar but slightly worse latency should not change", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: false, @@ -250,7 +250,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { latency: 10 * time.Millisecond, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -267,7 +267,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "current route with bad score should be changed to route with better score", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: false, @@ -281,7 +281,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { latency: 10 * time.Millisecond, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, @@ -298,7 +298,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, { name: "current chosen route doesn't exist anymore", - statuses: map[string]routerPeerStatus{ + statuses: map[route.ID]routerPeerStatus{ "route1": { connected: true, relayed: false, @@ -312,7 +312,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { latency: 10 * time.Millisecond, }, }, - existingRoutes: map[string]*route.Route{ + existingRoutes: map[route.ID]*route.Route{ "route1": { ID: "route1", Metric: route.MaxMetric, diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index dfc39102f9c..9f0f7421335 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -29,8 +29,8 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) // Manager is a route manager interface type Manager interface { Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) - TriggerSelection(map[string][]*route.Route) + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -43,7 +43,7 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex - clientNetworks map[string]*clientNetwork + clientNetworks map[route.HAUniqueID]*clientNetwork routeSelector *routeselector.RouteSelector serverRouter serverRouter statusRecorder *peer.Status @@ -57,7 +57,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, dm := &DefaultManager{ ctx: mCTX, stop: cancel, - clientNetworks: make(map[string]*clientNetwork), + clientNetworks: make(map[route.HAUniqueID]*clientNetwork), routeSelector: routeselector.NewRouteSelector(), statusRecorder: statusRecorder, wgInterface: wgInterface, @@ -122,7 +122,7 @@ func (m *DefaultManager) Stop() { } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") @@ -164,12 +164,12 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector { } // GetClientRoutes returns the client routes -func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork { +func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { return m.clientNetworks } // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones -func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) { +func (m *DefaultManager) TriggerSelection(networks route.HAMap) { m.mux.Lock() defer m.mux.Unlock() @@ -190,7 +190,7 @@ func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) { } // stopObsoleteClients stops the client network watcher for the networks that are not in the new list -func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) { +func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) { for id, client := range m.clientNetworks { if _, ok := networks[id]; !ok { log.Debugf("Stopping client network watcher, %s", id) @@ -200,7 +200,7 @@ func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) } } -func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { +func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) { // removing routes that do not exist as per the update from the Management service. m.stopObsoleteClients(networks) @@ -219,15 +219,15 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[ } } -func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) { - newClientRoutesIDMap := make(map[string][]*route.Route) - newServerRoutesMap := make(map[string]*route.Route) - ownNetworkIDs := make(map[string]bool) +func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) { + newClientRoutesIDMap := make(route.HAMap) + newServerRoutesMap := make(map[route.ID]*route.Route) + ownNetworkIDs := make(map[route.HAUniqueID]bool) for _, newRoute := range newRoutes { - networkID := route.GetHAUniqueID(newRoute) + haID := route.GetHAUniqueID(newRoute) if newRoute.Peer == m.pubKey { - ownNetworkIDs[networkID] = true + ownNetworkIDs[haID] = true // only linux is supported for now if runtime.GOOS != "linux" { log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) @@ -238,12 +238,12 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r } for _, newRoute := range newRoutes { - networkID := route.GetHAUniqueID(newRoute) - if !ownNetworkIDs[networkID] { + haID := route.GetHAUniqueID(newRoute) + if !ownNetworkIDs[haID] { if !isPrefixSupported(newRoute.Network) { continue } - newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) + newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute) } } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index b3464018ece..adbef80618d 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -14,8 +14,8 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) - TriggerSelectionFunc func(map[string][]*route.Route) + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector StopFunc func() } @@ -30,14 +30,14 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) } return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") } -func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) { +func (m *MockManager) TriggerSelection(networks route.HAMap) { if m.TriggerSelectionFunc != nil { m.TriggerSelectionFunc(networks) } diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier.go index ede8f02c4f0..d0c02612e37 100644 --- a/client/internal/routemanager/notifier.go +++ b/client/internal/routemanager/notifier.go @@ -36,7 +36,7 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { n.initialRouteRangers = nets } -func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) { +func (n *notifier) onNewRoutes(idMap route.HAMap) { newNets := make([]string, 0) for _, routes := range idMap { for _, r := range routes { diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go index c9a13a90414..368421eb70f 100644 --- a/client/internal/routemanager/server.go +++ b/client/internal/routemanager/server.go @@ -3,7 +3,7 @@ package routemanager import "github.com/netbirdio/netbird/route" type serverRouter interface { - updateRoutes(map[string]*route.Route) error + updateRoutes(map[route.ID]*route.Route) error removeFromServerNetwork(*route.Route) error cleanUp() } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index af82dc91349..95672e4805c 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -19,7 +19,7 @@ import ( type defaultServerRouter struct { mux sync.Mutex ctx context.Context - routes map[string]*route.Route + routes map[route.ID]*route.Route firewall firewall.Manager wgInterface *iface.WGIface statusRecorder *peer.Status @@ -28,15 +28,15 @@ type defaultServerRouter struct { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { return &defaultServerRouter{ ctx: ctx, - routes: make(map[string]*route.Route), + routes: make(map[route.ID]*route.Route), firewall: firewall, wgInterface: wgInterface, statusRecorder: statusRecorder, }, nil } -func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error { - serverRoutesToRemove := make([]string, 0) +func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { + serverRoutesToRemove := make([]route.ID, 0) for routeID := range m.routes { update, found := routesMap[routeID] @@ -168,7 +168,7 @@ func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, return firewall.RouterPair{}, err } return firewall.RouterPair{ - ID: route.ID, + ID: string(route.ID), Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 7bd93b46ebb..1c17e880393 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -12,22 +12,22 @@ import ( ) type RouteSelector struct { - selectedRoutes map[string]struct{} + selectedRoutes map[route.NetID]struct{} selectAll bool } func NewRouteSelector() *RouteSelector { return &RouteSelector{ - selectedRoutes: map[string]struct{}{}, + selectedRoutes: map[route.NetID]struct{}{}, // default selects all routes selectAll: true, } } // SelectRoutes updates the selected routes based on the provided route IDs. -func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error { +func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error { if !appendRoute { - rs.selectedRoutes = map[string]struct{}{} + rs.selectedRoutes = map[route.NetID]struct{}{} } var multiErr *multierror.Error @@ -51,15 +51,15 @@ func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRout // SelectAllRoutes sets the selector to select all routes. func (rs *RouteSelector) SelectAllRoutes() { rs.selectAll = true - rs.selectedRoutes = map[string]struct{}{} + rs.selectedRoutes = map[route.NetID]struct{}{} } // DeselectRoutes removes specific routes from the selection. // If the selector is in "select all" mode, it will transition to "select specific" mode. -func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error { +func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { if rs.selectAll { rs.selectAll = false - rs.selectedRoutes = map[string]struct{}{} + rs.selectedRoutes = map[route.NetID]struct{}{} for _, route := range allRoutes { rs.selectedRoutes[route] = struct{}{} } @@ -85,11 +85,11 @@ func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) err // DeselectAllRoutes deselects all routes, effectively disabling route selection. func (rs *RouteSelector) DeselectAllRoutes() { rs.selectAll = false - rs.selectedRoutes = map[string]struct{}{} + rs.selectedRoutes = map[route.NetID]struct{}{} } // IsSelected checks if a specific route is selected. -func (rs *RouteSelector) IsSelected(routeID string) bool { +func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { if rs.selectAll { return true } @@ -98,18 +98,14 @@ func (rs *RouteSelector) IsSelected(routeID string) bool { } // FilterSelected removes unselected routes from the provided map. -func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route { +func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { if rs.selectAll { return maps.Clone(routes) } - filtered := map[string][]*route.Route{} + filtered := route.HAMap{} for id, rt := range routes { - netID := id - if i := strings.LastIndex(id, "-"); i != -1 { - netID = id[:i] - } - if rs.IsSelected(netID) { + if rs.IsSelected(id.NetID()) { filtered[id] = rt } } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index b3d0547b591..fb1e456cd00 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -12,53 +12,53 @@ import ( ) func TestRouteSelector_SelectRoutes(t *testing.T) { - allRoutes := []string{"route1", "route2", "route3"} + allRoutes := []route.NetID{"route1", "route2", "route3"} tests := []struct { name string - initialSelected []string + initialSelected []route.NetID - selectRoutes []string + selectRoutes []route.NetID append bool - wantSelected []string + wantSelected []route.NetID wantError bool }{ { name: "Select specific routes, initial all selected", - selectRoutes: []string{"route1", "route2"}, - wantSelected: []string{"route1", "route2"}, + selectRoutes: []route.NetID{"route1", "route2"}, + wantSelected: []route.NetID{"route1", "route2"}, }, { name: "Select specific routes, initial all deselected", - initialSelected: []string{}, - selectRoutes: []string{"route1", "route2"}, - wantSelected: []string{"route1", "route2"}, + initialSelected: []route.NetID{}, + selectRoutes: []route.NetID{"route1", "route2"}, + wantSelected: []route.NetID{"route1", "route2"}, }, { name: "Select specific routes with initial selection", - initialSelected: []string{"route1"}, - selectRoutes: []string{"route2", "route3"}, - wantSelected: []string{"route2", "route3"}, + initialSelected: []route.NetID{"route1"}, + selectRoutes: []route.NetID{"route2", "route3"}, + wantSelected: []route.NetID{"route2", "route3"}, }, { name: "Select non-existing route", - selectRoutes: []string{"route1", "route4"}, - wantSelected: []string{"route1"}, + selectRoutes: []route.NetID{"route1", "route4"}, + wantSelected: []route.NetID{"route1"}, wantError: true, }, { name: "Append route with initial selection", - initialSelected: []string{"route1"}, - selectRoutes: []string{"route2"}, + initialSelected: []route.NetID{"route1"}, + selectRoutes: []route.NetID{"route2"}, append: true, - wantSelected: []string{"route1", "route2"}, + wantSelected: []route.NetID{"route1", "route2"}, }, { name: "Append route without initial selection", - selectRoutes: []string{"route2"}, + selectRoutes: []route.NetID{"route2"}, append: true, - wantSelected: []string{"route2"}, + wantSelected: []route.NetID{"route2"}, }, } @@ -86,32 +86,32 @@ func TestRouteSelector_SelectRoutes(t *testing.T) { } func TestRouteSelector_SelectAllRoutes(t *testing.T) { - allRoutes := []string{"route1", "route2", "route3"} + allRoutes := []route.NetID{"route1", "route2", "route3"} tests := []struct { name string - initialSelected []string + initialSelected []route.NetID - wantSelected []string + wantSelected []route.NetID }{ { name: "Initial all selected", - wantSelected: []string{"route1", "route2", "route3"}, + wantSelected: []route.NetID{"route1", "route2", "route3"}, }, { name: "Initial all deselected", - initialSelected: []string{}, - wantSelected: []string{"route1", "route2", "route3"}, + initialSelected: []route.NetID{}, + wantSelected: []route.NetID{"route1", "route2", "route3"}, }, { name: "Initial some selected", - initialSelected: []string{"route1"}, - wantSelected: []string{"route1", "route2", "route3"}, + initialSelected: []route.NetID{"route1"}, + wantSelected: []route.NetID{"route1", "route2", "route3"}, }, { name: "Initial all selected", - initialSelected: []string{"route1", "route2", "route3"}, - wantSelected: []string{"route1", "route2", "route3"}, + initialSelected: []route.NetID{"route1", "route2", "route3"}, + wantSelected: []route.NetID{"route1", "route2", "route3"}, }, } @@ -134,39 +134,39 @@ func TestRouteSelector_SelectAllRoutes(t *testing.T) { } func TestRouteSelector_DeselectRoutes(t *testing.T) { - allRoutes := []string{"route1", "route2", "route3"} + allRoutes := []route.NetID{"route1", "route2", "route3"} tests := []struct { name string - initialSelected []string + initialSelected []route.NetID - deselectRoutes []string + deselectRoutes []route.NetID - wantSelected []string + wantSelected []route.NetID wantError bool }{ { name: "Deselect specific routes, initial all selected", - deselectRoutes: []string{"route1", "route2"}, - wantSelected: []string{"route3"}, + deselectRoutes: []route.NetID{"route1", "route2"}, + wantSelected: []route.NetID{"route3"}, }, { name: "Deselect specific routes, initial all deselected", - initialSelected: []string{}, - deselectRoutes: []string{"route1", "route2"}, - wantSelected: []string{}, + initialSelected: []route.NetID{}, + deselectRoutes: []route.NetID{"route1", "route2"}, + wantSelected: []route.NetID{}, }, { name: "Deselect specific routes with initial selection", - initialSelected: []string{"route1", "route2"}, - deselectRoutes: []string{"route1", "route3"}, - wantSelected: []string{"route2"}, + initialSelected: []route.NetID{"route1", "route2"}, + deselectRoutes: []route.NetID{"route1", "route3"}, + wantSelected: []route.NetID{"route2"}, }, { name: "Deselect non-existing route", - initialSelected: []string{"route1", "route2"}, - deselectRoutes: []string{"route1", "route4"}, - wantSelected: []string{"route2"}, + initialSelected: []route.NetID{"route1", "route2"}, + deselectRoutes: []route.NetID{"route1", "route4"}, + wantSelected: []route.NetID{"route2"}, wantError: true, }, } @@ -195,32 +195,32 @@ func TestRouteSelector_DeselectRoutes(t *testing.T) { } func TestRouteSelector_DeselectAll(t *testing.T) { - allRoutes := []string{"route1", "route2", "route3"} + allRoutes := []route.NetID{"route1", "route2", "route3"} tests := []struct { name string - initialSelected []string + initialSelected []route.NetID - wantSelected []string + wantSelected []route.NetID }{ { name: "Initial all selected", - wantSelected: []string{}, + wantSelected: []route.NetID{}, }, { name: "Initial all deselected", - initialSelected: []string{}, - wantSelected: []string{}, + initialSelected: []route.NetID{}, + wantSelected: []route.NetID{}, }, { name: "Initial some selected", - initialSelected: []string{"route1", "route2"}, - wantSelected: []string{}, + initialSelected: []route.NetID{"route1", "route2"}, + wantSelected: []route.NetID{}, }, { name: "Initial all selected", - initialSelected: []string{"route1", "route2", "route3"}, - wantSelected: []string{}, + initialSelected: []route.NetID{"route1", "route2", "route3"}, + wantSelected: []route.NetID{}, }, } @@ -245,7 +245,7 @@ func TestRouteSelector_DeselectAll(t *testing.T) { func TestRouteSelector_IsSelected(t *testing.T) { rs := routeselector.NewRouteSelector() - err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) + err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"}) require.NoError(t, err) assert.True(t, rs.IsSelected("route1")) @@ -257,10 +257,10 @@ func TestRouteSelector_IsSelected(t *testing.T) { func TestRouteSelector_FilterSelected(t *testing.T) { rs := routeselector.NewRouteSelector() - err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) + err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"}) require.NoError(t, err) - routes := map[string][]*route.Route{ + routes := route.HAMap{ "route1-10.0.0.0/8": {}, "route2-192.168.0.0/16": {}, "route3-172.16.0.0/12": {}, @@ -268,7 +268,7 @@ func TestRouteSelector_FilterSelected(t *testing.T) { filtered := rs.FilterSelected(routes) - assert.Equal(t, map[string][]*route.Route{ + assert.Equal(t, route.HAMap{ "route1-10.0.0.0/8": {}, "route2-192.168.0.0/16": {}, }, filtered) diff --git a/client/server/route.go b/client/server/route.go index 018fe23c64d..768535d1815 100644 --- a/client/server/route.go +++ b/client/server/route.go @@ -9,10 +9,11 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" ) type selectRoute struct { - NetID string + NetID route.NetID Network netip.Prefix Selected bool } @@ -60,7 +61,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) ( var pbRoutes []*proto.Route for _, route := range routes { pbRoutes = append(pbRoutes, &proto.Route{ - ID: route.NetID, + ID: string(route.NetID), Network: route.Network.String(), Selected: route.Selected, }) @@ -81,7 +82,8 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) if req.GetAll() { routeSelector.SelectAllRoutes() } else { - if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetRouteIDs()) + if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { return nil, fmt.Errorf("select routes: %w", err) } } @@ -100,7 +102,8 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques if req.GetAll() { routeSelector.DeselectAllRoutes() } else { - if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetRouteIDs()) + if err := routeSelector.DeselectRoutes(routes, maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } } @@ -108,3 +111,11 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques return &proto.SelectRoutesResponse{}, nil } + +func toNetIDs(routes []string) []route.NetID { + var netIDs []route.NetID + for _, rt := range routes { + netIDs = append(netIDs, route.NetID(rt)) + } + return netIDs +} diff --git a/management/server/account.go b/management/server/account.go index aac13665749..da1e433703a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -100,10 +100,10 @@ type AccountManager interface { SavePolicy(accountID, userID string, policy *Policy) error DeletePolicy(accountID, policyID, userID string) error ListPolicies(accountID, userID string) ([]*Policy, error) - GetRoute(accountID, routeID, userID string) (*route.Route, error) - CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) + GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) + CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) SaveRoute(accountID, userID string, route *route.Route) error - DeleteRoute(accountID, routeID, userID string) error + DeleteRoute(accountID string, routeID route.ID, userID string) error ListRoutes(accountID, userID string) ([]*route.Route, error) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) @@ -229,7 +229,7 @@ type Account struct { Groups map[string]*nbgroup.Group `gorm:"-"` GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` - Routes map[string]*route.Route `gorm:"-"` + Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` @@ -266,7 +266,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) peerRoutesMembership := make(lookupMap) for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{} + peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{} } groupListMap := a.getPeerGroups(peerID) @@ -284,7 +284,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { var filteredRoutes []*route.Route for _, r := range routes { - _, found := peerMemberships[route.GetHAUniqueID(r)] + _, found := peerMemberships[string(route.GetHAUniqueID(r))] if !found { filteredRoutes = append(filteredRoutes, r) } @@ -323,7 +323,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro return enabledRoutes, disabledRoutes } - seenRoute := make(map[string]struct{}) + seenRoute := make(map[route.ID]struct{}) takeRoute := func(r *route.Route, id string) { if _, ok := seenRoute[r.ID]; ok { @@ -354,7 +354,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro newPeerRoute := r.Copy() newPeerRoute.Peer = id newPeerRoute.PeerGroups = nil - newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map takeRoute(newPeerRoute, id) break } @@ -693,7 +693,7 @@ func (a *Account) Copy() *Account { policies = append(policies, policy.Copy()) } - routes := map[string]*route.Route{} + routes := map[route.ID]*route.Route{} for id, r := range a.Routes { routes[id] = r.Copy() } @@ -1946,7 +1946,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { network := NewNetwork() peers := make(map[string]*nbpeer.Peer) users := make(map[string]*User) - routes := make(map[string]*route.Route) + routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) users[userID] = NewOwnerUser(userID) diff --git a/management/server/account_test.go b/management/server/account_test.go index a0eff239b54..456963361e3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1408,7 +1408,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { t.Fatal(err) } account := &Account{ - Routes: map[string]*route.Route{ + Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", Network: prefix, @@ -1437,12 +1437,12 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { routes := account.GetRoutesByPrefix(prefix) assert.Len(t, routes, 2) - routeIDs := make(map[string]struct{}, 2) + routeIDs := make(map[route.ID]struct{}, 2) for _, r := range routes { routeIDs[r.ID] = struct{}{} } - assert.Contains(t, routeIDs, "route-1") - assert.Contains(t, routeIDs, "route-2") + assert.Contains(t, routeIDs, route.ID("route-1")) + assert.Contains(t, routeIDs, route.ID("route-2")) } func TestAccount_GetRoutesToSync(t *testing.T) { @@ -1459,7 +1459,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, - Routes: map[string]*route.Route{ + Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", Network: prefix, @@ -1502,12 +1502,12 @@ func TestAccount_GetRoutesToSync(t *testing.T) { routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) assert.Len(t, routes, 2) - routeIDs := make(map[string]struct{}, 2) + routeIDs := make(map[route.ID]struct{}, 2) for _, r := range routes { routeIDs[r.ID] = struct{}{} } - assert.Contains(t, routeIDs, "route-2") - assert.Contains(t, routeIDs, "route-3") + assert.Contains(t, routeIDs, route.ID("route-2")) + assert.Contains(t, routeIDs, route.ID("route-3")) emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) @@ -1573,7 +1573,7 @@ func TestAccount_Copy(t *testing.T) { SourcePostureChecks: make([]string, 0), }, }, - Routes: map[string]*route.Route{ + Routes: map[route.ID]*route.Route{ "route1": { ID: "route1", PeerGroups: []string{}, diff --git a/management/server/group.go b/management/server/group.go index 0fc952cdbc9..0d93ab5e514 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -242,7 +242,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) for _, r := range account.Routes { for _, g := range r.Groups { if g == groupID { - return &GroupLinkError{"route", r.NetID} + return &GroupLinkError{"route", string(r.NetID)} } } } diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 348bdbfd688..f755e7a16a2 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -107,7 +107,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { newRoute, err := h.accountManager.CreateRoute( account.Id, newPrefix.String(), peerId, peerGroupIds, - req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, ) if err != nil { util.WriteError(err, w) @@ -135,7 +135,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id) + _, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) if err != nil { util.WriteError(err, w) return @@ -185,9 +185,9 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { } newRoute := &route.Route{ - ID: routeID, + ID: route.ID(routeID), Network: newPrefix, - NetID: req.NetworkId, + NetID: route.NetID(req.NetworkId), NetworkType: prefixType, Masquerade: req.Masquerade, Metric: req.Metric, @@ -230,7 +230,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteRoute(account.Id, routeID, user.Id) + err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id) if err != nil { util.WriteError(err, w) return @@ -254,7 +254,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id) + foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) if err != nil { util.WriteError(status.Errorf(status.NotFound, "route not found"), w) return @@ -265,9 +265,9 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { func toRouteResponse(serverRoute *route.Route) *api.Route { route := &api.Route{ - Id: serverRoute.ID, + Id: string(serverRoute.ID), Description: serverRoute.Description, - NetworkId: serverRoute.NetID, + NetworkId: string(serverRoute.NetID), Enabled: serverRoute.Enabled, Peer: &serverRoute.Peer, Network: serverRoute.Network.String(), diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index c02292f2a94..1c8288d5f7f 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -82,7 +82,7 @@ var testingAccount = &server.Account{ func initRoutesTestData() *RoutesHandler { return &RoutesHandler{ accountManager: &mock_server.MockAccountManager{ - GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) { + GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { return baseExistingRoute, nil } @@ -93,7 +93,7 @@ func initRoutesTestData() *RoutesHandler { } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { + CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -120,7 +120,7 @@ func initRoutesTestData() *RoutesHandler { } return nil }, - DeleteRouteFunc: func(_ string, routeID string, _ string) error { + DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error { if routeID != existingRouteID { return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID) } diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index c479867d294..c5b18607adc 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -67,7 +67,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { SourcePostureChecks: []string{"1"}, }, }, - Routes: map[string]*route.Route{ + Routes: map[route.ID]*route.Route{ "1": { ID: "1", PeerGroups: make([]string, 1), @@ -151,7 +151,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { }, }, }, - Routes: map[string]*route.Route{ + Routes: map[route.ID]*route.Route{ "1": { ID: "1", PeerGroups: make([]string, 1), diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8687937dc49..e3f0edd01dd 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -22,76 +22,76 @@ type MockAccountManager struct { GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(accountID string) ([]*server.User, error) - GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error - DeletePeerFunc func(accountID, peerKey, userID string) error - GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(peerKey string) (*server.Network, error) - AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(accountID, userID string, group *group.Group) error - DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*group.Group, error) - GroupAddPeerFunc func(accountID, groupID, peerID string) error - GroupDeletePeerFunc func(accountID, groupID, peerID string) error - DeleteRuleFunc func(accountID, ruleID, userID string) error - GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(accountID, userID string, policy *server.Policy) error - DeletePolicyFunc func(accountID, policyID, userID string) error - ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) - MarkPATUsedFunc func(pat string) error - UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error - UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) - GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) - SaveRouteFunc func(accountID, userID string, route *route.Route) error - DeleteRouteFunc func(accountID, routeID, userID string) error - ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error - CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) - DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) - GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error - ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) - CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error - DeleteAccountFunc func(accountID, userID string) error - GetDNSDomainFunc func() string - StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error - GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) - SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) - InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error - GetAllConnectedPeersFunc func() (map[string]struct{}, error) - HasConnectedChannelFunc func(peerID string) bool - GetExternalCacheManagerFunc func() server.ExternalCacheManager - GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error - DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error - ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) - GetIdpManagerFunc func() idp.Manager + GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) + GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) + GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) + ListUsersFunc func(accountID string) ([]*server.User, error) + GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) + MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error + DeletePeerFunc func(accountID, peerKey, userID string) error + GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) + GetPeerNetworkFunc func(peerKey string) (*server.Network, error) + AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) + GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(accountID, userID string, group *group.Group) error + DeleteGroupFunc func(accountID, userId, groupID string) error + ListGroupsFunc func(accountID string) ([]*group.Group, error) + GroupAddPeerFunc func(accountID, groupID, peerID string) error + GroupDeletePeerFunc func(accountID, groupID, peerID string) error + DeleteRuleFunc func(accountID, ruleID, userID string) error + GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) + SavePolicyFunc func(accountID, userID string, policy *server.Policy) error + DeletePolicyFunc func(accountID, policyID, userID string) error + ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) + GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) + GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + MarkPATUsedFunc func(pat string) error + UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error + UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error + UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) + GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error) + SaveRouteFunc func(accountID string, userID string, route *route.Route) error + DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error + ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) + SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) + ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) + SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) + SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) + DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error + CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error + ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error + DeleteAccountFunc func(accountID, userID string) error + GetDNSDomainFunc func() string + StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) + SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) + LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) + SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) + InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) + HasConnectedChannelFunc func(peerID string) bool + GetExternalCacheManagerFunc func() server.ExternalCacheManager + GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error + DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error + ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) + GetIdpManagerFunc func() idp.Manager UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error GroupValidationFunc func(accountId string, groups []string) (bool, error) } @@ -399,15 +399,15 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer. } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID) + return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } // GetRoute mock implementation of GetRoute from server.AccountManager interface -func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { +func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) { if am.GetRouteFunc != nil { return am.GetRouteFunc(accountID, routeID, userID) } @@ -415,7 +415,7 @@ func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*rout } // SaveRoute mock implementation of SaveRoute from server.AccountManager interface -func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.Route) error { +func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error { if am.SaveRouteFunc != nil { return am.SaveRouteFunc(accountID, userID, route) } @@ -423,7 +423,7 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R } // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface -func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error { +func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error { if am.DeleteRouteFunc != nil { return am.DeleteRouteFunc(accountID, routeID, userID) } diff --git a/management/server/route.go b/management/server/route.go index 4de552a2d43..0b7658441f0 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -13,7 +13,7 @@ import ( ) // GetRoute gets a route object from account and route IDs -func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { +func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -40,7 +40,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r } // checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error { +func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error { // routes can have both peer and peer_groups routesWithPrefix := account.GetRoutesByPrefix(prefix) @@ -56,7 +56,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account } if prefixRoute.Peer != "" { - seenPeers[prefixRoute.ID] = true + seenPeers[string(prefixRoute.ID)] = true } for _, groupID := range prefixRoute.PeerGroups { seenPeerGroups[groupID] = true @@ -114,7 +114,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { +func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -131,7 +131,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, } var newRoute route.Route - newRoute.ID = xid.New().String() + newRoute.ID = route.ID(xid.New().String()) prefixType, newPrefix, err := route.ParseNetwork(network) if err != nil { @@ -154,7 +154,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) } - if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" { + if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" { return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } @@ -175,7 +175,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, newRoute.Groups = groups if account.Routes == nil { - account.Routes = make(map[string]*route.Route) + account.Routes = make(map[route.ID]*route.Route) } account.Routes[newRoute.ID] = &newRoute @@ -187,7 +187,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, am.updateAccountPeers(account) - am.StoreEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta()) + am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) return &newRoute, nil } @@ -209,7 +209,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) } - if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" { + if utf8.RuneCountInString(string(routeToSave.NetID)) > route.MaxNetIDChar || routeToSave.NetID == "" { return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } @@ -248,13 +248,13 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave am.updateAccountPeers(account) - am.StoreEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta()) + am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) return nil } // DeleteRoute deletes route with routeID -func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { +func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -274,7 +274,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) return err } - am.StoreEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta()) + am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) am.updateAccountPeers(account) @@ -310,8 +310,8 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route. func toProtocolRoute(route *route.Route) *proto.Route { return &proto.Route{ - ID: route.ID, - NetID: route.NetID, + ID: string(route.ID), + NetID: string(route.NetID), Network: route.Network.String(), NetworkType: int64(route.NetworkType), Peer: route.Peer, diff --git a/management/server/route_test.go b/management/server/route_test.go index 9f8ea08c932..e06ac650c71 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -40,7 +40,7 @@ const ( func TestCreateRoute(t *testing.T) { type input struct { network string - netID string + netID route.NetID peerKey string peerGroupIDs []string description string @@ -382,8 +382,8 @@ func TestSaveRoute(t *testing.T) { invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34") validMetric := 1000 invalidMetric := 99999 - validNetID := "12345678901234567890qw" - invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1" + validNetID := route.NetID("12345678901234567890qw") + invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1") validGroupHA1 := routeGroupHA1 validGroupHA2 := routeGroupHA2 diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index bfde82a6de7..65281a4f841 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -451,7 +451,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { } account.GroupsG = nil - account.Routes = make(map[string]*route.Route, len(account.RoutesG)) + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) for _, route := range account.RoutesG { account.Routes[route.ID] = route.Copy() } diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go index 8a1bcd10aeb..cc033e61fef 100644 --- a/management/server/sqlite_store_test.go +++ b/management/server/sqlite_store_test.go @@ -2,8 +2,6 @@ package server import ( "fmt" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" "math/rand" "net" "net/netip" @@ -12,6 +10,9 @@ import ( "testing" "time" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -75,9 +76,9 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { } account.Users[user.Id] = user route := &route2.Route{ - ID: fmt.Sprintf("network-id-%d", n), + ID: route2.ID(fmt.Sprintf("network-id-%d", n)), Description: "base route", - NetID: fmt.Sprintf("network-id-%d", n), + NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)), Network: netip.MustParsePrefix(netIP.String() + "/24"), NetworkType: route2.IPv4Network, Metric: 9999, diff --git a/route/hauniqueid.go b/route/hauniqueid.go new file mode 100644 index 00000000000..6f74563e261 --- /dev/null +++ b/route/hauniqueid.go @@ -0,0 +1,22 @@ +package route + +import "strings" + +type HAUniqueID string + +// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address +func GetHAUniqueID(input *Route) HAUniqueID { + return HAUniqueID(string(input.NetID) + "-" + input.Network.String()) +} + +func (id HAUniqueID) String() string { + return string(id) +} + +// NetID returns the Network ID from the HAUniqueID +func (id HAUniqueID) NetID() NetID { + if i := strings.LastIndex(string(id), "-"); i != -1 { + return NetID(id[:i]) + } + return NetID(id) +} diff --git a/route/route.go b/route/route.go index 97c75f3b492..50c53cbe6da 100644 --- a/route/route.go +++ b/route/route.go @@ -36,6 +36,12 @@ const ( IPv6Network ) +type ID string + +type NetID string + +type HAMap map[HAUniqueID][]*Route + // NetworkType route network type type NetworkType int @@ -65,11 +71,11 @@ func ToPrefixType(prefix string) NetworkType { // Route represents a route type Route struct { - ID string `gorm:"primaryKey"` + ID ID `gorm:"primaryKey"` // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` Network netip.Prefix `gorm:"serializer:json"` - NetID string + NetID NetID Description string Peer string PeerGroups []string `gorm:"serializer:json"` @@ -165,8 +171,3 @@ func compareList(list, other []string) bool { return true } - -// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address -func GetHAUniqueID(input *Route) string { - return input.NetID + "-" + input.Network.String() -}