Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nil pointer dereference during close #96

Merged
merged 2 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Client struct {
token string
data protocol.Raw
transport transport
disconnectedCh chan struct{}
state State
subs map[string]*Subscription
serverSubs map[string]*serverSub
Expand Down Expand Up @@ -138,7 +139,9 @@ func newClient(endpoint string, isProtobuf bool, config Config) *Client {
}

// Queue to run callbacks on.
client.cbQueue = &cbQueue{}
client.cbQueue = &cbQueue{
closeCh: make(chan struct{}),
}
client.cbQueue.cond = sync.NewCond(&client.cbQueue.mu)
go client.cbQueue.dispatch()

Expand Down Expand Up @@ -534,10 +537,20 @@ func (c *Client) moveToClosed() {
})
}

c.mu.RLock()
disconnectedCh := c.disconnectedCh
c.mu.RUnlock()
// At this point connection close was issued, so we wait until the reader goroutine
// finishes its work, after that it's safe to close the callback queue.
if disconnectedCh != nil {
<-disconnectedCh
}

c.mu.Lock()
defer c.mu.Unlock()
c.disconnectedCh = nil
c.cbQueue.close()
c.cbQueue = nil
c.mu.Unlock()
}

func (c *Client) handleError(err error) {
Expand Down Expand Up @@ -959,6 +972,7 @@ func (c *Client) startReconnecting() error {
disconnectCh := make(chan struct{})
c.receive = make(chan []byte, 64)
c.transport = t
c.disconnectedCh = disconnectCh

go c.reader(t, disconnectCh)

Expand Down
21 changes: 17 additions & 4 deletions queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
// https://github.com/nats-io/nats.go client released under Apache 2.0
// license: see https://github.com/nats-io/nats.go/blob/master/LICENSE.
type cbQueue struct {
mu sync.Mutex
cond *sync.Cond
head *asyncCB
tail *asyncCB
mu sync.Mutex
cond *sync.Cond
head *asyncCB
tail *asyncCB
closeCh chan struct{}
closed bool
}

type asyncCB struct {
Expand Down Expand Up @@ -43,6 +45,7 @@ func (q *cbQueue) dispatch() {
// This signals that the dispatcher has been closed and all
// previous callbacks have been dispatched.
if curr.fn == nil {
close(q.closeCh)
return
}
curr.fn(time.Since(curr.tm))
Expand All @@ -56,13 +59,22 @@ func (q *cbQueue) push(f func(duration time.Duration)) {
}

// Close signals that async queue must be closed.
// Queue won't accept any more callbacks after that – ignoring them if pushed.
func (q *cbQueue) close() {
q.pushOrClose(nil, true)
q.waitClose()
}

func (q *cbQueue) waitClose() {
<-q.closeCh
}

func (q *cbQueue) pushOrClose(f func(time.Duration), close bool) {
q.mu.Lock()
defer q.mu.Unlock()
if q.closed {
return
}
// Make sure that library is not calling push with nil function,
// since this is used to notify the dispatcher that it must stop.
if !close && f == nil {
Expand All @@ -76,6 +88,7 @@ func (q *cbQueue) pushOrClose(f func(time.Duration), close bool) {
}
q.tail = cb
if close {
q.closed = true
q.cond.Broadcast()
} else {
q.cond.Signal()
Expand Down
131 changes: 131 additions & 0 deletions queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package centrifuge

import (
"sync"
"testing"
"time"
)

func assertTrue(t *testing.T, condition bool, msg string) {
if !condition {
t.Fatalf("Assertion failed: %s", msg)
}
}

func assertEqual(t *testing.T, expected, actual interface{}, msg string) {
if expected != actual {
t.Fatalf("Assertion failed: %s - expected: %v, got: %v", msg, expected, actual)
}
}

func newTestQueue() *cbQueue {
q := &cbQueue{
closeCh: make(chan struct{}),
}
q.cond = sync.NewCond(&q.mu)
return q
}

func TestCbQueue_PushAndDispatch(t *testing.T) {
q := newTestQueue()

var wg sync.WaitGroup
wg.Add(1)

// Start the dispatcher in a separate goroutine.
go q.dispatch()

startTime := time.Now()
q.push(func(d time.Duration) {
defer wg.Done()
assertTrue(t, d >= 0, "Callback duration should be positive")
})

// Wait for the callback to finish.
wg.Wait()

// Ensure the callback executed quickly.
elapsed := time.Since(startTime)
assertTrue(t, elapsed < 100*time.Millisecond, "Callback should be dispatched immediately")
}

func TestCbQueue_OrderPreservation(t *testing.T) {
q := newTestQueue()

// Start the dispatcher in a separate goroutine.
go q.dispatch()

var results []int
var mu sync.Mutex
expectedResults := []int{1, 2, 3}

for _, i := range expectedResults {
i := i
q.push(func(d time.Duration) {
mu.Lock()
defer mu.Unlock()
results = append(results, i)
})
}

// Allow time for the queue to process.
time.Sleep(100 * time.Millisecond)

mu.Lock()
defer mu.Unlock()

for i, r := range results {
assertEqual(t, expectedResults[i], r, "unexpected result")
}
}

func TestCbQueue_Close(t *testing.T) {
q := newTestQueue()

go q.dispatch()

var executed bool
q.push(func(d time.Duration) {
executed = true
})

q.close()

// Ensure the closeCh channel is closed.
select {
case <-q.closeCh:
// Channel was closed as expected.
case <-time.After(1 * time.Second):
t.Fatal("closeCh was not closed after queue close")
}

assertTrue(t, executed, "Callback should be executed before close")
}

func TestCbQueue_IgnorePushAfterClose(t *testing.T) {
q := newTestQueue()
go q.dispatch()
q.close()

var executed bool
q.push(func(d time.Duration) {
executed = true
})

// Allow some time to see if the callback is executed.
time.Sleep(100 * time.Millisecond)

assertTrue(t, !executed, "Callback should not be executed after queue close")
}

func TestCbQueue_PushNilCallbackPanics(t *testing.T) {
q := newTestQueue()

defer func() {
if r := recover(); r == nil {
t.Fatal("Expected panic when pushing nil callback with close set to false")
}
}()

q.pushOrClose(nil, false)
}
Loading