Skip to content

Commit

Permalink
fix: add mutex to PushToChan
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-98 committed Sep 14, 2023
1 parent c5c61a2 commit dd7b968
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
19 changes: 10 additions & 9 deletions waku/v2/protocol/common_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
)

// this is common layout for all the services that require mutex protection and a guarantee that all running goroutines will be finished before stop finishes execution. This guarantee comes from waitGroup all one has to use CommonService.WaitGroup() in the goroutines that should finish by the end of stop function.
Expand All @@ -13,7 +12,7 @@ type CommonService[T any] struct {
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
started atomic.Bool
started bool
channel chan T
}

Expand All @@ -30,15 +29,15 @@ func NewCommonService[T any]() *CommonService[T] {
func (sp *CommonService[T]) Start(ctx context.Context, fn func() error) error {
sp.Lock()
defer sp.Unlock()
if sp.started.Load() {
if sp.started {
return ErrAlreadyStarted
}
sp.started.Store(true)
sp.started = true
sp.ctx, sp.cancel = context.WithCancel(ctx)
// currently is used in discv5 for returning new discovered Peers to peerConnector for connecting with them
sp.channel = make(chan T)
if err := fn(); err != nil {
sp.started.Store(false)
sp.started = false
sp.cancel()
return err
}
Expand All @@ -52,19 +51,19 @@ var ErrNotStarted = errors.New("not started")
func (sp *CommonService[T]) Stop(fn func()) {
sp.Lock()
defer sp.Unlock()
if !sp.started.Load() {
if !sp.started {
return
}
sp.cancel()
fn()
sp.wg.Wait()
close(sp.channel)
sp.started.Store(false)
sp.started = false
}

// This is not a mutex protected function, it is up to the caller to use it in a mutex protected context
func (sp *CommonService[T]) ErrOnNotRunning() error {
if !sp.started.Load() {
if !sp.started {
return ErrNotStarted
}
return nil
Expand All @@ -80,7 +79,9 @@ func (sp *CommonService[T]) GetListeningChan() <-chan T {
return sp.channel
}
func (sp *CommonService[T]) PushToChan(data T) bool {
if !sp.started.Load() {
sp.RLock()
defer sp.RUnlock()
if !sp.started {
return false
}
select {
Expand Down
22 changes: 11 additions & 11 deletions waku/v2/rendezvous/rendezvous.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Rendezvous struct {
peerConnector PeerConnector

log *zap.Logger
*protocol.CommonService[struct{}]
*protocol.CommonService[peermanager.PeerData]
}

// PeerConnector will subscribe to a channel containing the information for all peers found by this discovery protocol
Expand All @@ -46,7 +46,7 @@ func NewRendezvous(db *DB, peerConnector PeerConnector, log *zap.Logger) *Rendez
db: db,
peerConnector: peerConnector,
log: logger,
CommonService: protocol.NewCommonService[struct{}](),
CommonService: protocol.NewCommonService[peermanager.PeerData](),
}
}

Expand All @@ -60,10 +60,15 @@ func (r *Rendezvous) Start(ctx context.Context) error {
}

func (r *Rendezvous) start() error {
err := r.db.Start(r.Context())
if err != nil {
return err
if r.db != nil {
if err := r.db.Start(r.Context()); err != nil {
return err
}
}
if r.peerConnector != nil {
r.peerConnector.Subscribe(r.Context(), r.GetListeningChan())
}

r.rendezvousSvc = rvs.NewRendezvousService(r.host, r.db)

r.log.Info("rendezvous protocol started")
Expand Down Expand Up @@ -98,18 +103,13 @@ func (r *Rendezvous) DiscoverWithNamespace(ctx context.Context, namespace string
if len(addrInfo) != 0 {
rp.SetSuccess(cookie)

peerCh := make(chan peermanager.PeerData)
defer close(peerCh)
r.peerConnector.Subscribe(ctx, peerCh)
for _, p := range addrInfo {
peer := peermanager.PeerData{
Origin: peerstore.Rendezvous,
AddrInfo: p,
}
select {
case <-ctx.Done():
if !r.PushToChan(peer) {
return
case peerCh <- peer:
}
}
} else {
Expand Down
2 changes: 2 additions & 0 deletions waku/v2/rendezvous/rendezvous_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func TestRendezvous(t *testing.T) {

rendezvousClient2 := NewRendezvous(nil, myPeerConnector, utils.Logger())
rendezvousClient2.SetHost(host3)
rendezvousClient2.Start(ctx)

timedCtx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
Expand All @@ -108,4 +109,5 @@ func TestRendezvous(t *testing.T) {
case p := <-myPeerConnector.ch:
require.Equal(t, p.AddrInfo.ID.Pretty(), host2.ID().Pretty())
}
rendezvousClient2.Stop()
}

0 comments on commit dd7b968

Please sign in to comment.