diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index 233753483..3b447980c 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -101,13 +101,13 @@ func (f *eventFactoryClient) Request(opts ...Option) <-chan error { request := f.request.Clone() if o.reselect { ctx, cancel := f.ctxFunc() - defer cancel() _, _ = f.client.Close(ctx, request.GetConnection(), f.opts...) if request.GetConnection() != nil { request.GetConnection().Mechanism = nil request.GetConnection().NetworkServiceEndpointName = "" request.GetConnection().State = networkservice.State_RESELECT_REQUESTED } + cancel() } ctx, cancel := f.ctxFunc() defer cancel() diff --git a/pkg/networkservice/common/begin/event_factory_client_test.go b/pkg/networkservice/common/begin/event_factory_client_test.go index 3209e3384..7e5ef6421 100644 --- a/pkg/networkservice/common/begin/event_factory_client_test.go +++ b/pkg/networkservice/common/begin/event_factory_client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. +// Copyright (c) 2022-2023 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,6 +18,8 @@ package begin_test import ( "context" + + "sync/atomic" "testing" "time" @@ -33,6 +35,8 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkcontext" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/count" "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) @@ -166,6 +170,54 @@ func TestContextTimeout_Client(t *testing.T) { eventFactoryCl.callClose() } +// This test checks if the eventFactory reselect option cancels Close ctx correctly +func TestContextCancellationOnReselect(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + var closeCall atomic.Bool + ch := make(chan struct{}, 1) + + eventFactoryCl := &eventFactoryClient{} + counter := &count.Client{} + client := chain.NewNetworkServiceClient( + begin.NewClient(), + eventFactoryCl, + checkcontext.NewClient(t, func(t *testing.T, reqCtx context.Context) { + if !closeCall.Load() { + <-ch + } else { + closeCall.Store(false) + go func() { + <-reqCtx.Done() + ch <- struct{}{} + }() + } + }), + counter, + ) + + // Do Request + // Write to ch for the first request + ch <- struct{}{} + request := testRequest("1") + conn, err := client.Request(ctx, request.Clone()) + require.NotNil(t, t, conn) + require.NoError(t, err) + + // Call Reselect + // It will call Close first + closeCall.Store(true) + eventFactoryCl.callReselect() + + // Waiting for the re-request + require.Eventually(t, func() bool { + return counter.Requests() == 2 + }, time.Second, time.Millisecond*100) +} + type eventFactoryClient struct { ctx context.Context } @@ -191,6 +243,11 @@ func (s *eventFactoryClient) callRefresh() { <-eventFactory.Request() } +func (s *eventFactoryClient) callReselect() { + eventFactory := begin.FromContext(s.ctx) + eventFactory.Request(begin.WithReselect()) +} + type contextKey struct{} type checkContextClient struct {