From 8fe3d10e407a4be27d7fd7aacc5ab244dc1c90e3 Mon Sep 17 00:00:00 2001 From: Vladimir Popov Date: Thu, 22 Apr 2021 18:44:43 +0700 Subject: [PATCH] Rework pkg/registry to use clock.Clock (#784) * Rework registry/querycache to clock Signed-off-by: Vladimir Popov * Rework registry/expire to clock Signed-off-by: Vladimir Popov * Rework registry/connect to clock Signed-off-by: Vladimir Popov * Rework registry/refresh to clock Signed-off-by: Vladimir Popov * Remove time.Sleep from clocked registry tests Signed-off-by: Vladimir Popov * Remove clockMock.IsTimerSet Signed-off-by: Vladimir Popov --- pkg/registry/common/connect/ns_server.go | 9 +- pkg/registry/common/connect/nse_server.go | 9 +- pkg/registry/common/expire/ns_server.go | 12 +- pkg/registry/common/expire/ns_server_test.go | 288 ++++++++++++------ pkg/registry/common/expire/nse_server.go | 13 +- pkg/registry/common/expire/nse_server_test.go | 163 ++++++---- pkg/registry/common/memory/nse_server.go | 12 +- pkg/registry/common/querycache/cache.go | 16 +- .../common/querycache/nse_client_test.go | 41 ++- .../common/refresh/nse_registry_client.go | 14 +- .../refresh/nse_registry_client_test.go | 107 +++++-- 11 files changed, 457 insertions(+), 227 deletions(-) diff --git a/pkg/registry/common/connect/ns_server.go b/pkg/registry/common/connect/ns_server.go index 7901c55ae..045e8ed9b 100644 --- a/pkg/registry/common/connect/ns_server.go +++ b/pkg/registry/common/connect/ns_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -21,6 +21,7 @@ import ( "time" "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" + "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/tools/extend" @@ -33,7 +34,7 @@ import ( ) type nsCacheEntry struct { - expirationTimer *time.Timer + expirationTimer clock.Timer client registry.NetworkServiceRegistryClient } @@ -43,6 +44,7 @@ type connectNSServer struct { cache nsClientMap connectExpiration time.Duration ctx context.Context + clock clock.Clock } // NewNetworkServiceRegistryServer creates new connect NetworkServiceEndpointRegistryServer with specific chain context, registry client factory and options @@ -52,6 +54,7 @@ func NewNetworkServiceRegistryServer(ctx context.Context, clientFactory func(ctx ctx: ctx, clientFactory: clientFactory, connectExpiration: defaultConnectExpiration, + clock: clock.FromContext(ctx), } for _, o := range options { o.apply(r) @@ -87,7 +90,7 @@ func (c *connectNSServer) connect(ctx context.Context) registry.NetworkServiceRe ctx = extend.WithValuesFromContext(c.ctx, ctx) client := clienturl.NewNetworkServiceRegistryClient(ctx, c.clientFactory, c.dialOptions...) cached, _ := c.cache.LoadOrStore(key, &nsCacheEntry{ - expirationTimer: time.AfterFunc(c.connectExpiration, func() { + expirationTimer: c.clock.AfterFunc(c.connectExpiration, func() { c.cache.Delete(key) }), client: client, diff --git a/pkg/registry/common/connect/nse_server.go b/pkg/registry/common/connect/nse_server.go index 68fab8ff6..3b12184fd 100644 --- a/pkg/registry/common/connect/nse_server.go +++ b/pkg/registry/common/connect/nse_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -21,6 +21,7 @@ import ( "time" "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" + "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/registry/common/clienturl" "github.com/networkservicemesh/sdk/pkg/tools/extend" @@ -33,7 +34,7 @@ import ( ) type nseCacheEntry struct { - expirationTimer *time.Timer + expirationTimer clock.Timer client registry.NetworkServiceEndpointRegistryClient } @@ -43,6 +44,7 @@ type connectNSEServer struct { cache nseClientMap connectExpiration time.Duration ctx context.Context + clock clock.Clock } // NewNetworkServiceEndpointRegistryServer creates new connect NetworkServiceEndpointEndpointRegistryServer with specific chain context, registry client factory and options @@ -55,6 +57,7 @@ func NewNetworkServiceEndpointRegistryServer(ctx context.Context, ctx: ctx, clientFactory: clientFactory, connectExpiration: defaultConnectExpiration, + clock: clock.FromContext(ctx), } for _, o := range options { o.apply(r) @@ -88,7 +91,7 @@ func (c *connectNSEServer) connect(ctx context.Context) registry.NetworkServiceE ctx = extend.WithValuesFromContext(c.ctx, ctx) client := clienturl.NewNetworkServiceEndpointRegistryClient(ctx, c.clientFactory, c.dialOptions...) cached, _ := c.cache.LoadOrStore(key, &nseCacheEntry{ - expirationTimer: time.AfterFunc(c.connectExpiration, func() { + expirationTimer: c.clock.AfterFunc(c.connectExpiration, func() { c.cache.Delete(key) }), client: client, diff --git a/pkg/registry/common/expire/ns_server.go b/pkg/registry/common/expire/ns_server.go index c5753a083..00423d141 100644 --- a/pkg/registry/common/expire/ns_server.go +++ b/pkg/registry/common/expire/ns_server.go @@ -20,9 +20,9 @@ import ( "context" "errors" "sync" - "time" "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/extend" "github.com/golang/protobuf/ptypes/empty" @@ -39,12 +39,14 @@ type expireNSServer struct { } type nsState struct { - Timers map[string]*time.Timer + Timers map[string]clock.Timer Context context.Context sync.Mutex } func (n *expireNSServer) checkUpdates(eventCh <-chan *registry.NetworkServiceEndpoint) { + clockTime := clock.FromContext(n.chainCtx) + for event := range eventCh { nse := event if nse.ExpirationTime == nil { @@ -59,10 +61,10 @@ func (n *expireNSServer) checkUpdates(eventCh <-chan *registry.NetworkServiceEnd } state.Lock() timer, ok := state.Timers[nse.Name] - expirationDuration := time.Until(nse.ExpirationTime.AsTime().Local()) + expirationDuration := clockTime.Until(nse.ExpirationTime.AsTime().Local()) if !ok { if expirationDuration > 0 { - state.Timers[nse.Name] = time.AfterFunc(expirationDuration, func() { + state.Timers[nse.Name] = clockTime.AfterFunc(expirationDuration, func() { state.Lock() ctx := state.Context delete(state.Timers, nse.Name) @@ -119,7 +121,7 @@ func (n *expireNSServer) Register(ctx context.Context, request *registry.Network valuesCtx := extend.WithValuesFromContext(n.chainCtx, ctx) v, _ := n.nsStates.LoadOrStore(request.Name, &nsState{ - Timers: make(map[string]*time.Timer), + Timers: make(map[string]clock.Timer), }) v.Lock() diff --git a/pkg/registry/common/expire/ns_server_test.go b/pkg/registry/common/expire/ns_server_test.go index 2bffa393b..322c7d76f 100644 --- a/pkg/registry/common/expire/ns_server_test.go +++ b/pkg/registry/common/expire/ns_server_test.go @@ -18,10 +18,14 @@ package expire_test import ( "context" + "fmt" + "io" + "sync" "testing" "time" "go.uber.org/goleak" + "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/api/pkg/api/registry" @@ -31,9 +35,16 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/memory" "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) -const testPeriod = time.Millisecond * 200 +const ( + expireTimeout = time.Minute + nsName = "ns" + testWait = 100 * time.Millisecond + testTick = testWait / 100 +) func TestExpireNSServer_NSE_Expired(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -41,151 +52,238 @@ func TestExpireNSServer_NSE_Expired(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nseMem := next.NewNetworkServiceEndpointRegistryServer( - memory.NewNetworkServiceEndpointRegistryServer(), - ) - - nseClient := adapters.NetworkServiceEndpointServerToClient(nseMem) + clockMock := clockmock.NewMock() + ctx = clock.WithClock(ctx, clockMock) + nseMem := memory.NewNetworkServiceEndpointRegistryServer() nsMem := memory.NewNetworkServiceRegistryServer() - s := next.NewNetworkServiceRegistryServer(expire.NewNetworkServiceServer(ctx, nseClient), nsMem) + updateServer := new(updateNSEServer) + + s := next.NewNetworkServiceRegistryServer( + expire.NewNetworkServiceServer( + ctx, + adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( + updateServer, + nseMem, + ))), + nsMem, + ) _, err := s.Register(ctx, ®istry.NetworkService{ - Name: "IP terminator", + Name: nsName, }) - require.Nil(t, err) + require.NoError(t, err) - for i := 0; i < 100; i++ { + names := make([]string, 10) + for i := 0; i < len(names); i++ { + names[i] = fmt.Sprint("nse-", i) _, err = nseMem.Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - NetworkServiceNames: []string{"IP terminator"}, - ExpirationTime: timestamppb.New(time.Now().Add(testPeriod * 2)), + Name: names[i], + NetworkServiceNames: []string{nsName}, + ExpirationTime: timestamppb.New(clockMock.Now().Add(expireTimeout)), }) - require.Nil(t, err) + require.NoError(t, err) } - nsClient := adapters.NetworkServiceServerToClient(s) - - stream, err := nsClient.Find(ctx, ®istry.NetworkServiceQuery{ - NetworkService: ®istry.NetworkService{}, + // Wait for the update from nseMem + require.Eventually(t, func() bool { + for _, name := range names { + if _, ok := updateServer.updates.Load(name); !ok { + return false + } + } + return true + }, testWait, testTick) + + c := adapters.NetworkServiceServerToClient(nsMem) + + stream, err := c.Find(ctx, ®istry.NetworkServiceQuery{ + NetworkService: new(registry.NetworkService), }) - require.Nil(t, err) + require.NoError(t, err) - list := registry.ReadNetworkServiceList(stream) - require.NotEmpty(t, list) + ns, err := stream.Recv() + require.NoError(t, err) + require.Equal(t, nsName, ns.Name) + clockMock.Add(expireTimeout) require.Eventually(t, func() bool { - stream, err = nsClient.Find(ctx, ®istry.NetworkServiceQuery{ - NetworkService: ®istry.NetworkService{}, + stream, err = c.Find(ctx, ®istry.NetworkServiceQuery{ + NetworkService: new(registry.NetworkService), }) - require.Nil(t, err) - list = registry.ReadNetworkServiceList(stream) - return len(list) == 0 - }, time.Second, time.Millisecond*100) + require.NoError(t, err) + + _, err = stream.Recv() + return err == io.EOF + }, testWait, testTick) } -func TestExpireNSServer_NSE_UnregisterdBeforeExpired(t *testing.T) { +func TestExpireNSServer_NSE_Unregistered(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nseMem := next.NewNetworkServiceEndpointRegistryServer( - memory.NewNetworkServiceEndpointRegistryServer(), - ) + clockMock := clockmock.NewMock() + ctx = clock.WithClock(ctx, clockMock) - nseClient := adapters.NetworkServiceEndpointServerToClient(nseMem) + nseMem := memory.NewNetworkServiceEndpointRegistryServer() + nsMem := memory.NewNetworkServiceRegistryServer() + + updateServer := new(updateNSEServer) - s := next.NewNetworkServiceRegistryServer(expire.NewNetworkServiceServer(ctx, nseClient), memory.NewNetworkServiceRegistryServer()) + s := next.NewNetworkServiceRegistryServer( + expire.NewNetworkServiceServer( + ctx, + adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( + updateServer, + nseMem, + ))), + nsMem, + ) _, err := s.Register(ctx, ®istry.NetworkService{ - Name: "IP terminator", + Name: nsName, }) - require.Nil(t, err) + require.NoError(t, err) - for i := 0; i < 100; i++ { + names := make([]string, 10) + for i := 0; i < len(names); i++ { + names[i] = fmt.Sprint("nse-", i) _, err = nseMem.Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - NetworkServiceNames: []string{"IP terminator"}, - ExpirationTime: timestamppb.New(time.Now().Add(time.Hour)), + Name: names[i], + NetworkServiceNames: []string{nsName}, + ExpirationTime: timestamppb.New(clockMock.Now().Add(expireTimeout)), }) - require.Nil(t, err) + require.NoError(t, err) } - nsClient := adapters.NetworkServiceServerToClient(s) - stream, err := nsClient.Find(ctx, ®istry.NetworkServiceQuery{ - NetworkService: ®istry.NetworkService{}, + // Wait for the update from nseMem + require.Eventually(t, func() bool { + for _, name := range names { + if _, ok := updateServer.updates.Load(name); !ok { + return false + } + } + return true + }, testWait, testTick) + + c := adapters.NetworkServiceServerToClient(nsMem) + + stream, err := c.Find(ctx, ®istry.NetworkServiceQuery{ + NetworkService: new(registry.NetworkService), }) - require.Nil(t, err) + require.NoError(t, err) - list := registry.ReadNetworkServiceList(stream) - require.NotEmpty(t, list) + ns, err := stream.Recv() + require.NoError(t, err) + require.Equal(t, nsName, ns.Name) - <-time.After(testPeriod * 2) - _, err = nseClient.Unregister(ctx, ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - NetworkServiceNames: []string{"IP terminator"}, - }) - require.Nil(t, err) + for i := 0; i < 10; i++ { + _, err = nseMem.Unregister(ctx, ®istry.NetworkServiceEndpoint{ + Name: fmt.Sprint("nse-", i), + }) + require.NoError(t, err) + } require.Eventually(t, func() bool { - stream, err = nsClient.Find(ctx, ®istry.NetworkServiceQuery{ - NetworkService: ®istry.NetworkService{}, + stream, err = c.Find(ctx, ®istry.NetworkServiceQuery{ + NetworkService: new(registry.NetworkService), }) - require.Nil(t, err) - list = registry.ReadNetworkServiceList(stream) - return len(list) == 0 - }, time.Second, time.Millisecond*100) + require.NoError(t, err) + + _, err = stream.Recv() + return err == io.EOF + }, testWait, testTick) } -func TestExpireNSServer_NSEServerSendsExpirationUpdate(t *testing.T) { +func TestExpireNSServer_NSE_Update(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + + const nseName = "nse" ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nseMem := next.NewNetworkServiceEndpointRegistryServer( - expire.NewNetworkServiceEndpointRegistryServer(ctx, time.Second), - memory.NewNetworkServiceEndpointRegistryServer(), + + clockMock := clockmock.NewMock() + ctx = clock.WithClock(ctx, clockMock) + + nseMem := memory.NewNetworkServiceEndpointRegistryServer() + nsMem := memory.NewNetworkServiceRegistryServer() + + updateServer := new(updateNSEServer) + + s := next.NewNetworkServiceRegistryServer( + expire.NewNetworkServiceServer( + ctx, + adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( + updateServer, + nseMem, + ))), + nsMem, ) - _, err := nseMem.Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - NetworkServiceNames: []string{"IP terminator"}, - ExpirationTime: timestamppb.New(time.Now().Add(time.Hour)), + _, err := s.Register(ctx, ®istry.NetworkService{ + Name: nsName, }) - require.Nil(t, err) + require.NoError(t, err) - _, err = nseMem.Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: "nse-2", - NetworkServiceNames: []string{"IP terminator"}, - ExpirationTime: timestamppb.New(time.Now().Add(testPeriod)), - }) - require.Nil(t, err) + for i := 0; i < 3; i++ { + updateServer.updates = sync.Map{} - nseClient := adapters.NetworkServiceEndpointServerToClient(nseMem) - nsMem := memory.NewNetworkServiceRegistryServer() - s := next.NewNetworkServiceRegistryServer(expire.NewNetworkServiceServer(ctx, nseClient), nsMem) - _, err = s.Register(ctx, ®istry.NetworkService{ - Name: "IP terminator", - }) + _, err = nseMem.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: nseName, + NetworkServiceNames: []string{nsName}, + ExpirationTime: timestamppb.New(clockMock.Now().Add(expireTimeout)), + }) + require.NoError(t, err) - require.Nil(t, err) - nsClient := adapters.NetworkServiceServerToClient(s) - stream, err := nsClient.Find(ctx, ®istry.NetworkServiceQuery{ - NetworkService: ®istry.NetworkService{}, - }) + // Wait for the update from nseMem + require.Eventually(t, func() bool { + _, ok := updateServer.updates.Load(nseName) + return ok + }, testWait, testTick) - require.Nil(t, err) - list := registry.ReadNetworkServiceList(stream) - require.NotEmpty(t, list) - <-time.After(testPeriod * 2) + c := adapters.NetworkServiceServerToClient(nsMem) - require.Eventually(t, func() bool { - stream, err = nsClient.Find(ctx, ®istry.NetworkServiceQuery{ - NetworkService: ®istry.NetworkService{}, + stream, err := c.Find(ctx, ®istry.NetworkServiceQuery{ + NetworkService: new(registry.NetworkService), }) - require.Nil(t, err) - list = registry.ReadNetworkServiceList(stream) - return len(list) == 1 - }, time.Second, time.Millisecond*100) + require.NoError(t, err) + + ns, err := stream.Recv() + require.NoError(t, err) + require.Equal(t, nsName, ns.Name) + + clockMock.Add(expireTimeout / 2) + } +} + +type updateNSEServer struct { + updates sync.Map +} + +func (s *updateNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) +} + +func (s *updateNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, &updateNSEFindServer{ + updateNSEServer: s, + NetworkServiceEndpointRegistry_FindServer: server, + }) +} + +func (s *updateNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*emptypb.Empty, error) { + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) +} + +type updateNSEFindServer struct { + *updateNSEServer + registry.NetworkServiceEndpointRegistry_FindServer +} + +func (s *updateNSEFindServer) Send(nse *registry.NetworkServiceEndpoint) error { + s.updates.Store(nse.Name, struct{}{}) + return s.NetworkServiceEndpointRegistry_FindServer.Send(nse) } diff --git a/pkg/registry/common/expire/nse_server.go b/pkg/registry/common/expire/nse_server.go index 0a5f3b6fb..5e9c15851 100644 --- a/pkg/registry/common/expire/nse_server.go +++ b/pkg/registry/common/expire/nse_server.go @@ -27,6 +27,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clock" ) // TODO: rework with serialize (#749) @@ -39,7 +40,7 @@ type expireNSEServer struct { type unregisterTimer struct { expirationTime time.Time started, canceled bool - timer *time.Timer + timer clock.Timer executor serialize.Executor } @@ -52,6 +53,8 @@ func NewNetworkServiceEndpointRegistryServer(ctx context.Context, nseExpiration } func (n *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + clockTime := clock.FromContext(ctx) + t, loaded := n.timers.LoadAndDelete(nse.Name) stopped := loaded && t.timer.Stop() @@ -72,7 +75,7 @@ func (n *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkSer if err != nil { if stopped { // Timer has been stopped, we need only to reset it. - t.timer.Reset(time.Until(expirationTime)) + t.timer.Reset(clockTime.Until(expirationTime)) } else if loaded && !started { // Timer function has been stopped with the `canceled` flag, we need to remove the flag. t.executor.AsyncExec(func() { @@ -86,7 +89,7 @@ func (n *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkSer return nil, err } - expirationTime = time.Now().Add(n.nseExpiration) + expirationTime = clockTime.Now().Add(n.nseExpiration) if resp.ExpirationTime != nil { if respExpirationTime := resp.ExpirationTime.AsTime().Local(); respExpirationTime.Before(expirationTime) { expirationTime = respExpirationTime @@ -110,11 +113,13 @@ func (n *expireNSEServer) newTimer( expirationTime time.Time, nse *registry.NetworkServiceEndpoint, ) *unregisterTimer { + clockTime := clock.FromContext(ctx) + t := &unregisterTimer{ expirationTime: expirationTime, } - t.timer = time.AfterFunc(time.Until(expirationTime), func() { + t.timer = clockTime.AfterFunc(clockTime.Until(expirationTime), func() { t.executor.AsyncExec(func() { t.started = true if t.canceled || n.ctx.Err() != nil { diff --git a/pkg/registry/common/expire/nse_server_test.go b/pkg/registry/common/expire/nse_server_test.go index c6fc4cdb9..b0e17771c 100644 --- a/pkg/registry/common/expire/nse_server_test.go +++ b/pkg/registry/common/expire/nse_server_test.go @@ -18,13 +18,15 @@ package expire_test import ( "context" + "io" + "sync/atomic" "testing" - "time" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/api/pkg/api/registry" @@ -36,96 +38,126 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" +) + +const ( + nseName = "nse" ) func TestExpireNSEServer_ShouldCorrectlySetExpirationTime_InRemoteCase(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) + s := next.NewNetworkServiceEndpointRegistryServer( - expire.NewNetworkServiceEndpointRegistryServer(context.Background(), time.Hour), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), new(remoteNSEServer), ) - resp, err := s.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1"}) + resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: nseName, + }) require.NoError(t, err) - require.Greater(t, time.Until(resp.ExpirationTime.AsTime()).Minutes(), float64(50)) + require.Equal(t, clockMock.Until(resp.ExpirationTime.AsTime()), expireTimeout) } func TestExpireNSEServer_ShouldUseLessExpirationTimeFromInput_AndWork(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) + + mem := memory.NewNetworkServiceEndpointRegistryServer() + s := next.NewNetworkServiceEndpointRegistryServer( - expire.NewNetworkServiceEndpointRegistryServer(context.Background(), time.Hour), - memory.NewNetworkServiceEndpointRegistryServer(), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + mem, ) - resp, err := s.Register(context.Background(), ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - ExpirationTime: timestamppb.New(time.Now().Add(time.Millisecond * 200)), + resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: nseName, + ExpirationTime: timestamppb.New(clockMock.Now().Add(expireTimeout / 2)), }) require.NoError(t, err) - require.Less(t, time.Until(resp.ExpirationTime.AsTime()).Seconds(), float64(65)) + require.Equal(t, clockMock.Until(resp.ExpirationTime.AsTime()), expireTimeout/2) - c := adapters.NetworkServiceEndpointServerToClient(s) + c := adapters.NetworkServiceEndpointServerToClient(mem) + clockMock.Add(expireTimeout / 2) require.Eventually(t, func() bool { - stream, err := c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ + stream, err := c.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), }) require.NoError(t, err) - list := registry.ReadNetworkServiceEndpointList(stream) - return len(list) == 0 - }, time.Second, time.Millisecond*100) + _, err = stream.Recv() + return err == io.EOF + }, testWait, testTick) } func TestExpireNSEServer_ShouldUseLessExpirationTimeFromResponse(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) + s := next.NewNetworkServiceEndpointRegistryServer( - expire.NewNetworkServiceEndpointRegistryServer(context.Background(), time.Hour), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), new(remoteNSEServer), // <-- GRPC invocation - expire.NewNetworkServiceEndpointRegistryServer(context.Background(), 10*time.Minute), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout/2), ) - resp, err := s.Register(context.Background(), ®istry.NetworkServiceEndpoint{Name: "nse-1"}) + resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) require.NoError(t, err) - require.Less(t, time.Until(resp.ExpirationTime.AsTime()).Minutes(), float64(11)) + require.Equal(t, clockMock.Until(resp.ExpirationTime.AsTime()), expireTimeout/2) } func TestExpireNSEServer_ShouldRemoveNSEAfterExpirationTime(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) + + mem := memory.NewNetworkServiceEndpointRegistryServer() + s := next.NewNetworkServiceEndpointRegistryServer( - expire.NewNetworkServiceEndpointRegistryServer(context.Background(), testPeriod*2), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), new(remoteNSEServer), // <-- GRPC invocation - memory.NewNetworkServiceEndpointRegistryServer(), + mem, ) - _, err := s.Register(context.Background(), ®istry.NetworkServiceEndpoint{}) + _, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: nseName, + }) require.NoError(t, err) - c := adapters.NetworkServiceEndpointServerToClient(s) - stream, err := c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ + c := adapters.NetworkServiceEndpointServerToClient(mem) + + stream, err := c.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), }) require.NoError(t, err) - list := registry.ReadNetworkServiceEndpointList(stream) - require.NotEmpty(t, list) + nse, err := stream.Recv() + require.NoError(t, err) + require.Equal(t, nseName, nse.Name) + clockMock.Add(expireTimeout) require.Eventually(t, func() bool { - stream, err = c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ + stream, err = c.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), }) require.NoError(t, err) - list = registry.ReadNetworkServiceEndpointList(stream) - return len(list) == 0 - }, time.Second, time.Millisecond*100) + _, err = stream.Recv() + return err == io.EOF + }, testWait, testTick) } func TestExpireNSEServer_DataRace(t *testing.T) { @@ -139,9 +171,9 @@ func TestExpireNSEServer_DataRace(t *testing.T) { mem, ) - for i := 0; i < 1000; i++ { + for i := 0; i < 200; i++ { _, err := s.Register(context.Background(), ®istry.NetworkServiceEndpoint{ - Name: "nse-1", + Name: nseName, Url: "tcp://1.1.1.1", }) require.NoError(t, err) @@ -155,9 +187,9 @@ func TestExpireNSEServer_DataRace(t *testing.T) { }) require.NoError(t, err) - list := registry.ReadNetworkServiceEndpointList(stream) - return len(list) == 0 - }, time.Second, time.Millisecond*100) + _, err = stream.Recv() + return err == io.EOF + }, testWait, testTick) } func TestExpireNSEServer_RefreshFailure(t *testing.T) { @@ -166,11 +198,14 @@ func TestExpireNSEServer_RefreshFailure(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + clockMock := clockmock.NewMock() + ctx = clock.WithClock(ctx, clockMock) + c := next.NewNetworkServiceEndpointRegistryClient( refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithChainContext(ctx)), adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( new(remoteNSEServer), // <-- GRPC invocation - expire.NewNetworkServiceEndpointRegistryServer(ctx, testPeriod), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), newFailureNSEServer(1, -1), memory.NewNetworkServiceEndpointRegistryServer(), )), @@ -179,15 +214,16 @@ func TestExpireNSEServer_RefreshFailure(t *testing.T) { _, err := c.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) require.NoError(t, err) + clockMock.Add(expireTimeout) require.Eventually(t, func() bool { - stream, err := c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ + stream, err := c.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), }) require.NoError(t, err) - list := registry.ReadNetworkServiceEndpointList(stream) - return len(list) == 0 - }, time.Second, time.Millisecond*100) + _, err = stream.Recv() + return err == io.EOF + }, testWait, testTick) } func TestExpireNSEServer_RefreshKeepsNoUnregister(t *testing.T) { @@ -196,37 +232,37 @@ func TestExpireNSEServer_RefreshKeepsNoUnregister(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mem := memory.NewNetworkServiceEndpointRegistryServer() + clockMock := clockmock.NewMock() + ctx = clock.WithClock(ctx, clockMock) + + unregisterServer := new(unregisterNSEServer) c := next.NewNetworkServiceEndpointRegistryClient( refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithChainContext(ctx)), adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( // NSMgr chain new(remoteNSEServer), // <-- GRPC invocation - expire.NewNetworkServiceEndpointRegistryServer(ctx, 2*testPeriod), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), // Registry chain new(remoteNSEServer), // <-- GRPC invocation checkcontext.NewNSEServer(t, func(_ *testing.T, _ context.Context) { - <-time.After(testPeriod) + clockMock.Add(expireTimeout / 2) }), - expire.NewNetworkServiceEndpointRegistryServer(ctx, time.Minute), - mem, + expire.NewNetworkServiceEndpointRegistryServer(ctx, 10*expireTimeout), + unregisterServer, )), ) - _, err := c.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) - require.NoError(t, err) - - stream, err := adapters.NetworkServiceEndpointServerToClient(mem).Find(ctx, ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), - Watch: true, + _, err := c.Register(ctx, ®istry.NetworkServiceEndpoint{ + Name: nseName, }) require.NoError(t, err) - for start := time.Now(); time.Since(start).Seconds() < 1; { - nse, err := stream.Recv() - require.NoError(t, err) - require.NotEqual(t, int64(-1), nse.ExpirationTime.Seconds) + for i := 0; i < 3; i++ { + clockMock.Add(expireTimeout / 10 * 9) + require.Never(t, func() bool { + return atomic.LoadInt32(&unregisterServer.unregisterCount) > 0 + }, testWait, testTick) } } @@ -277,3 +313,20 @@ func (s *failureNSEServer) Find(query *registry.NetworkServiceEndpointQuery, ser func (s *failureNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } + +type unregisterNSEServer struct { + unregisterCount int32 +} + +func (s *unregisterNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) +} + +func (s *unregisterNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) +} + +func (s *unregisterNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*emptypb.Empty, error) { + atomic.AddInt32(&s.unregisterCount, 1) + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) +} diff --git a/pkg/registry/common/memory/nse_server.go b/pkg/registry/common/memory/nse_server.go index f9dd27938..0604cbbc7 100644 --- a/pkg/registry/common/memory/nse_server.go +++ b/pkg/registry/common/memory/nse_server.go @@ -155,12 +155,12 @@ func (s *memoryNSEServer) receiveEvent( } func (s *memoryNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - s.networkServiceEndpoints.Delete(nse.Name) - - nse.ExpirationTime = ×tamp.Timestamp{ - Seconds: -1, + if unregisterNSE, ok := s.networkServiceEndpoints.LoadAndDelete(nse.Name); ok { + unregisterNSE = unregisterNSE.Clone() + unregisterNSE.ExpirationTime = ×tamp.Timestamp{ + Seconds: -1, + } + s.sendEvent(unregisterNSE) } - s.sendEvent(nse) - return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } diff --git a/pkg/registry/common/querycache/cache.go b/pkg/registry/common/querycache/cache.go index d280652ee..ba44f1c0b 100644 --- a/pkg/registry/common/querycache/cache.go +++ b/pkg/registry/common/querycache/cache.go @@ -22,35 +22,39 @@ import ( "time" "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/tools/clock" ) type cache struct { expireTimeout time.Duration entries cacheEntryMap + clockTime clock.Clock } func newCache(ctx context.Context, opts ...Option) *cache { c := &cache{ expireTimeout: time.Minute, + clockTime: clock.FromContext(ctx), } for _, opt := range opts { opt(c) } - ticker := time.NewTicker(c.expireTimeout) + ticker := c.clockTime.Ticker(c.expireTimeout) go func() { for { select { case <-ctx.Done(): ticker.Stop() return - case <-ticker.C: + case <-ticker.C(): c.entries.Range(func(_ string, e *cacheEntry) bool { e.lock.Lock() defer e.lock.Unlock() - if time.Until(e.expirationTime) < 0 { + if c.clockTime.Until(e.expirationTime) < 0 { e.cleanup() } @@ -67,7 +71,7 @@ func (c *cache) LoadOrStore(key string, nse *registry.NetworkServiceEndpoint, ca var once sync.Once return c.entries.LoadOrStore(key, &cacheEntry{ nse: nse, - expirationTime: time.Now().Add(c.expireTimeout), + expirationTime: c.clockTime.Now().Add(c.expireTimeout), cleanup: func() { once.Do(func() { c.entries.Delete(key) @@ -86,12 +90,12 @@ func (c *cache) Load(key string) (*registry.NetworkServiceEndpoint, bool) { e.lock.Lock() defer e.lock.Unlock() - if time.Until(e.expirationTime) < 0 { + if c.clockTime.Until(e.expirationTime) < 0 { e.cleanup() return nil, false } - e.expirationTime = time.Now().Add(c.expireTimeout) + e.expirationTime = c.clockTime.Now().Add(c.expireTimeout) return e.nse, true } diff --git a/pkg/registry/common/querycache/nse_client_test.go b/pkg/registry/common/querycache/nse_client_test.go index fda31b595..a97881c73 100644 --- a/pkg/registry/common/querycache/nse_client_test.go +++ b/pkg/registry/common/querycache/nse_client_test.go @@ -33,13 +33,17 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/querycache" "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) const ( - name = "nse" - url1 = "tcp://1.1.1.1" - url2 = "tcp://2.2.2.2" - expirationTime = 100 * time.Millisecond + expireTimeout = time.Minute + name = "nse" + url1 = "tcp://1.1.1.1" + url2 = "tcp://2.2.2.2" + testWait = 100 * time.Millisecond + testTick = testWait / 100 ) func testNSEQuery(nseName string) *registry.NetworkServiceEndpointQuery { @@ -60,7 +64,7 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { failureClient := new(failureNSEClient) c := next.NewNetworkServiceEndpointRegistryClient( - querycache.NewClient(ctx, querycache.WithExpireTimeout(time.Minute)), + querycache.NewClient(ctx, querycache.WithExpireTimeout(expireTimeout)), failureClient, adapters.NetworkServiceEndpointServerToClient(mem), ) @@ -75,6 +79,8 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) // 1. Find from memory + atomic.StoreInt32(&failureClient.shouldFail, 0) + stream, err := c.Find(ctx, testNSEQuery("")) require.NoError(t, err) @@ -95,7 +101,7 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { return false } return name == nse.Name && url1 == nse.Url - }, 100*time.Millisecond, time.Millisecond) + }, testWait, testTick) // 3. Update NSE in memory reg.Url = url2 @@ -111,7 +117,7 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { return false } return name == nse.Name && url2 == nse.Url - }, 100*time.Millisecond, time.Millisecond) + }, testWait, testTick) // 4. Delete NSE from memory _, err = mem.Unregister(ctx, reg) @@ -120,7 +126,7 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { require.Eventually(t, func() bool { _, err = c.Find(ctx, testNSEQuery(name)) return err != nil - }, 100*time.Millisecond, time.Millisecond) + }, testWait, testTick) } func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { @@ -129,11 +135,14 @@ func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + clockMock := clockmock.NewMock() + ctx = clock.WithClock(ctx, clockMock) + mem := memory.NewNetworkServiceEndpointRegistryServer() failureClient := new(failureNSEClient) c := next.NewNetworkServiceEndpointRegistryClient( - querycache.NewClient(ctx, querycache.WithExpireTimeout(expirationTime)), + querycache.NewClient(ctx, querycache.WithExpireTimeout(expireTimeout)), failureClient, adapters.NetworkServiceEndpointServerToClient(mem), ) @@ -147,6 +156,8 @@ func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) // 1. Find from memory + atomic.StoreInt32(&failureClient.shouldFail, 0) + stream, err := c.Find(ctx, testNSEQuery("")) require.NoError(t, err) @@ -161,11 +172,11 @@ func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { _, err = stream.Recv() } return err == nil - }, 100*time.Millisecond, time.Millisecond) + }, testWait, testTick) // 3. Keep finding from cache to prevent expiration - for start := time.Now(); time.Since(start) > 2*expirationTime; time.Sleep(expirationTime / 10) { - stream, err = c.Find(ctx, testNSEQuery("")) + for start := clockMock.Now(); clockMock.Since(start) < 2*expireTimeout; clockMock.Add(expireTimeout / 3) { + stream, err = c.Find(ctx, testNSEQuery(name)) require.NoError(t, err) _, err = stream.Recv() @@ -173,10 +184,10 @@ func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { } // 4. Wait for the expire to happen - time.Sleep(expirationTime) + clockMock.Add(expireTimeout) - _, err = c.Find(ctx, testNSEQuery("")) - require.Error(t, err) + _, err = c.Find(ctx, testNSEQuery(name)) + require.Errorf(t, err, "find error") } type failureNSEClient struct { diff --git a/pkg/registry/common/refresh/nse_registry_client.go b/pkg/registry/common/refresh/nse_registry_client.go index 92235d4b6..332d42c97 100644 --- a/pkg/registry/common/refresh/nse_registry_client.go +++ b/pkg/registry/common/refresh/nse_registry_client.go @@ -27,6 +27,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/log" ) @@ -54,6 +55,7 @@ func NewNetworkServiceEndpointRegistryClient(options ...Option) registry.Network func (c *refreshNSEClient) startRefresh( ctx context.Context, + clockTime clock.Clock, client registry.NetworkServiceEndpointRegistryClient, nse *registry.NetworkServiceEndpoint, expiryDuration time.Duration, @@ -66,8 +68,8 @@ func (c *refreshNSEClient) startRefresh( select { case <-ctx.Done(): return - case <-time.After(2 * time.Until(t) / 3): - nse.ExpirationTime = timestamppb.New(time.Now().Add(expiryDuration)) + case <-clockTime.After(2 * clockTime.Until(t) / 3): + nse.ExpirationTime = timestamppb.New(clockTime.Now().Add(expiryDuration)) res, err := client.Register(ctx, nse.Clone()) if err != nil { @@ -88,12 +90,14 @@ func (c *refreshNSEClient) Register( nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption, ) (*registry.NetworkServiceEndpoint, error) { + clockTime := clock.FromContext(ctx) + var expiryDuration time.Duration if nse.ExpirationTime == nil { expiryDuration = c.defaultExpiryDuration - nse.ExpirationTime = timestamppb.New(time.Now().Add(expiryDuration)) + nse.ExpirationTime = timestamppb.New(clockTime.Now().Add(expiryDuration)) } else { - expiryDuration = time.Until(nse.ExpirationTime.AsTime().Local()) + expiryDuration = clockTime.Until(nse.ExpirationTime.AsTime().Local()) } refreshNSE := nse.Clone() @@ -114,7 +118,7 @@ func (c *refreshNSEClient) Register( ctx, cancel := context.WithCancel(c.chainContext) c.nseCancels.Store(resp.Name, cancel) - c.startRefresh(ctx, nextClient, refreshNSE, expiryDuration) + c.startRefresh(ctx, clockTime, nextClient, refreshNSE, expiryDuration) return resp, err } diff --git a/pkg/registry/common/refresh/nse_registry_client_test.go b/pkg/registry/common/refresh/nse_registry_client_test.go index 528897486..0588be621 100644 --- a/pkg/registry/common/refresh/nse_registry_client_test.go +++ b/pkg/registry/common/refresh/nse_registry_client_test.go @@ -35,9 +35,15 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checknse" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) -const testExpiryDuration = time.Millisecond * 100 +const ( + expireTimeout = time.Minute + testWait = 100 * time.Millisecond + testTick = testWait / 100 +) func testNSE() *registry.NetworkServiceEndpoint { return ®istry.NetworkServiceEndpoint{ @@ -48,22 +54,29 @@ func testNSE() *registry.NetworkServiceEndpoint { func TestNewNetworkServiceEndpointRegistryClient(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) + countClient := new(requestCountClient) client := next.NewNetworkServiceEndpointRegistryClient( - refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithDefaultExpiryDuration(testExpiryDuration)), + refresh.NewNetworkServiceEndpointRegistryClient( + refresh.WithChainContext(ctx), + refresh.WithDefaultExpiryDuration(expireTimeout)), countClient, ) - _, err := client.Register(context.Background(), ®istry.NetworkServiceEndpoint{ - Name: "nse-1", - }) + _, err := client.Register(ctx, testNSE()) require.NoError(t, err) + // Wait for the Refresh goroutine to start + time.Sleep(testTick) + + clockMock.Add(expireTimeout) require.Eventually(t, func() bool { return atomic.LoadInt32(&countClient.requestCount) > 1 - }, time.Second, testExpiryDuration/4) + }, testWait, testTick) - _, err = client.Unregister(context.Background(), testNSE()) + _, err = client.Unregister(ctx, testNSE()) require.NoError(t, err) } @@ -87,29 +100,41 @@ func TestRefreshNSEClient_ShouldSetExpirationTime_BeforeCallNext(t *testing.T) { func Test_RefreshNSEClient_CalledRegisterTwice(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) + countClient := new(requestCountClient) client := next.NewNetworkServiceEndpointRegistryClient( - refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithDefaultExpiryDuration(testExpiryDuration)), + refresh.NewNetworkServiceEndpointRegistryClient( + refresh.WithChainContext(ctx), + refresh.WithDefaultExpiryDuration(expireTimeout)), countClient, ) - _, err := client.Register(context.Background(), testNSE()) + _, err := client.Register(ctx, testNSE()) require.NoError(t, err) - reg, err := client.Register(context.Background(), testNSE()) + reg, err := client.Register(ctx, testNSE()) require.NoError(t, err) + // Wait for the Refresh goroutine to start + time.Sleep(testTick) + + clockMock.Add(expireTimeout) require.Eventually(t, func() bool { return atomic.LoadInt32(&countClient.requestCount) > 2 - }, time.Second, testExpiryDuration/4) + }, testWait, testTick) - _, err = client.Unregister(context.Background(), reg) + _, err = client.Unregister(ctx, reg) require.NoError(t, err) } func Test_RefreshNSEClient_ShouldOverrideNameAndDuration(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) + endpoint := ®istry.NetworkServiceEndpoint{ Name: "nse-1", Url: "url", @@ -118,10 +143,12 @@ func Test_RefreshNSEClient_ShouldOverrideNameAndDuration(t *testing.T) { countClient := new(requestCountClient) registryServer := &nseRegistryServer{ name: uuid.New().String(), - expiryDuration: testExpiryDuration, + expiryDuration: expireTimeout, } client := next.NewNetworkServiceEndpointRegistryClient( - refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithDefaultExpiryDuration(time.Hour)), + refresh.NewNetworkServiceEndpointRegistryClient( + refresh.WithChainContext(ctx), + refresh.WithDefaultExpiryDuration(10*expireTimeout)), checknse.NewClient(t, func(t *testing.T, nse *registry.NetworkServiceEndpoint) { if atomic.LoadInt32(&countClient.requestCount) > 0 { require.Equal(t, registryServer.name, nse.Name) @@ -132,43 +159,63 @@ func Test_RefreshNSEClient_ShouldOverrideNameAndDuration(t *testing.T) { adapters.NetworkServiceEndpointServerToClient(registryServer), ) - reg, err := client.Register(context.Background(), endpoint.Clone()) + reg, err := client.Register(ctx, endpoint.Clone()) require.NoError(t, err) - require.Eventually(t, func() bool { - return atomic.LoadInt32(&countClient.requestCount) > 3 - }, time.Second, testExpiryDuration/4) + // Wait for the Refresh goroutine to start + time.Sleep(testTick) + + for i := 1; i <= 3; i++ { + count := int32(i) + + clockMock.Add(expireTimeout) + require.Eventually(t, func() bool { + return atomic.LoadInt32(&countClient.requestCount) > count + }, testWait, testTick) + } reg.Url = endpoint.Url - _, err = client.Unregister(context.Background(), reg) + _, err = client.Unregister(ctx, reg) require.NoError(t, err) } func Test_RefreshNSEClient_SetsCorrectExpireTime(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - const expiryDuration = 100 * time.Millisecond - const timeoutDuration = 200 * time.Millisecond + clockMock := clockmock.NewMock() + ctx := clock.WithClock(context.Background(), clockMock) countClient := new(requestCountClient) client := next.NewNetworkServiceEndpointRegistryClient( - refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithDefaultExpiryDuration(expiryDuration)), + refresh.NewNetworkServiceEndpointRegistryClient( + refresh.WithChainContext(ctx), + refresh.WithDefaultExpiryDuration(expireTimeout)), checknse.NewClient(t, func(t *testing.T, nse *registry.NetworkServiceEndpoint) { - require.Greater(t, int64(expiryDuration+timeoutDuration)/2, int64(time.Until(nse.ExpirationTime.AsTime().Local()))) - nse.ExpirationTime = timestamppb.New(time.Now().Add(timeoutDuration)) + require.Equal(t, expireTimeout, clockMock.Until(nse.ExpirationTime.AsTime().Local())) + nse.ExpirationTime = timestamppb.New(clockMock.Now().Add(10 * expireTimeout)) }), countClient, ) - reg, err := client.Register(context.Background(), testNSE()) + reg, err := client.Register(ctx, testNSE()) require.NoError(t, err) - require.Eventually(t, func() bool { - return atomic.LoadInt32(&countClient.requestCount) > 3 - }, time.Second, testExpiryDuration/4) + // Wait for the Refresh goroutine to start + time.Sleep(testTick) - _, err = client.Unregister(context.Background(), reg) + for i := 1; i <= 3; i++ { + count := int32(i) + + clockMock.Add(10 * expireTimeout) + require.Eventually(t, func() bool { + return atomic.LoadInt32(&countClient.requestCount) > count + }, testWait, testTick) + } + + reg.ExpirationTime = timestamppb.New(clockMock.Now().Add(expireTimeout)) + + _, err = client.Unregister(ctx, reg) require.Nil(t, err) } @@ -199,7 +246,7 @@ func (s *nseRegistryServer) Register(ctx context.Context, nse *registry.NetworkS nse = nse.Clone() nse.Name = s.name nse.Url = uuid.New().String() - nse.ExpirationTime = timestamppb.New(time.Now().Add(s.expiryDuration)) + nse.ExpirationTime = timestamppb.New(clock.FromContext(ctx).Now().Add(s.expiryDuration)) return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) }