Skip to content

Commit

Permalink
Refactor code to add context
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga committed Jul 3, 2024
1 parent 0dd69ca commit 38b23a2
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
}

routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect)
routesFirewallRules := a.getPeerRoutesFirewallRules(peerID, validatedPeersMap)
routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)

dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
Expand Down
8 changes: 4 additions & 4 deletions management/server/http/policies_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (h *Policies) savePolicy(
}

if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) {
util.WriteError(status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w)
return
}

Expand All @@ -193,7 +193,7 @@ func (h *Policies) savePolicy(
if rule.PortRanges != nil && len(*rule.PortRanges) != 0 {
for _, portRange := range *rule.PortRanges {
if portRange.Start < 1 || portRange.End > 65535 {
util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
return
}
pr.PortRanges = append(pr.PortRanges, server.RulePortRange{
Expand All @@ -206,7 +206,7 @@ func (h *Policies) savePolicy(
// validate policy object
switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
if len(pr.Ports) == 0 || len(pr.PortRanges) != 0{
if len(pr.Ports) == 0 || len(pr.PortRanges) != 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return
}
Expand All @@ -215,7 +215,7 @@ func (h *Policies) savePolicy(
return
}
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0){
if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
Expand Down
16 changes: 8 additions & 8 deletions management/server/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,10 @@ func getPlaceholderIP() netip.Prefix {
}

// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
func (a *Account) getPeerRoutesFirewallRules(peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))

enabledRoutes, _ := a.getRoutingPeerRoutes(peerID)
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
for _, route := range enabledRoutes {
// If no access control groups are specified, accept all incoming traffic.
if len(route.AccessControlGroups) == 0 {
Expand All @@ -444,8 +444,8 @@ func (a *Account) getPeerRoutesFirewallRules(peerID string, validatedPeersMap ma
continue
}

distributionGroupPeers, _ := getAllPeersFromGroups(a, route.Groups, peerID, nil, validatedPeersMap)
rules := generateRouteFirewallRules(route, rule, distributionGroupPeers, firewallRuleDirectionIN)
distributionGroupPeers, _ := getAllPeersFromGroups(ctx, a, route.Groups, peerID, nil, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
Expand Down Expand Up @@ -481,7 +481,7 @@ func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []strin
}

// generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{})
rules := make([]*RouteFirewallRule, 0)

Expand All @@ -505,7 +505,7 @@ func generateRouteFirewallRules(route *route.Route, rule *PolicyRule, groupPeers
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
continue
}
rules = append(rules, generateRulesWithPorts(baseRule, rule, rulesExists)...)
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
}

return rules
Expand Down Expand Up @@ -545,7 +545,7 @@ func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, r
}

// generateRulesWithPorts generates rules when specific ports are provided.
func generateRulesWithPorts(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
rules := make([]*RouteFirewallRule, 0)
ruleIDBase := generateRuleIDBase(rule, baseRule)

Expand All @@ -559,7 +559,7 @@ func generateRulesWithPorts(baseRule RouteFirewallRule, rule *PolicyRule, rulesE
pr := baseRule
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
log.Errorf("failed to parse port %s for rule: %s", port, rule.ID)
log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
continue
}

Expand Down
16 changes: 8 additions & 8 deletions management/server/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,9 @@ func TestCreateRoute(t *testing.T) {
if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll()
require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID},[]string{}, true, userID, false)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID},[]string{groupAll.ID}, true, userID, false)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit)
}

Expand Down Expand Up @@ -1073,7 +1073,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")

newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups,baseRoute.Enabled, userID, baseRoute.KeepRoute)
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true)

Expand Down Expand Up @@ -1165,7 +1165,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")

createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups,false, userID, baseRoute.KeepRoute)
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err)

noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
Expand Down Expand Up @@ -1697,7 +1697,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
})

t.Run("check peer routes firewall rules", func(t *testing.T) {
routesFirewallRules := account.getPeerRoutesFirewallRules("peerA", validatedPeers)
routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
assert.Len(t, routesFirewallRules, 6)

expectedRoutesFirewallRules := []*RouteFirewallRule{
Expand Down Expand Up @@ -1759,12 +1759,12 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)

// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.getPeerRoutesFirewallRules("peerD", validatedPeers)
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 6)
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)

// peerE is a single routing peer for route 2 and route 3
routesFirewallRules = account.getPeerRoutesFirewallRules("peerE", validatedPeers)
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
assert.Len(t, routesFirewallRules, 3)

expectedRoutesFirewallRules = []*RouteFirewallRule{
Expand Down Expand Up @@ -1798,7 +1798,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)

// peerC is part of route1 distribution groups but should not receive the routes firewall rules
routesFirewallRules = account.getPeerRoutesFirewallRules("peerC", validatedPeers)
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, routesFirewallRules, 0)
})

Expand Down

0 comments on commit 38b23a2

Please sign in to comment.