Skip to content

Commit

Permalink
Fix nil pointer dereference during close, wait inflight callbacks to …
Browse files Browse the repository at this point in the history
…finish (#96)
  • Loading branch information
FZambia authored Oct 21, 2024
1 parent 4a54acc commit c4f7a6d
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 6 deletions.
18 changes: 16 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Client struct {
token string
data protocol.Raw
transport transport
disconnectedCh chan struct{}
state State
subs map[string]*Subscription
serverSubs map[string]*serverSub
Expand Down Expand Up @@ -138,7 +139,9 @@ func newClient(endpoint string, isProtobuf bool, config Config) *Client {
}

// Queue to run callbacks on.
client.cbQueue = &cbQueue{}
client.cbQueue = &cbQueue{
closeCh: make(chan struct{}),
}
client.cbQueue.cond = sync.NewCond(&client.cbQueue.mu)
go client.cbQueue.dispatch()

Expand Down Expand Up @@ -534,10 +537,20 @@ func (c *Client) moveToClosed() {
})
}

c.mu.RLock()
disconnectedCh := c.disconnectedCh
c.mu.RUnlock()
// At this point connection close was issued, so we wait until the reader goroutine
// finishes its work, after that it's safe to close the callback queue.
if disconnectedCh != nil {
<-disconnectedCh
}

c.mu.Lock()
defer c.mu.Unlock()
c.disconnectedCh = nil
c.cbQueue.close()
c.cbQueue = nil
c.mu.Unlock()
}

func (c *Client) handleError(err error) {
Expand Down Expand Up @@ -959,6 +972,7 @@ func (c *Client) startReconnecting() error {
disconnectCh := make(chan struct{})
c.receive = make(chan []byte, 64)
c.transport = t
c.disconnectedCh = disconnectCh

go c.reader(t, disconnectCh)

Expand Down
21 changes: 17 additions & 4 deletions queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
// https://github.com/nats-io/nats.go client released under Apache 2.0
// license: see https://github.com/nats-io/nats.go/blob/master/LICENSE.
type cbQueue struct {
mu sync.Mutex
cond *sync.Cond
head *asyncCB
tail *asyncCB
mu sync.Mutex
cond *sync.Cond
head *asyncCB
tail *asyncCB
closeCh chan struct{}
closed bool
}

type asyncCB struct {
Expand Down Expand Up @@ -43,6 +45,7 @@ func (q *cbQueue) dispatch() {
// This signals that the dispatcher has been closed and all
// previous callbacks have been dispatched.
if curr.fn == nil {
close(q.closeCh)
return
}
curr.fn(time.Since(curr.tm))
Expand All @@ -56,13 +59,22 @@ func (q *cbQueue) push(f func(duration time.Duration)) {
}

// Close signals that async queue must be closed.
// Queue won't accept any more callbacks after that – ignoring them if pushed.
func (q *cbQueue) close() {
q.pushOrClose(nil, true)
q.waitClose()
}

func (q *cbQueue) waitClose() {
<-q.closeCh
}

func (q *cbQueue) pushOrClose(f func(time.Duration), close bool) {
q.mu.Lock()
defer q.mu.Unlock()
if q.closed {
return
}
// Make sure that library is not calling push with nil function,
// since this is used to notify the dispatcher that it must stop.
if !close && f == nil {
Expand All @@ -76,6 +88,7 @@ func (q *cbQueue) pushOrClose(f func(time.Duration), close bool) {
}
q.tail = cb
if close {
q.closed = true
q.cond.Broadcast()
} else {
q.cond.Signal()
Expand Down
131 changes: 131 additions & 0 deletions queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package centrifuge

import (
"sync"
"testing"
"time"
)

func assertTrue(t *testing.T, condition bool, msg string) {
if !condition {
t.Fatalf("Assertion failed: %s", msg)
}
}

func assertEqual(t *testing.T, expected, actual interface{}, msg string) {
if expected != actual {
t.Fatalf("Assertion failed: %s - expected: %v, got: %v", msg, expected, actual)
}
}

func newTestQueue() *cbQueue {
q := &cbQueue{
closeCh: make(chan struct{}),
}
q.cond = sync.NewCond(&q.mu)
return q
}

func TestCbQueue_PushAndDispatch(t *testing.T) {
q := newTestQueue()

var wg sync.WaitGroup
wg.Add(1)

// Start the dispatcher in a separate goroutine.
go q.dispatch()

startTime := time.Now()
q.push(func(d time.Duration) {
defer wg.Done()
assertTrue(t, d >= 0, "Callback duration should be positive")
})

// Wait for the callback to finish.
wg.Wait()

// Ensure the callback executed quickly.
elapsed := time.Since(startTime)
assertTrue(t, elapsed < 100*time.Millisecond, "Callback should be dispatched immediately")
}

func TestCbQueue_OrderPreservation(t *testing.T) {
q := newTestQueue()

// Start the dispatcher in a separate goroutine.
go q.dispatch()

var results []int
var mu sync.Mutex
expectedResults := []int{1, 2, 3}

for _, i := range expectedResults {
i := i
q.push(func(d time.Duration) {
mu.Lock()
defer mu.Unlock()
results = append(results, i)
})
}

// Allow time for the queue to process.
time.Sleep(100 * time.Millisecond)

mu.Lock()
defer mu.Unlock()

for i, r := range results {
assertEqual(t, expectedResults[i], r, "unexpected result")
}
}

func TestCbQueue_Close(t *testing.T) {
q := newTestQueue()

go q.dispatch()

var executed bool
q.push(func(d time.Duration) {
executed = true
})

q.close()

// Ensure the closeCh channel is closed.
select {
case <-q.closeCh:
// Channel was closed as expected.
case <-time.After(1 * time.Second):
t.Fatal("closeCh was not closed after queue close")
}

assertTrue(t, executed, "Callback should be executed before close")
}

func TestCbQueue_IgnorePushAfterClose(t *testing.T) {
q := newTestQueue()
go q.dispatch()
q.close()

var executed bool
q.push(func(d time.Duration) {
executed = true
})

// Allow some time to see if the callback is executed.
time.Sleep(100 * time.Millisecond)

assertTrue(t, !executed, "Callback should not be executed after queue close")
}

func TestCbQueue_PushNilCallbackPanics(t *testing.T) {
q := newTestQueue()

defer func() {
if r := recover(); r == nil {
t.Fatal("Expected panic when pushing nil callback with close set to false")
}
}()

q.pushOrClose(nil, false)
}

0 comments on commit c4f7a6d

Please sign in to comment.