From dd7b968d3577d5a61c749eabb89d77495232ea2a Mon Sep 17 00:00:00 2001 From: harsh-98 Date: Thu, 14 Sep 2023 16:50:53 +0700 Subject: [PATCH] fix: add mutex to PushToChan --- waku/v2/protocol/common_service.go | 19 ++++++++++--------- waku/v2/rendezvous/rendezvous.go | 22 +++++++++++----------- waku/v2/rendezvous/rendezvous_test.go | 2 ++ 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/waku/v2/protocol/common_service.go b/waku/v2/protocol/common_service.go index 6cfcfa100..f33e12b48 100644 --- a/waku/v2/protocol/common_service.go +++ b/waku/v2/protocol/common_service.go @@ -4,7 +4,6 @@ import ( "context" "errors" "sync" - "sync/atomic" ) // this is common layout for all the services that require mutex protection and a guarantee that all running goroutines will be finished before stop finishes execution. This guarantee comes from waitGroup all one has to use CommonService.WaitGroup() in the goroutines that should finish by the end of stop function. @@ -13,7 +12,7 @@ type CommonService[T any] struct { cancel context.CancelFunc ctx context.Context wg sync.WaitGroup - started atomic.Bool + started bool channel chan T } @@ -30,15 +29,15 @@ func NewCommonService[T any]() *CommonService[T] { func (sp *CommonService[T]) Start(ctx context.Context, fn func() error) error { sp.Lock() defer sp.Unlock() - if sp.started.Load() { + if sp.started { return ErrAlreadyStarted } - sp.started.Store(true) + sp.started = true sp.ctx, sp.cancel = context.WithCancel(ctx) // currently is used in discv5 for returning new discovered Peers to peerConnector for connecting with them sp.channel = make(chan T) if err := fn(); err != nil { - sp.started.Store(false) + sp.started = false sp.cancel() return err } @@ -52,19 +51,19 @@ var ErrNotStarted = errors.New("not started") func (sp *CommonService[T]) Stop(fn func()) { sp.Lock() defer sp.Unlock() - if !sp.started.Load() { + if !sp.started { return } sp.cancel() fn() sp.wg.Wait() close(sp.channel) - sp.started.Store(false) + sp.started = false } // This is not a mutex protected function, it is up to the caller to use it in a mutex protected context func (sp *CommonService[T]) ErrOnNotRunning() error { - if !sp.started.Load() { + if !sp.started { return ErrNotStarted } return nil @@ -80,7 +79,9 @@ func (sp *CommonService[T]) GetListeningChan() <-chan T { return sp.channel } func (sp *CommonService[T]) PushToChan(data T) bool { - if !sp.started.Load() { + sp.RLock() + defer sp.RUnlock() + if !sp.started { return false } select { diff --git a/waku/v2/rendezvous/rendezvous.go b/waku/v2/rendezvous/rendezvous.go index 2739e69e6..74f20f0ed 100644 --- a/waku/v2/rendezvous/rendezvous.go +++ b/waku/v2/rendezvous/rendezvous.go @@ -31,7 +31,7 @@ type Rendezvous struct { peerConnector PeerConnector log *zap.Logger - *protocol.CommonService[struct{}] + *protocol.CommonService[peermanager.PeerData] } // PeerConnector will subscribe to a channel containing the information for all peers found by this discovery protocol @@ -46,7 +46,7 @@ func NewRendezvous(db *DB, peerConnector PeerConnector, log *zap.Logger) *Rendez db: db, peerConnector: peerConnector, log: logger, - CommonService: protocol.NewCommonService[struct{}](), + CommonService: protocol.NewCommonService[peermanager.PeerData](), } } @@ -60,10 +60,15 @@ func (r *Rendezvous) Start(ctx context.Context) error { } func (r *Rendezvous) start() error { - err := r.db.Start(r.Context()) - if err != nil { - return err + if r.db != nil { + if err := r.db.Start(r.Context()); err != nil { + return err + } } + if r.peerConnector != nil { + r.peerConnector.Subscribe(r.Context(), r.GetListeningChan()) + } + r.rendezvousSvc = rvs.NewRendezvousService(r.host, r.db) r.log.Info("rendezvous protocol started") @@ -98,18 +103,13 @@ func (r *Rendezvous) DiscoverWithNamespace(ctx context.Context, namespace string if len(addrInfo) != 0 { rp.SetSuccess(cookie) - peerCh := make(chan peermanager.PeerData) - defer close(peerCh) - r.peerConnector.Subscribe(ctx, peerCh) for _, p := range addrInfo { peer := peermanager.PeerData{ Origin: peerstore.Rendezvous, AddrInfo: p, } - select { - case <-ctx.Done(): + if !r.PushToChan(peer) { return - case peerCh <- peer: } } } else { diff --git a/waku/v2/rendezvous/rendezvous_test.go b/waku/v2/rendezvous/rendezvous_test.go index ccf2ac789..2a8563cfb 100644 --- a/waku/v2/rendezvous/rendezvous_test.go +++ b/waku/v2/rendezvous/rendezvous_test.go @@ -92,6 +92,7 @@ func TestRendezvous(t *testing.T) { rendezvousClient2 := NewRendezvous(nil, myPeerConnector, utils.Logger()) rendezvousClient2.SetHost(host3) + rendezvousClient2.Start(ctx) timedCtx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() @@ -108,4 +109,5 @@ func TestRendezvous(t *testing.T) { case p := <-myPeerConnector.ch: require.Equal(t, p.AddrInfo.ID.Pretty(), host2.ID().Pretty()) } + rendezvousClient2.Stop() }