From 309d5a6322b85633f02cda428db66ce61cd54afa Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Tue, 16 May 2023 21:39:56 +0700 Subject: [PATCH] Fix vl3 dns configurations (#1460) Signed-off-by: Artem Glazychev --- pkg/networkservice/chains/nsmgr/vl3_test.go | 23 +++++++++----- .../connectioncontext/dnscontext/client.go | 13 +------- .../dnscontext/client_test.go | 12 ++++--- .../dnscontext/vl3dns/server.go | 31 +++++++++++++------ 4 files changed, 46 insertions(+), 33 deletions(-) diff --git a/pkg/networkservice/chains/nsmgr/vl3_test.go b/pkg/networkservice/chains/nsmgr/vl3_test.go index 1a21d4eff..813e60e74 100644 --- a/pkg/networkservice/chains/nsmgr/vl3_test.go +++ b/pkg/networkservice/chains/nsmgr/vl3_test.go @@ -360,8 +360,9 @@ func Test_Interdomain_vl3_dns(t *testing.T) { vl3.NewServer(ctx, serverPrefixCh), vl3dns.NewServer(ctx, dnsServerIPCh, + vl3dns.WithDNSPort(40053), vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ target .NetworkService }}.{{ domain .NetworkService }}."), - vl3dns.WithDNSPort(40053)), + ), checkrequest.NewServer(t, func(t *testing.T, nsr *networkservice.NetworkServiceRequest) { require.False(t, interdomain.Is(nsr.GetConnection().GetNetworkService())) }, @@ -394,18 +395,20 @@ func Test_Interdomain_vl3_dns(t *testing.T) { req.Connection = resp.Clone() require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 1) require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1) + require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].SearchDomains, 1) - requireIPv4Lookup(ctx, t, &resolver, nscName+".vl3", "10.0.0.1") + searchDomain := resp.GetContext().GetDnsContext().GetConfigs()[0].SearchDomains[0] + requireIPv4Lookup(ctx, t, &resolver, fmt.Sprintf("%s.%s", nscName, searchDomain), "10.0.0.1") resp, err = nsc.Request(ctx, req) require.NoError(t, err) - requireIPv4Lookup(ctx, t, &resolver, nscName+".vl3", "10.0.0.1") + requireIPv4Lookup(ctx, t, &resolver, fmt.Sprintf("%s.%s", nscName, searchDomain), "10.0.0.1") _, err = nsc.Close(ctx, resp) require.NoError(t, err) - _, err = resolver.LookupIP(ctx, "ip4", nscName+".vl3") + _, err = resolver.LookupIP(ctx, "ip4", fmt.Sprintf("%s.%s", nscName, searchDomain)) require.Error(t, err) } @@ -458,8 +461,9 @@ func Test_FloatingInterdomain_vl3_dns(t *testing.T) { vl3.NewServer(ctx, serverPrefixCh), vl3dns.NewServer(ctx, dnsServerIPCh, + vl3dns.WithDNSPort(40053), vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ target .NetworkService }}.{{ domain .NetworkService }}."), - vl3dns.WithDNSPort(40053)), + ), ) resolver := net.Resolver{ @@ -488,17 +492,20 @@ func Test_FloatingInterdomain_vl3_dns(t *testing.T) { req.Connection = resp.Clone() require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 1) require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1) + require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].SearchDomains, 3) + + searchDomain := resp.GetContext().GetDnsContext().GetConfigs()[0].SearchDomains[0] - requireIPv4Lookup(ctx, t, &resolver, nscName+".vl3."+floating.Name, "10.0.0.1") + requireIPv4Lookup(ctx, t, &resolver, fmt.Sprintf("%s.%s", nscName, searchDomain), "10.0.0.1") resp, err = nsc.Request(ctx, req) require.NoError(t, err) - requireIPv4Lookup(ctx, t, &resolver, nscName+".vl3."+floating.Name, "10.0.0.1") + requireIPv4Lookup(ctx, t, &resolver, fmt.Sprintf("%s.%s", nscName, searchDomain), "10.0.0.1") _, err = nsc.Close(ctx, resp) require.NoError(t, err) - _, err = resolver.LookupIP(ctx, "ip4", nscName+".vl3."+floating.Name) + _, err = resolver.LookupIP(ctx, "ip4", fmt.Sprintf("%s.%s", nscName, searchDomain)) require.Error(t, err) } diff --git a/pkg/networkservice/connectioncontext/dnscontext/client.go b/pkg/networkservice/connectioncontext/dnscontext/client.go index b93ec574e..afd8fba4c 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/client.go +++ b/pkg/networkservice/connectioncontext/dnscontext/client.go @@ -31,7 +31,6 @@ import ( "google.golang.org/grpc" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" - "github.com/networkservicemesh/sdk/pkg/tools/dnsutils" "github.com/networkservicemesh/sdk/pkg/tools/log" ) @@ -72,22 +71,12 @@ func (c *dnsContextClient) Request(ctx context.Context, request *networkservice. request.Connection.Context.DnsContext = &networkservice.DNSContext{} } - if !dnsutils.ContainsDNSConfig(request.GetConnection().GetContext().GetDnsContext().Configs, c.resolvconfDNSConfig) { - request.GetConnection().GetContext().GetDnsContext().Configs = append(request.GetConnection().GetContext().GetDnsContext().Configs, c.resolvconfDNSConfig) - } - rv, err := next.Client(ctx).Request(ctx, request, opts...) if err != nil { return nil, err } - var configs []*networkservice.DNSConfig - if rv.GetContext().GetDnsContext() != nil { - configs = rv.GetContext().GetDnsContext().GetConfigs() - } - - c.dnsConfigsMap.Store(rv.Id, configs) - + c.dnsConfigsMap.Store(rv.Id, append(rv.GetContext().GetDnsContext().Configs, c.resolvconfDNSConfig)) return rv, nil } diff --git a/pkg/networkservice/connectioncontext/dnscontext/client_test.go b/pkg/networkservice/connectioncontext/dnscontext/client_test.go index 327302a2c..8eacbf75c 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/client_test.go +++ b/pkg/networkservice/connectioncontext/dnscontext/client_test.go @@ -46,12 +46,13 @@ func Test_DNSContextClient_Usecases(t *testing.T) { err := os.WriteFile(resolveConfigPath, []byte("nameserver 8.8.4.4\nsearch example.com\n"), os.ModePerm) require.NoError(t, err) + dnsConfigMap := new(genericsync.Map[string, []*networkservice.DNSConfig]) client := chain.NewNetworkServiceClient( metadata.NewClient(), dnscontext.NewClient( dnscontext.WithChainContext(ctx), dnscontext.WithResolveConfigPath(resolveConfigPath), - dnscontext.WithDNSConfigsMap(new(genericsync.Map[string, []*networkservice.DNSConfig])), + dnscontext.WithDNSConfigsMap(dnsConfigMap), ), ) @@ -73,9 +74,12 @@ nameserver 127.0.0.1` resp, err := client.Request(ctx, request) require.NoError(t, err) require.NotNil(t, resp.GetContext().GetDnsContext()) - require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), len(request.GetConnection().Context.DnsContext.GetConfigs())) - require.Contains(t, resp.Context.DnsContext.Configs[0].DnsServerIps, "8.8.4.4") - require.Contains(t, resp.Context.DnsContext.Configs[0].SearchDomains, "example.com") + require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 0) + // Check updated dnsConfigMap + loadedDNSConfig, ok := dnsConfigMap.Load(resp.Id) + require.True(t, ok) + require.Contains(t, loadedDNSConfig[0].DnsServerIps, "8.8.4.4") + require.Contains(t, loadedDNSConfig[0].SearchDomains, "example.com") _, err = client.Close(ctx, resp) require.NoError(t, err) } diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go index e24981b29..a012f76d7 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go @@ -109,13 +109,7 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw var clientsConfigs = request.GetConnection().GetContext().GetDnsContext().GetConfigs() - dnsServerIPStr, err := n.addDNSContext(request.GetConnection()) - if err != nil { - return nil, err - } - - var recordNames []string - recordNames, err = n.buildSrcDNSRecords(request.GetConnection()) + recordNames, err := n.buildSrcDNSRecords(request.GetConnection()) if err != nil { return nil, err } @@ -129,6 +123,11 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw } } + dnsServerIPStr, err := n.addDNSContext(request.GetConnection(), recordNames) + if err != nil { + return nil, err + } + resp, err := next.Server(ctx).Request(ctx, request) if err == nil { ips := getSrcIPs(resp) @@ -171,12 +170,26 @@ func (n *vl3DNSServer) Close(ctx context.Context, conn *networkservice.Connectio return next.Server(ctx).Close(ctx, conn) } -func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string, err error) { +func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection, dnsRecords []string) (serverIP string, err error) { if ip := n.dnsServerIP.Load(); ip != nil { dnsServerIP := ip.(net.IP) + + // Construct searchDomains for a client + // Example: dnsRecord = "target.d1.d2.d3." ---> searchDomain = ["d1.d2.d3", "d2.d3", "d3"] + var searchDomains []string + for _, dnsRecord := range dnsRecords { + var ok bool + searchDomain := dnsRecord + for _, searchDomain, ok = strings.Cut(strings.Trim(searchDomain, "."), "."); ok; _, searchDomain, ok = strings.Cut(searchDomain, ".") { + searchDomains = append(searchDomains, searchDomain) + } + } + + // Add dnsConfig to the connection var dnsContext = c.GetContext().GetDnsContext() configToAdd := &networkservice.DNSConfig{ - DnsServerIps: []string{dnsServerIP.String()}, + DnsServerIps: []string{dnsServerIP.String()}, + SearchDomains: searchDomains, } if !dnsutils.ContainsDNSConfig(dnsContext.Configs, configToAdd) { dnsContext.Configs = append(dnsContext.Configs, configToAdd)