diff --git a/pkg/networkservice/chains/nsmgr/heal_test.go b/pkg/networkservice/chains/nsmgr/heal_test.go index c3dfd2be0..89bbdf4b0 100644 --- a/pkg/networkservice/chains/nsmgr/heal_test.go +++ b/pkg/networkservice/chains/nsmgr/heal_test.go @@ -538,10 +538,6 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { SetNSMgrProxySupplier(nil). SetRegistryProxySupplier(nil) - if withNSEExpiration { - builder = builder.SetRegistryExpiryDuration(time.Second / 2) - } - domain := builder.Build() nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) @@ -549,9 +545,13 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) - nseCtx, nseCtxCancel := context.WithCancel(ctx) - - domain.Nodes[0].NewEndpoint(nseCtx, defaultRegistryEndpoint(nsReg.Name), sandbox.GenerateTestToken) + nseCtx, nseCtxCancel := context.WithTimeout(ctx, time.Second/2) + if withNSEExpiration { + // NSE will be unregistered after (tokenTimeout - registerTimeout) + domain.Nodes[0].NewEndpoint(nseCtx, defaultRegistryEndpoint(nsReg.Name), sandbox.GenerateExpiringToken(time.Second)) + } else { + domain.Nodes[0].NewEndpoint(nseCtx, defaultRegistryEndpoint(nsReg.Name), sandbox.GenerateTestToken) + } request := defaultRequest(nsReg.Name) @@ -601,10 +601,6 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { nscCtxCancel() - for _, fwd := range domain.Nodes[0].Forwarders { - fwd.Cancel() - } - require.Eventually(t, func() bool { logrus.Error(goleak.Find()) return goleak.Find(ignoreCurrent) == nil diff --git a/pkg/networkservice/chains/nsmgr/server.go b/pkg/networkservice/chains/nsmgr/server.go index eaf4ac4ed..e1db71565 100644 --- a/pkg/networkservice/chains/nsmgr/server.go +++ b/pkg/networkservice/chains/nsmgr/server.go @@ -272,11 +272,11 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options var nseRegistry = chain.NewNetworkServiceEndpointRegistryServer( grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - begin.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, + begin.NewNetworkServiceEndpointRegistryServer(), registryclientinfo.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, time.Minute), + expire.NewNetworkServiceEndpointRegistryServer(ctx), registryrecvfd.NewNetworkServiceEndpointRegistryServer(), // Allow to receive a passed files registrysendfd.NewNetworkServiceEndpointRegistryServer(), remoteOrLocalRegistry, diff --git a/pkg/networkservice/chains/nsmgr/single_test.go b/pkg/networkservice/chains/nsmgr/single_test.go index 8203ff241..1186a50a5 100644 --- a/pkg/networkservice/chains/nsmgr/single_test.go +++ b/pkg/networkservice/chains/nsmgr/single_test.go @@ -27,25 +27,32 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" - "github.com/networkservicemesh/api/pkg/api/networkservice" - "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/cls" - kernelmech "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" - registryapi "github.com/networkservicemesh/api/pkg/api/registry" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - registryclient "github.com/networkservicemesh/sdk/pkg/registry/chains/client" + "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/cls" + kernelmech "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" + registryapi "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/networkservice/chains/client" + "github.com/networkservicemesh/sdk/pkg/networkservice/chains/endpoint" "github.com/networkservicemesh/sdk/pkg/networkservice/chains/nsmgr" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize" "github.com/networkservicemesh/sdk/pkg/networkservice/common/excludedprefixes" "github.com/networkservicemesh/sdk/pkg/networkservice/ipam/point2pointipam" + countutils "github.com/networkservicemesh/sdk/pkg/networkservice/utils/count" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/inject/injecterror" "github.com/networkservicemesh/sdk/pkg/registry" + registryclient "github.com/networkservicemesh/sdk/pkg/registry/chains/client" "github.com/networkservicemesh/sdk/pkg/registry/chains/memory" - "github.com/networkservicemesh/sdk/pkg/registry/common/authorize" + authorizeregistry "github.com/networkservicemesh/sdk/pkg/registry/common/authorize" + "github.com/networkservicemesh/sdk/pkg/registry/common/sendfd" + injecterrorregistry "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injecterror" "github.com/networkservicemesh/sdk/pkg/tools/clientinfo" + "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" "github.com/networkservicemesh/sdk/pkg/tools/sandbox" "github.com/networkservicemesh/sdk/pkg/tools/token" ) @@ -257,7 +264,6 @@ func Test_UsecasePoint2MultiPoint(t *testing.T) { SetNodeSetup(func(ctx context.Context, node *sandbox.Node, _ int) { node.NewNSMgr(ctx, "nsmgr", nil, sandbox.GenerateTestToken, nsmgr.NewServer) }). - SetRegistryExpiryDuration(time.Second). Build() domain.Nodes[0].NewForwarder(ctx, ®istryapi.NetworkServiceEndpoint{ @@ -379,7 +385,6 @@ func Test_RemoteUsecase_Point2MultiPoint(t *testing.T) { SetNodeSetup(func(ctx context.Context, node *sandbox.Node, _ int) { node.NewNSMgr(ctx, "nsmgr", nil, sandbox.GenerateTestToken, nsmgr.NewServer) }). - SetRegistryExpiryDuration(time.Second). Build() for i := 0; i < nodeCount; i++ { @@ -500,13 +505,13 @@ func Test_FailedRegistryAuthorization(t *testing.T) { nsmgrSuppier := func(ctx context.Context, tokenGenerator token.GeneratorFunc, options ...nsmgr.Option) nsmgr.Nsmgr { options = append(options, nsmgr.WithAuthorizeNSERegistryServer( - authorize.NewNetworkServiceEndpointRegistryServer(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceEndpointRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), nsmgr.WithAuthorizeNSRegistryServer( - authorize.NewNetworkServiceRegistryServer(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), nsmgr.WithAuthorizeNSERegistryClient( - authorize.NewNetworkServiceEndpointRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceEndpointRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), nsmgr.WithAuthorizeNSRegistryClient( - authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), ) return nsmgr.NewServer(ctx, tokenGenerator, options...) } @@ -514,7 +519,6 @@ func Test_FailedRegistryAuthorization(t *testing.T) { registrySupplier := func( ctx context.Context, tokenGenerator token.GeneratorFunc, - expiryDuration time.Duration, proxyRegistryURL *url.URL, options ...grpc.DialOption) registry.Registry { registryName := sandbox.UniqueName("registry-memory") @@ -522,11 +526,10 @@ func Test_FailedRegistryAuthorization(t *testing.T) { return memory.NewServer( ctx, tokenGeneratorFunc("spiffe://test.com/"+registryName), - memory.WithExpireDuration(expiryDuration), memory.WithProxyRegistryURL(proxyRegistryURL), memory.WithDialOptions(options...), memory.WithAuthorizeNSRegistryServer( - authorize.NewNetworkServiceRegistryServer(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), ) } domain := sandbox.NewBuilder(ctx, t). @@ -553,7 +556,7 @@ func Test_FailedRegistryAuthorization(t *testing.T) { nsRegistryClient1 := domain.NewNSRegistryClient(ctx, tokenGeneratorFunc("spiffe://test.com/ns-1"), registryclient.WithAuthorizeNSRegistryClient( - authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) + authorizeregistry.NewNetworkServiceRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) ns1 := defaultRegistryService("ns-1") _, err := nsRegistryClient1.Register(ctx, ns1) @@ -561,9 +564,184 @@ func Test_FailedRegistryAuthorization(t *testing.T) { nsRegistryClient2 := domain.NewNSRegistryClient(ctx, tokenGeneratorFunc("spiffe://test.com/ns-2"), registryclient.WithAuthorizeNSRegistryClient( - authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) + authorizeregistry.NewNetworkServiceRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) ns2 := defaultRegistryService("ns-1") _, err = nsRegistryClient2.Register(ctx, ns2) require.Error(t, err) } + +func createAuthorizedEndpoint(ctx context.Context, t *testing.T, ns string, nsmgrURL *url.URL, counter networkservice.NetworkServiceServer) { + nseReg := defaultRegistryEndpoint(ns) + + nse := endpoint.NewServer(ctx, sandbox.GenerateTestToken, + endpoint.WithName("final-endpoint"), + endpoint.WithAuthorizeServer(authorize.NewServer(authorize.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + endpoint.WithAdditionalFunctionality(counter), + ) + + nseServer := grpc.NewServer() + nse.Register(nseServer) + nseURL := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} + errCh := grpcutils.ListenAndServe(ctx, nseURL, nseServer) + select { + case err := <-errCh: + require.NoError(t, err) + default: + } + + nseRegistryClient := registryclient.NewNetworkServiceEndpointRegistryClient( + ctx, + registryclient.WithClientURL(nsmgrURL), + registryclient.WithDialOptions(sandbox.DialOptions(sandbox.WithTokenGenerator(sandbox.GenerateTestToken))...), + registryclient.WithNSEAdditionalFunctionality(sendfd.NewNetworkServiceEndpointRegistryClient()), + ) + + nseReg.Url = nseURL.String() + _, err := nseRegistryClient.Register(ctx, nseReg.Clone()) + require.NoError(t, err) +} + +// This test checks timeout on sandbox +// We run nsmgr and NSE with networkservice authorize chain element (tokens_expired.rego) +func Test_Timeout(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + // timeout chain element will call Close() after (tokenTimeout - requestTimeout) + // to be sure that token is not expired + tokenTimeout := time.Second * 2 + requestTimeout := time.Second + time.Millisecond*500 + + chainCtx, chainCtxCancel := context.WithTimeout(context.Background(), time.Second*5) + defer chainCtxCancel() + + // Set tokens_expired policy + nsmgrSuppier := func(ctx context.Context, tokenGenerator token.GeneratorFunc, options ...nsmgr.Option) nsmgr.Nsmgr { + options = append(options, + nsmgr.WithAuthorizeServer(authorize.NewServer(authorize.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + ) + return nsmgr.NewServer(ctx, tokenGenerator, options...) + } + + domain := sandbox.NewBuilder(chainCtx, t). + SetNodesCount(1). + SetNSMgrSupplier(nsmgrSuppier). + Build() + + nsRegistryClient := domain.NewNSRegistryClient(chainCtx, sandbox.GenerateTestToken) + ns := defaultRegistryService("ns") + + nsReg, err := nsRegistryClient.Register(chainCtx, ns) + require.NoError(t, err) + + counter := new(countutils.Server) + + createAuthorizedEndpoint(chainCtx, t, ns.Name, domain.Nodes[0].NSMgr.URL, counter) + + // Set an expiring token. + // Add injecterror to allow only the first Request. All subsequent ones will fall. + // This emulates the death of the client. + nsc := domain.Nodes[0].NewClient(chainCtx, + sandbox.GenerateExpiringToken(tokenTimeout), + client.WithAdditionalFunctionality( + injecterror.NewClient(injecterror.WithRequestErrorTimes(1, -1)), + ), + ) + + request := defaultRequest(nsReg.Name) + requestCtx, requestCtxCancel := context.WithTimeout(context.Background(), requestTimeout) + defer requestCtxCancel() + + conn, err := nsc.Request(requestCtx, request) + require.NoError(t, err) + require.NotNil(t, conn) + // Closes equal to 0 for now + require.Equal(t, 1, counter.Requests()) + require.Equal(t, 0, counter.Closes()) + + // Waiting for the timeout + require.Eventually(t, func() bool { return counter.Closes() == 1 }, tokenTimeout, time.Millisecond*100) +} + +// This test checks registry expire on sandbox +// We run nsmgr and registry with registry authorize chain element (tokens_expired.rego) +func Test_Expire(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + // expire chain element will call Unregister() after (tokenTimeout - registerTimeout) + // to be sure that token is not expired + tokenTimeout := time.Second * 2 + registerTimeout := time.Second + time.Millisecond*500 + + chainCtx, chainCtxCancel := context.WithTimeout(context.Background(), time.Second*5) + defer chainCtxCancel() + + // Set tokens_expired policy for nsmgr and registry + nsmgrSuppier := func(ctx context.Context, tokenGenerator token.GeneratorFunc, options ...nsmgr.Option) nsmgr.Nsmgr { + options = append(options, + nsmgr.WithAuthorizeNSERegistryServer( + authorizeregistry.NewNetworkServiceEndpointRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + ) + return nsmgr.NewServer(ctx, tokenGenerator, options...) + } + + registrySupplier := func( + ctx context.Context, + tokenGenerator token.GeneratorFunc, + proxyRegistryURL *url.URL, + options ...grpc.DialOption) registry.Registry { + return memory.NewServer( + ctx, + tokenGenerator, + memory.WithProxyRegistryURL(proxyRegistryURL), + memory.WithDialOptions(options...), + memory.WithAuthorizeNSRegistryServer( + authorizeregistry.NewNetworkServiceRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + ) + } + + domain := sandbox.NewBuilder(chainCtx, t). + SetNodesCount(1). + SetNSMgrSupplier(nsmgrSuppier). + SetRegistrySupplier(registrySupplier). + Build() + + nsRegistryClient := domain.NewNSRegistryClient(chainCtx, sandbox.GenerateTestToken) + ns := defaultRegistryService("ns") + + nsReg, err := nsRegistryClient.Register(chainCtx, ns) + require.NoError(t, err) + + // Set an expiring token. + // Add injecterrorregistry to allow only the first Register. All subsequent ones will fall. + // This emulates the death of the NSE. + nseRegistryClient := registryclient.NewNetworkServiceEndpointRegistryClient(chainCtx, + registryclient.WithClientURL(domain.Nodes[0].NSMgr.URL), + registryclient.WithDialOptions(sandbox.DialOptions(sandbox.WithTokenGenerator(sandbox.GenerateExpiringToken(tokenTimeout)))...), + registryclient.WithNSEAdditionalFunctionality( + injecterrorregistry.NewNetworkServiceEndpointRegistryClient( + injecterrorregistry.WithRegisterErrorTimes(1, -1), + injecterrorregistry.WithFindErrorTimes())), + ) + + registerCtx, registerCtxCancel := context.WithTimeout(context.Background(), registerTimeout) + defer registerCtxCancel() + _, err = nseRegistryClient.Register(registerCtx, ®istryapi.NetworkServiceEndpoint{ + Name: "final-endpoint", + Url: "nseURL", + NetworkServiceNames: []string{nsReg.Name}, + }) + require.NoError(t, err) + + // Wait for the endpoint expiration + time.Sleep(tokenTimeout) + stream, err := nseRegistryClient.Find(chainCtx, ®istryapi.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istryapi.NetworkServiceEndpoint{ + Name: "final-endpoint", + }, + }) + require.NoError(t, err) + + // Eventually expire will call Unregister + require.Len(t, registryapi.ReadNetworkServiceEndpointList(stream), 0) +} diff --git a/pkg/networkservice/chains/nsmgrproxy/server.go b/pkg/networkservice/chains/nsmgrproxy/server.go index d8c2b99b9..2a0795727 100644 --- a/pkg/networkservice/chains/nsmgrproxy/server.go +++ b/pkg/networkservice/chains/nsmgrproxy/server.go @@ -292,9 +292,9 @@ func NewServer(ctx context.Context, regURL, proxyURL *url.URL, tokenGenerator to var nseServerChain = chain.NewNetworkServiceEndpointRegistryServer( grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - begin.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, + begin.NewNetworkServiceEndpointRegistryServer(), clienturl.NewNetworkServiceEndpointRegistryServer(proxyURL), interdomainBypassNSEServer, registryswapip.NewNetworkServiceEndpointRegistryServer(opts.openMapIPChannel(ctx)), diff --git a/pkg/networkservice/common/timeout/server.go b/pkg/networkservice/common/timeout/server.go index f82eb2ff0..45515ddc2 100644 --- a/pkg/networkservice/common/timeout/server.go +++ b/pkg/networkservice/common/timeout/server.go @@ -51,6 +51,14 @@ func NewServer(ctx context.Context) networkservice.NetworkServiceServer { } func (s *timeoutServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (conn *networkservice.Connection, err error) { + timeClock := clock.FromContext(ctx) + + deadline, ok := ctx.Deadline() + requestTimeout := timeClock.Until(deadline) + if !ok { + requestTimeout = 0 + } + conn, err = next.Server(ctx).Request(ctx, request) if err != nil { return nil, err @@ -67,8 +75,8 @@ func (s *timeoutServer) Request(ctx context.Context, request *networkservice.Net } store(ctx, metadata.IsClient(s), cancel) eventFactory := begin.FromContext(ctx) - timeClock := clock.FromContext(ctx) - afterCh := timeClock.After(timeClock.Until(expirationTime)) + afterCh := timeClock.After(timeClock.Until(expirationTime) - requestTimeout) + go func(cancelCtx context.Context, afterCh <-chan time.Time) { select { case <-cancelCtx.Done(): diff --git a/pkg/registry/chains/memory/server.go b/pkg/registry/chains/memory/server.go index c35b84595..bddf005ee 100644 --- a/pkg/registry/chains/memory/server.go +++ b/pkg/registry/chains/memory/server.go @@ -20,7 +20,6 @@ package memory import ( "context" "net/url" - "time" "google.golang.org/grpc" @@ -51,7 +50,6 @@ type serverOptions struct { authorizeNSERegistryServer registry.NetworkServiceEndpointRegistryServer authorizeNSRegistryClient registry.NetworkServiceRegistryClient authorizeNSERegistryClient registry.NetworkServiceEndpointRegistryClient - expireDuration time.Duration proxyRegistryURL *url.URL dialOptions []grpc.DialOption } @@ -99,13 +97,6 @@ func WithAuthorizeNSERegistryClient(authorizeNSERegistryClient registry.NetworkS } } -// WithExpireDuration sets expire duration for the server -func WithExpireDuration(expireDuration time.Duration) Option { - return func(o *serverOptions) { - o.expireDuration = expireDuration - } -} - // WithProxyRegistryURL sets URL to reach the proxy registry func WithProxyRegistryURL(proxyRegistryURL *url.URL) Option { return func(o *serverOptions) { @@ -127,7 +118,6 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options authorizeNSERegistryServer: registryauthorize.NewNetworkServiceEndpointRegistryServer(registryauthorize.Any()), authorizeNSRegistryClient: registryauthorize.NewNetworkServiceRegistryClient(registryauthorize.Any()), authorizeNSERegistryClient: registryauthorize.NewNetworkServiceEndpointRegistryClient(registryauthorize.Any()), - expireDuration: time.Minute, proxyRegistryURL: nil, } for _, opt := range options { @@ -136,9 +126,9 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options nseChain := chain.NewNetworkServiceEndpointRegistryServer( grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - begin.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, + begin.NewNetworkServiceEndpointRegistryServer(), switchcase.NewNetworkServiceEndpointRegistryServer(switchcase.NSEServerCase{ Condition: func(c context.Context, nse *registry.NetworkServiceEndpoint) bool { if interdomain.Is(nse.GetName()) { @@ -171,7 +161,7 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options Condition: func(c context.Context, nse *registry.NetworkServiceEndpoint) bool { return true }, Action: chain.NewNetworkServiceEndpointRegistryServer( setregistrationtime.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, opts.expireDuration), + expire.NewNetworkServiceEndpointRegistryServer(ctx), memory.NewNetworkServiceEndpointRegistryServer(), ), }, diff --git a/pkg/registry/chains/proxydns/server.go b/pkg/registry/chains/proxydns/server.go index b5bcd136a..7f61e5c9f 100644 --- a/pkg/registry/chains/proxydns/server.go +++ b/pkg/registry/chains/proxydns/server.go @@ -110,9 +110,9 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, dnsResol nseChain := chain.NewNetworkServiceEndpointRegistryServer( grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - begin.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, + begin.NewNetworkServiceEndpointRegistryServer(), dnsresolve.NewNetworkServiceEndpointRegistryServer(dnsresolve.WithResolver(dnsResolver)), connect.NewNetworkServiceEndpointRegistryServer( chain.NewNetworkServiceEndpointRegistryClient( diff --git a/pkg/registry/common/authorize/ns_client.go b/pkg/registry/common/authorize/ns_client.go index e2cdb4f8f..b26997618 100644 --- a/pkg/registry/common/authorize/ns_client.go +++ b/pkg/registry/common/authorize/ns_client.go @@ -19,7 +19,6 @@ package authorize import ( "context" - "sync/atomic" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" @@ -36,7 +35,6 @@ import ( type authorizeNSClient struct { policies policiesList nsPathIdsMap *PathIdsMap - serverPeer atomic.Value } // NewNetworkServiceRegistryClient - returns a new authorization registry.NetworkServiceRegistryClient @@ -75,7 +73,6 @@ func (c *authorizeNSClient) Register(ctx context.Context, ns *registry.NetworkSe } if p != (peer.Peer{}) { - c.serverPeer.Store(&p) ctx = peer.NewContext(ctx, &p) } @@ -117,12 +114,18 @@ func (c *authorizeNSClient) Unregister(ctx context.Context, ns *registry.Network path := grpcmetadata.PathFromContext(ctx) ctx = grpcmetadata.PathWithContext(ctx, path) + var p peer.Peer + opts = append(opts, grpc.Peer(&p)) + resp, err := next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) if err != nil { return nil, err } - path = grpcmetadata.PathFromContext(ctx) + if p != (peer.Peer{}) { + ctx = peer.NewContext(ctx, &p) + } + spiffeID := getSpiffeIDFromPath(ctx, path) rawMap := getRawMap(c.nsPathIdsMap) diff --git a/pkg/registry/common/authorize/nse_client.go b/pkg/registry/common/authorize/nse_client.go index e51730a8a..fb22905d7 100644 --- a/pkg/registry/common/authorize/nse_client.go +++ b/pkg/registry/common/authorize/nse_client.go @@ -19,7 +19,6 @@ package authorize import ( "context" - "sync/atomic" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" @@ -36,7 +35,6 @@ import ( type authorizeNSEClient struct { policies policiesList nsePathIdsMap *PathIdsMap - serverPeer atomic.Value } // NewNetworkServiceEndpointRegistryClient - returns a new authorization registry.NetworkServiceEndpointRegistryClient @@ -76,13 +74,10 @@ func (c *authorizeNSEClient) Register(ctx context.Context, nse *registry.Network } if p != (peer.Peer{}) { - c.serverPeer.Store(&p) ctx = peer.NewContext(ctx, &p) } - path = grpcmetadata.PathFromContext(ctx) spiffeID := getSpiffeIDFromPath(ctx, path) - rawMap := getRawMap(c.nsePathIdsMap) input := RegistryOpaInput{ ResourceID: spiffeID.String(), @@ -116,22 +111,21 @@ func (c *authorizeNSEClient) Unregister(ctx context.Context, nse *registry.Netwo } path := grpcmetadata.PathFromContext(ctx) - ctx = grpcmetadata.PathWithContext(ctx, path) + var p peer.Peer + opts = append(opts, grpc.Peer(&p)) + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) if err != nil { return nil, err } - p, ok := c.serverPeer.Load().(*peer.Peer) - if ok && p != nil { - ctx = peer.NewContext(ctx, p) + if p != (peer.Peer{}) { + ctx = peer.NewContext(ctx, &p) } - path = grpcmetadata.PathFromContext(ctx) spiffeID := getSpiffeIDFromPath(ctx, path) - rawMap := getRawMap(c.nsePathIdsMap) input := RegistryOpaInput{ ResourceID: spiffeID.String(), diff --git a/pkg/registry/common/begin/ns_client.go b/pkg/registry/common/begin/ns_client.go index 6635e3e5d..899209dd3 100644 --- a/pkg/registry/common/begin/ns_client.go +++ b/pkg/registry/common/begin/ns_client.go @@ -24,6 +24,7 @@ import ( "github.com/pkg/errors" "google.golang.org/grpc" + "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/log" ) @@ -75,7 +76,7 @@ func (b *beginNSClient) Register(ctx context.Context, in *registry.NetworkServic eventFactoryClient.state = established eventFactoryClient.registration = mergeNS(in, resp.Clone()) eventFactoryClient.response = resp.Clone() - eventFactoryClient.updateContext(ctx) + eventFactoryClient.updateContext(grpcmetadata.PathWithContext(ctx, grpcmetadata.PathFromContext(ctx).Clone())) }) return resp, err } diff --git a/pkg/registry/common/begin/ns_server.go b/pkg/registry/common/begin/ns_server.go index b803c5ff8..35264f328 100644 --- a/pkg/registry/common/begin/ns_server.go +++ b/pkg/registry/common/begin/ns_server.go @@ -24,6 +24,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/types/known/emptypb" + "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/log" ) @@ -73,7 +74,7 @@ func (b *beginNSServer) Register(ctx context.Context, in *registry.NetworkServic eventFactoryServer.registration = mergeNS(in, resp) eventFactoryServer.state = established eventFactoryServer.response = resp - eventFactoryServer.updateContext(ctx) + eventFactoryServer.updateContext(grpcmetadata.PathWithContext(ctx, grpcmetadata.PathFromContext(ctx).Clone())) }) return resp, err } diff --git a/pkg/registry/common/begin/nse_client.go b/pkg/registry/common/begin/nse_client.go index 9f5cba839..960574030 100644 --- a/pkg/registry/common/begin/nse_client.go +++ b/pkg/registry/common/begin/nse_client.go @@ -24,6 +24,7 @@ import ( "github.com/pkg/errors" "google.golang.org/grpc" + "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/log" ) @@ -75,7 +76,7 @@ func (b *beginNSEClient) Register(ctx context.Context, in *registry.NetworkServi eventFactoryClient.state = established eventFactoryClient.registration = mergeNSE(in, resp.Clone()) eventFactoryClient.response = resp.Clone() - eventFactoryClient.updateContext(ctx) + eventFactoryClient.updateContext(grpcmetadata.PathWithContext(ctx, grpcmetadata.PathFromContext(ctx).Clone())) }) return resp, err } diff --git a/pkg/registry/common/begin/nse_server.go b/pkg/registry/common/begin/nse_server.go index 8d2b80dd6..866d9f6ca 100644 --- a/pkg/registry/common/begin/nse_server.go +++ b/pkg/registry/common/begin/nse_server.go @@ -24,6 +24,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/types/known/emptypb" + "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/log" ) @@ -73,7 +74,7 @@ func (b *beginNSEServer) Register(ctx context.Context, in *registry.NetworkServi eventFactoryServer.registration = mergeNSE(in, resp) eventFactoryServer.state = established eventFactoryServer.response = resp - eventFactoryServer.updateContext(ctx) + eventFactoryServer.updateContext(grpcmetadata.PathWithContext(ctx, grpcmetadata.PathFromContext(ctx).Clone())) }) return resp, err } diff --git a/pkg/registry/common/expire/nse_server.go b/pkg/registry/common/expire/nse_server.go index bca5375f4..488400ff1 100644 --- a/pkg/registry/common/expire/nse_server.go +++ b/pkg/registry/common/expire/nse_server.go @@ -18,11 +18,8 @@ package expire import ( "context" - "time" "github.com/golang/protobuf/ptypes/empty" - "google.golang.org/protobuf/types/known/timestamppb" - "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/registry/common/begin" @@ -32,36 +29,31 @@ import ( ) type expireNSEServer struct { - nseExpiration time.Duration - ctx context.Context + ctx context.Context cancelsMap } // NewNetworkServiceEndpointRegistryServer creates a new NetworkServiceServer chain element that implements unregister // of expired connections for the subsequent chain elements. -func NewNetworkServiceEndpointRegistryServer(ctx context.Context, nseExpiration time.Duration) registry.NetworkServiceEndpointRegistryServer { +func NewNetworkServiceEndpointRegistryServer(ctx context.Context) registry.NetworkServiceEndpointRegistryServer { return &expireNSEServer{ - nseExpiration: nseExpiration, - ctx: ctx, + ctx: ctx, } } func (s *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { factory := begin.FromContext(ctx) timeClock := clock.FromContext(ctx) - expirationTime := timeClock.Now().Add(s.nseExpiration).Local() logger := log.FromContext(ctx).WithField("expireNSEServer", "Register") - if nse.GetExpirationTime() != nil { - if nseExpirationTime := nse.GetExpirationTime().AsTime().Local(); nseExpirationTime.Before(expirationTime) { - expirationTime = nseExpirationTime - logger.Infof("selected expiration time %v for %v", expirationTime, nse.GetName()) - } + deadline, ok := ctx.Deadline() + requestTimeout := timeClock.Until(deadline) + if !ok { + requestTimeout = 0 } - nse.ExpirationTime = timestamppb.New(expirationTime) - + expirationTime := nse.GetExpirationTime().AsTime() resp, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) if err != nil { return nil, err @@ -78,7 +70,7 @@ func (s *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkSer } s.cancelsMap.Store(nse.GetName(), cancel) - expireCh := timeClock.After(timeClock.Until(expirationTime.Local())) + expireCh := timeClock.After(timeClock.Until(expirationTime.Local()) - requestTimeout) go func() { select { diff --git a/pkg/registry/common/expire/nse_server_test.go b/pkg/registry/common/expire/nse_server_test.go index 0b53f7646..3a6e9b0fb 100644 --- a/pkg/registry/common/expire/nse_server_test.go +++ b/pkg/registry/common/expire/nse_server_test.go @@ -23,9 +23,11 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" "github.com/golang/protobuf/ptypes/empty" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "google.golang.org/grpc/credentials" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" @@ -36,11 +38,14 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/localbypass" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" + "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injecterror" + "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injectpeertoken" "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/clockmock" + "github.com/networkservicemesh/sdk/pkg/tools/token" ) const ( @@ -70,6 +75,20 @@ func find(ctx context.Context, c registry.NetworkServiceEndpointRegistryClient) return nses, nil } +func generateTestToken(ctx context.Context, duration time.Duration) token.GeneratorFunc { + return func(_ credentials.AuthInfo) (string, time.Time, error) { + expireTime := clock.FromContext(ctx).Now().Add(duration).Local() + + claims := jwt.RegisteredClaims{ + Subject: "spiffe://test.com/subject", + ExpiresAt: jwt.NewNumericDate(expireTime), + } + + tok, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("supersecret")) + return tok, expireTime, err + } +} + func TestExpireNSEServer_ShouldCorrectlySetExpirationTime_InRemoteCase(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -80,8 +99,10 @@ func TestExpireNSEServer_ShouldCorrectlySetExpirationTime_InRemoteCase(t *testin ctx = clock.WithClock(ctx, clockMock) s := next.NewNetworkServiceEndpointRegistryServer( + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + expire.NewNetworkServiceEndpointRegistryServer(ctx), new(remoteNSEServer), ) @@ -105,8 +126,10 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromInput_AndWork(t *testing mem := memory.NewNetworkServiceEndpointRegistryServer() s := next.NewNetworkServiceEndpointRegistryServer( + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + expire.NewNetworkServiceEndpointRegistryServer(ctx), mem, ) @@ -136,10 +159,12 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromResponse(t *testing.T) { s := next.NewNetworkServiceEndpointRegistryServer( begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + expire.NewNetworkServiceEndpointRegistryServer(ctx), new(remoteNSEServer), // <-- GRPC invocation begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout/2), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout/2)), ) resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) @@ -161,7 +186,9 @@ func TestExpireNSEServer_ShouldRemoveNSEAfterExpirationTime(t *testing.T) { s := next.NewNetworkServiceEndpointRegistryServer( begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + expire.NewNetworkServiceEndpointRegistryServer(ctx), new(remoteNSEServer), // <-- GRPC invocation mem, ) @@ -195,7 +222,7 @@ func TestExpireNSEServer_DataRace(t *testing.T) { s := next.NewNetworkServiceEndpointRegistryServer( begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, 0), + expire.NewNetworkServiceEndpointRegistryServer(ctx), localbypass.NewNetworkServiceEndpointRegistryServer("tcp://0.0.0.0"), mem, ) @@ -228,8 +255,10 @@ func TestExpireNSEServer_RefreshFailure(t *testing.T) { refresh.NewNetworkServiceEndpointRegistryClient(ctx), adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( new(remoteNSEServer), // <-- GRPC invocation + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + expire.NewNetworkServiceEndpointRegistryServer(ctx), injecterror.NewNetworkServiceEndpointRegistryServer( injecterror.WithRegisterErrorTimes(1, -1), injecterror.WithFindErrorTimes(), @@ -261,14 +290,16 @@ func TestExpireNSEServer_UnregisterFailure(t *testing.T) { mem := memory.NewNetworkServiceEndpointRegistryServer() s := next.NewNetworkServiceEndpointRegistryServer( + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + expire.NewNetworkServiceEndpointRegistryServer(ctx), injecterror.NewNetworkServiceEndpointRegistryServer( injecterror.WithRegisterErrorTimes(), injecterror.WithFindErrorTimes(), injecterror.WithUnregisterErrorTimes(0), ), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + expire.NewNetworkServiceEndpointRegistryServer(ctx), mem, ) @@ -312,7 +343,9 @@ func TestExpireNSEServer_RefreshKeepsNoUnregister(t *testing.T) { next.NewNetworkServiceEndpointRegistryServer( // NSMgr chain new(remoteNSEServer), // <-- GRPC invocation - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + expire.NewNetworkServiceEndpointRegistryServer(ctx), unregisterServer, )), ) diff --git a/pkg/registry/common/grpcmetadata/common_test.go b/pkg/registry/common/grpcmetadata/common_test.go index a71d0add8..cf984f292 100644 --- a/pkg/registry/common/grpcmetadata/common_test.go +++ b/pkg/registry/common/grpcmetadata/common_test.go @@ -19,6 +19,8 @@ package grpcmetadata_test import ( "time" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" + "github.com/golang-jwt/jwt/v4" "google.golang.org/grpc/credentials" @@ -29,9 +31,16 @@ const ( key = "supersecret" ) -func tokenGeneratorFunc(spiffeID string) token.GeneratorFunc { +// tokenGeneratorFunc generates new tokens automatically (based on time change). +// time.Second + smth - the time tick for jwt is a second. +func tokenGeneratorFunc(clock *clockmock.Mock, spiffeID string) token.GeneratorFunc { return func(peerAuthInfo credentials.AuthInfo) (string, time.Time, error) { - tok, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"sub": spiffeID}).SignedString([]byte(key)) - return tok, time.Date(3000, 1, 1, 1, 1, 1, 1, time.UTC), err + clock.Add(time.Second + time.Millisecond*10) + tok, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": spiffeID, + "exp": jwt.NewNumericDate(clock.Now().Add(time.Hour)), + }, + ).SignedString([]byte(key)) + return tok, clock.Now(), err } } diff --git a/pkg/registry/common/grpcmetadata/ns_client.go b/pkg/registry/common/grpcmetadata/ns_client.go index a7459e2db..8a5d95c70 100644 --- a/pkg/registry/common/grpcmetadata/ns_client.go +++ b/pkg/registry/common/grpcmetadata/ns_client.go @@ -75,5 +75,18 @@ func (c *grpcMetadataNSClient) Unregister(ctx context.Context, ns *registry.Netw return nil, err } - return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) + var header metadata.MD + opts = append(opts, grpc.Header(&header)) + + resp, err := next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) + if err != nil { + return nil, err + } + + newpath, err := fromMD(header) + if err == nil { + path.Index = newpath.Index + path.PathSegments = newpath.PathSegments + } + return resp, nil } diff --git a/pkg/registry/common/grpcmetadata/ns_test.go b/pkg/registry/common/grpcmetadata/ns_test.go index a4694ee92..d315ff18a 100644 --- a/pkg/registry/common/grpcmetadata/ns_test.go +++ b/pkg/registry/common/grpcmetadata/ns_test.go @@ -24,8 +24,10 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/emptypb" "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" @@ -33,8 +35,8 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injectpeertoken" - - "go.uber.org/goleak" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) const ( @@ -43,22 +45,64 @@ const ( serverID = "spiffe://test.com/server" ) +type pathCheckerNSClient struct { + funcBefore func(ctx context.Context) *grpcmetadata.Path + funcAfter func(ctx context.Context, pBefore *grpcmetadata.Path) +} + +func newPathCheckerNSClient(t *testing.T, expectedPathIndex int) registry.NetworkServiceRegistryClient { + client := &pathCheckerNSClient{} + + client.funcBefore = func(ctx context.Context) *grpcmetadata.Path { + p := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(p.Index), expectedPathIndex) + + return p + } + client.funcAfter = func(ctx context.Context, pBefore *grpcmetadata.Path) { + pAfter := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(pAfter.Index), expectedPathIndex) + for i := expectedPathIndex; i < len(pBefore.PathSegments); i++ { + require.NotEqual(t, pBefore.PathSegments[i].Token, pAfter.PathSegments[i].Token) + } + } + return client +} + +func (p *pathCheckerNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + +func (p *pathCheckerNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) +} + +func (p *pathCheckerNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*emptypb.Empty, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + func TestGRPCMetadataNetworkService(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx, cacncel := context.WithTimeout(context.Background(), time.Second) - defer cacncel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) - proxyToken, _, _ := tokenGeneratorFunc(proxyID)(nil) + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), grpcmetadata.NewNetworkServiceRegistryServer(), - updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), + updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, serverID)), checkcontext.NewNSServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 2) @@ -83,14 +127,15 @@ func TestGRPCMetadataNetworkService(t *testing.T) { }() proxyServer := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceRegistryServer(), - updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) }), adapters.NetworkServiceClientToServer(next.NewNetworkServiceRegistryClient( + newPathCheckerNSClient(t, 1), grpcmetadata.NewNetworkServiceRegistryClient(), registry.NewNetworkServiceRegistryClient(serverConn), )), @@ -111,6 +156,7 @@ func TestGRPCMetadataNetworkService(t *testing.T) { }() client := next.NewNetworkServiceRegistryClient( + newPathCheckerNSClient(t, 0), grpcmetadata.NewNetworkServiceRegistryClient(), registry.NewNetworkServiceRegistryClient(conn)) @@ -124,6 +170,10 @@ func TestGRPCMetadataNetworkService(t *testing.T) { require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 3) + // Simulate refresh + _, err = client.Register(ctx, ns) + require.NoError(t, err) + _, err = client.Unregister(ctx, ns) require.NoError(t, err) @@ -134,10 +184,11 @@ func TestGRPCMetadataNetworkService(t *testing.T) { func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx, cacncel := context.WithTimeout(context.Background(), time.Second) - defer cacncel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -165,9 +216,9 @@ func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { }() proxyServer := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceRegistryServer(), - updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) @@ -196,6 +247,7 @@ func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { }() client := next.NewNetworkServiceRegistryClient( + newPathCheckerNSClient(t, 0), grpcmetadata.NewNetworkServiceRegistryClient(), registry.NewNetworkServiceRegistryClient(conn)) @@ -209,6 +261,10 @@ func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 2) + // Simulate refresh + _, err = client.Register(ctx, ns) + require.NoError(t, err) + _, err = client.Unregister(ctx, ns) require.NoError(t, err) } diff --git a/pkg/registry/common/grpcmetadata/nse_client.go b/pkg/registry/common/grpcmetadata/nse_client.go index e5ec57e32..b77ba989b 100644 --- a/pkg/registry/common/grpcmetadata/nse_client.go +++ b/pkg/registry/common/grpcmetadata/nse_client.go @@ -73,5 +73,18 @@ func (c *grpcMetadataNSEClient) Unregister(ctx context.Context, nse *registry.Ne return nil, err } - return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) + var header metadata.MD + opts = append(opts, grpc.Header(&header)) + + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) + if err != nil { + return nil, err + } + + newpath, err := fromMD(header) + if err == nil { + path.Index = newpath.Index + path.PathSegments = newpath.PathSegments + } + return resp, nil } diff --git a/pkg/registry/common/grpcmetadata/nse_test.go b/pkg/registry/common/grpcmetadata/nse_test.go index 83ed2520f..949965aec 100644 --- a/pkg/registry/common/grpcmetadata/nse_test.go +++ b/pkg/registry/common/grpcmetadata/nse_test.go @@ -22,36 +22,86 @@ import ( "testing" "time" - "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/golang/protobuf/ptypes/empty" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injectpeertoken" - - "go.uber.org/goleak" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) +type pathCheckerNSEClient struct { + funcBefore func(ctx context.Context) *grpcmetadata.Path + funcAfter func(ctx context.Context, pBefore *grpcmetadata.Path) +} + +func newPathCheckerNSEClient(t *testing.T, expectedPathIndex int) registry.NetworkServiceEndpointRegistryClient { + client := &pathCheckerNSEClient{} + + client.funcBefore = func(ctx context.Context) *grpcmetadata.Path { + p := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(p.Index), expectedPathIndex) + + return p + } + client.funcAfter = func(ctx context.Context, pBefore *grpcmetadata.Path) { + pAfter := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(pAfter.Index), expectedPathIndex) + for i := expectedPathIndex; i < len(pBefore.PathSegments); i++ { + require.NotEqual(t, pBefore.PathSegments[i].Token, pAfter.PathSegments[i].Token) + } + } + return client +} + +func (p *pathCheckerNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + +func (p *pathCheckerNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (p *pathCheckerNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + +// This test checks that registry Path is correctly updated and passed through grpc metadata +// Test scheme: client ---> proxyServer ---> server func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) - proxyToken, _, _ := tokenGeneratorFunc(proxyID)(nil) + // Add clockMock to the context + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + // tokenGeneratorFunc generates new tokens automatically server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), + updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, serverID)), checkcontext.NewNSEServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 2) @@ -76,14 +126,15 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { }() proxyServer := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSEServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) }), adapters.NetworkServiceEndpointClientToServer(next.NewNetworkServiceEndpointRegistryClient( + newPathCheckerNSEClient(t, 1), grpcmetadata.NewNetworkServiceEndpointRegistryClient(), registry.NewNetworkServiceEndpointRegistryClient(serverConn), )), @@ -104,10 +155,7 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { }() client := next.NewNetworkServiceEndpointRegistryClient( - checkcontext.NewNSEClient(t, func(t *testing.T, ctx context.Context) { - path := grpcmetadata.PathFromContext(ctx) - require.Equal(t, int(path.Index), 0) - }), + newPathCheckerNSEClient(t, 0), grpcmetadata.NewNetworkServiceEndpointRegistryClient(), registry.NewNetworkServiceEndpointRegistryClient(conn)) @@ -121,6 +169,10 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 3) + // Simulate refresh + _, err = client.Register(ctx, nse) + require.NoError(t, err) + _, err = client.Unregister(ctx, nse) require.NoError(t, err) @@ -131,10 +183,12 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { func TestGRPCMetadataNetworkServiceEndpoint_BackwardCompatibility(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx, cacncel := context.WithTimeout(context.Background(), time.Second) - defer cacncel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) + // Add clockMock to the context + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -162,9 +216,9 @@ func TestGRPCMetadataNetworkServiceEndpoint_BackwardCompatibility(t *testing.T) }() proxyServer := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSEServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) @@ -193,19 +247,24 @@ func TestGRPCMetadataNetworkServiceEndpoint_BackwardCompatibility(t *testing.T) }() client := next.NewNetworkServiceEndpointRegistryClient( + newPathCheckerNSEClient(t, 0), grpcmetadata.NewNetworkServiceEndpointRegistryClient(), registry.NewNetworkServiceEndpointRegistryClient(conn)) path := grpcmetadata.Path{} ctx = grpcmetadata.PathWithContext(ctx, &path) - ns := ®istry.NetworkServiceEndpoint{Name: "ns"} - _, err = client.Register(ctx, ns) + nse := ®istry.NetworkServiceEndpoint{Name: "ns"} + _, err = client.Register(ctx, nse) require.NoError(t, err) require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 2) - _, err = client.Unregister(ctx, ns) + // Simulate refresh + _, err = client.Register(ctx, nse) + require.NoError(t, err) + + _, err = client.Unregister(ctx, nse) require.NoError(t, err) } diff --git a/pkg/registry/common/updatepath/ns_server_test.go b/pkg/registry/common/updatepath/ns_server_test.go index d5894d3b0..124e0d286 100644 --- a/pkg/registry/common/updatepath/ns_server_test.go +++ b/pkg/registry/common/updatepath/ns_server_test.go @@ -53,7 +53,7 @@ var nsSamples = []*nsSample{ } server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -107,9 +107,9 @@ var nsSamples = []*nsSample{ } server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -155,9 +155,9 @@ var nsSamples = []*nsSample{ } server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), ) diff --git a/pkg/registry/common/updatepath/nse_server.go b/pkg/registry/common/updatepath/nse_server.go index b072e9b68..e9e5bf2ce 100644 --- a/pkg/registry/common/updatepath/nse_server.go +++ b/pkg/registry/common/updatepath/nse_server.go @@ -21,6 +21,7 @@ import ( "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/api/pkg/api/registry" @@ -45,11 +46,11 @@ func (s *updatePathNSEServer) Register(ctx context.Context, nse *registry.Networ path := grpcmetadata.PathFromContext(ctx) // Update path - peerTok, _, tokenErr := token.FromContext(ctx) - if tokenErr != nil { - log.FromContext(ctx).Warnf("an error during getting peer token from the context: %+v", tokenErr) + peerTok, peerExpirationTime, peerTokenErr := token.FromContext(ctx) + if peerTokenErr != nil { + log.FromContext(ctx).Warnf("an error during getting peer token from the context: %+v", peerTokenErr) } - tok, _, tokenErr := generateToken(ctx, s.tokenGenerator) + tok, expirationTime, tokenErr := generateToken(ctx, s.tokenGenerator) if tokenErr != nil { return nil, errors.Wrap(tokenErr, "an error during generating token") } @@ -70,6 +71,13 @@ func (s *updatePathNSEServer) Register(ctx context.Context, nse *registry.Networ nse.PathIds = updatePathIds(nse.PathIds, int(path.Index-1), peerID.String()) nse.PathIds = updatePathIds(nse.PathIds, int(path.Index), id.String()) + if nse.GetExpirationTime() == nil || expirationTime.Before(nse.GetExpirationTime().AsTime().Local()) { + nse.ExpirationTime = timestamppb.New(expirationTime) + } + if peerTokenErr == nil && peerExpirationTime.Before(nse.GetExpirationTime().AsTime().Local()) { + nse.ExpirationTime = timestamppb.New(peerExpirationTime) + } + nse, err = next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) if err != nil { return nil, err diff --git a/pkg/registry/common/updatepath/nse_server_test.go b/pkg/registry/common/updatepath/nse_server_test.go index ad6344dd2..bb4e5935e 100644 --- a/pkg/registry/common/updatepath/nse_server_test.go +++ b/pkg/registry/common/updatepath/nse_server_test.go @@ -53,7 +53,7 @@ var nseSamples = []*nseSample{ } server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -107,9 +107,9 @@ var nseSamples = []*nseSample{ } server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceEndpointRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -155,9 +155,9 @@ var nseSamples = []*nseSample{ } server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceEndpointRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), ) diff --git a/pkg/registry/utils/inject/injectpeertoken/ns_server.go b/pkg/registry/utils/inject/injectpeertoken/ns_server.go index 0499a54eb..be060c225 100644 --- a/pkg/registry/utils/inject/injectpeertoken/ns_server.go +++ b/pkg/registry/utils/inject/injectpeertoken/ns_server.go @@ -19,6 +19,8 @@ package injectpeertoken import ( "context" + "github.com/networkservicemesh/sdk/pkg/tools/token" + "google.golang.org/protobuf/types/known/emptypb" "github.com/networkservicemesh/api/pkg/api/registry" @@ -27,18 +29,19 @@ import ( ) type injectSpiffeIDNSServer struct { - peerToken string + tokenGenerator token.GeneratorFunc } // NewNetworkServiceRegistryServer returns a server chain element putting spiffeID to context on Register and Unregister -func NewNetworkServiceRegistryServer(peerToken string) registry.NetworkServiceRegistryServer { +func NewNetworkServiceRegistryServer(tokenGenerator token.GeneratorFunc) registry.NetworkServiceRegistryServer { return &injectSpiffeIDNSServer{ - peerToken: peerToken, + tokenGenerator: tokenGenerator, } } func (s *injectSpiffeIDNSServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceRegistryServer(ctx).Register(ctx, ns) } @@ -47,6 +50,7 @@ func (s *injectSpiffeIDNSServer) Find(query *registry.NetworkServiceQuery, serve } func (s *injectSpiffeIDNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*emptypb.Empty, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns) } diff --git a/pkg/registry/utils/inject/injectpeertoken/nse_server.go b/pkg/registry/utils/inject/injectpeertoken/nse_server.go index 4fbdecb80..20e05a013 100644 --- a/pkg/registry/utils/inject/injectpeertoken/nse_server.go +++ b/pkg/registry/utils/inject/injectpeertoken/nse_server.go @@ -24,21 +24,23 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/token" ) type injectSpiffeIDNSEServer struct { - peerToken string + tokenGenerator token.GeneratorFunc } -// NewNetworkServiceEndpointRegistryServer returns a server chain element putting spiffeID to context on Register and Unregister -func NewNetworkServiceEndpointRegistryServer(peerToken string) registry.NetworkServiceEndpointRegistryServer { +// NewNetworkServiceEndpointRegistryServer returns a server chain element putting peer token to context on Register and Unregister +func NewNetworkServiceEndpointRegistryServer(tokenGenerator token.GeneratorFunc) registry.NetworkServiceEndpointRegistryServer { return &injectSpiffeIDNSEServer{ - peerToken: peerToken, + tokenGenerator: tokenGenerator, } } func (s *injectSpiffeIDNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) } @@ -47,6 +49,7 @@ func (s *injectSpiffeIDNSEServer) Find(query *registry.NetworkServiceEndpointQue } func (s *injectSpiffeIDNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } diff --git a/pkg/tools/opa/policies/common/tokens_expired.rego b/pkg/tools/opa/policies/common/tokens_expired.rego index 4f44d1a4f..5a87cb286 100644 --- a/pkg/tools/opa/policies/common/tokens_expired.rego +++ b/pkg/tools/opa/policies/common/tokens_expired.rego @@ -1,4 +1,4 @@ -# Copyright (c) 2020 Cisco and/or its affiliates. +# Copyright (c) 2020-2022 Cisco and/or its affiliates. # # SPDX-License-Identifier: Apache-2.0 # @@ -19,11 +19,11 @@ package nsm default valid = false valid { - count({x | input.path_segments[x]; token_expired(input.path_segments[x].token)}) == count(input.path_segments) + count({x | input.path_segments[x]; token_alive(input.path_segments[x].token)}) == count(input.path_segments) } -token_expired(token) { - print(token) +# alive means not expired +token_alive(token) { [_, payload, _] := io.jwt.decode(token) now < payload.exp } diff --git a/pkg/tools/sandbox/builder.go b/pkg/tools/sandbox/builder.go index 635819448..dc425aa8c 100644 --- a/pkg/tools/sandbox/builder.go +++ b/pkg/tools/sandbox/builder.go @@ -24,7 +24,6 @@ import ( "os" "runtime" "testing" - "time" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -53,21 +52,19 @@ type Builder struct { supplyRegistryProxy SupplyRegistryProxyFunc setupNode SetupNodeFunc - name string - dnsResolver dnsresolve.Resolver - generateTokenFunc token.GeneratorFunc - registryExpiryDuration time.Duration + name string + dnsResolver dnsresolve.Resolver + generateTokenFunc token.GeneratorFunc useUnixSockets bool domain *Domain } -func newRegistryMemoryServer(ctx context.Context, tokenGenerator token.GeneratorFunc, expiryDuration time.Duration, proxyRegistryURL *url.URL, options ...grpc.DialOption) registry.Registry { +func newRegistryMemoryServer(ctx context.Context, tokenGenerator token.GeneratorFunc, proxyRegistryURL *url.URL, options ...grpc.DialOption) registry.Registry { return memory.NewServer( ctx, tokenGenerator, - memory.WithExpireDuration(expiryDuration), memory.WithProxyRegistryURL(proxyRegistryURL), memory.WithDialOptions(options...)) } @@ -75,17 +72,16 @@ func newRegistryMemoryServer(ctx context.Context, tokenGenerator token.Generator // NewBuilder creates new SandboxBuilder func NewBuilder(ctx context.Context, t *testing.T) *Builder { b := &Builder{ - t: t, - ctx: ctx, - nodesCount: 1, - supplyNSMgr: nsmgr.NewServer, - supplyNSMgrProxy: nsmgrproxy.NewServer, - supplyRegistry: newRegistryMemoryServer, - supplyRegistryProxy: proxydns.NewServer, - name: "cluster.local", - dnsResolver: NewFakeResolver(), - generateTokenFunc: GenerateTestToken, - registryExpiryDuration: time.Minute, + t: t, + ctx: ctx, + nodesCount: 1, + supplyNSMgr: nsmgr.NewServer, + supplyNSMgrProxy: nsmgrproxy.NewServer, + supplyRegistry: newRegistryMemoryServer, + supplyRegistryProxy: proxydns.NewServer, + name: "cluster.local", + dnsResolver: NewFakeResolver(), + generateTokenFunc: GenerateTestToken, } b.setupNode = func(ctx context.Context, node *Node, _ int) { @@ -151,12 +147,6 @@ func (b *Builder) SetTokenGenerateFunc(f token.GeneratorFunc) *Builder { return b } -// SetRegistryExpiryDuration replaces registry expiry duration to custom -func (b *Builder) SetRegistryExpiryDuration(registryExpiryDuration time.Duration) *Builder { - b.registryExpiryDuration = registryExpiryDuration - return b -} - // UseUnixSockets sets 1 node and mark it to use unix socket to listen on. func (b *Builder) UseUnixSockets() *Builder { require.NotEqual(b.t, "windows", runtime.GOOS, "Unix sockets are not available for windows") @@ -275,7 +265,6 @@ func (b *Builder) newRegistry() *RegistryEntry { entry.Registry = b.supplyRegistry( ctx, b.generateTokenFunc, - b.registryExpiryDuration, nsmgrProxyURL, DialOptions(WithTokenGenerator(b.generateTokenFunc))..., ) diff --git a/pkg/tools/sandbox/types.go b/pkg/tools/sandbox/types.go index 4d6618eec..571713b01 100644 --- a/pkg/tools/sandbox/types.go +++ b/pkg/tools/sandbox/types.go @@ -19,7 +19,6 @@ package sandbox import ( "context" "net/url" - "time" registryapi "github.com/networkservicemesh/api/pkg/api/registry" "google.golang.org/grpc" @@ -41,7 +40,7 @@ type SupplyNSMgrProxyFunc func(ctx context.Context, regURL, proxyURL *url.URL, t type SupplyNSMgrFunc func(ctx context.Context, tokenGenerator token.GeneratorFunc, options ...nsmgr.Option) nsmgr.Nsmgr // SupplyRegistryFunc supplies Registry -type SupplyRegistryFunc func(ctx context.Context, tokenGenerator token.GeneratorFunc, expiryDuration time.Duration, proxyRegistryURL *url.URL, options ...grpc.DialOption) registry.Registry +type SupplyRegistryFunc func(ctx context.Context, tokenGenerator token.GeneratorFunc, proxyRegistryURL *url.URL, options ...grpc.DialOption) registry.Registry // SupplyRegistryProxyFunc supplies registry proxy type SupplyRegistryProxyFunc func(ctx context.Context, tokenGenerator token.GeneratorFunc, dnsResolver dnsresolve.Resolver, options ...proxydns.Option) registry.Registry