Skip to content

Commit

Permalink
Merge pull request #151 from ngrok/josh/fix-heartbeat-panic
Browse files Browse the repository at this point in the history
internal: fix a panic due to a racy Close routine
  • Loading branch information
jrobsonchase authored Jan 8, 2024
2 parents 9d060ca + b44c85d commit 8237d67
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 15 deletions.
49 changes: 35 additions & 14 deletions internal/tunnel/client/raw_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ type SessionHandler interface {
// When RawSession.Accept() returns an error, that means the session is dead.
// Client sessions run over a muxado session.
type rawSession struct {
mux *muxado.Heartbeat // the muxado session we're multiplexing streams over
id string // session id for logging purposes
handler SessionHandler // callbacks to allow the application to handle requests from the server
latency chan time.Duration
closeLatencyOnce sync.Once
mux *muxado.Heartbeat // the muxado session we're multiplexing streams over
id string // session id for logging purposes
handler SessionHandler // callbacks to allow the application to handle requests from the server
latency chan time.Duration
closed bool
closedLock sync.RWMutex
log.Logger
}

Expand Down Expand Up @@ -230,10 +231,19 @@ func (s *rawSession) respFunc(raw net.Conn) func(v any) error {
}

func (s *rawSession) Close() error {
s.closeLatencyOnce.Do(func() {
// Close the muxado heartbeat session. After this, the goroutine calling the
// callback handler should exit.
err := s.mux.Close()

// Prevent sending on a closed channel in the callback handler by ensuring
// exclusive access to the channel and the closed boolean here.
s.closedLock.Lock()
defer s.closedLock.Unlock()
if !s.closed {
s.closed = true
close(s.latency)
})
return s.mux.Close()
}
return err
}

// This is essentially the RPC protocol. The request and response are just JSON
Expand Down Expand Up @@ -271,12 +281,23 @@ func (s *rawSession) onHeartbeat(pingTime time.Duration, timeout bool) {
if timeout {
s.Error("heartbeat timeout, terminating session")
s.Close()
} else {
s.Debug("heartbeat received", "latency_ms", int(pingTime.Milliseconds()))
select {
case s.latency <- pingTime:
default:
}
return
}

// make sure we don't send on a closed channel.
// Any number of `onHeartbeat` callbacks can be in flight at a given time,
// but only one Close.
s.closedLock.RLock()
defer s.closedLock.RUnlock()

if s.closed {
return
}

s.Debug("heartbeat received", "latency_ms", int(pingTime.Milliseconds()))
select {
case s.latency <- pingTime:
default:
}
}

Expand Down
54 changes: 54 additions & 0 deletions internal/tunnel/client/raw_session_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package client

import (
"context"
"sync"
"testing"
"time"

"github.com/inconshreveable/log15/v3"

Expand All @@ -21,3 +24,54 @@ func TestRawSessionDoubleClose(t *testing.T) {
r.Close()
r.Close()
}

func TestHeartbeatTimeout(t *testing.T) {
r := NewRawSession(log15.New(), muxado.Client(&dummyStream{}, nil), nil, nil)
// Make sure we don't deadlock
r.(*rawSession).onHeartbeat(1, true)
}

func TestRawSessionCloseRace(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
defer cancel()

// Since this is a race condition, run the test as many times as we can
// within the timebox to see if we can hit it.
testloop:
for {
select {
case <-ctx.Done():
break testloop
default:
}

ctx, cancel := context.WithCancel(ctx)
r := NewRawSession(log15.New(), muxado.Client(&dummyStream{}, nil), nil, nil)

wg := sync.WaitGroup{}
wg.Add(1)

// Call onHeartbeat as fast as we can in the background.
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
r.(*rawSession).onHeartbeat(time.Millisecond*1, false)
}
}()

// Verify that closing the session while a heartbeat is in flight won't
// cause a panic
r.Close()

cancel()

// Wait till the heartbeat goroutine exists to make sure we capture the
// panic and it doesn't occur after the test completes.
wg.Wait()
}
}
3 changes: 2 additions & 1 deletion tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ func (t *tunnelImpl) CloseWithContext(_ context.Context) error {
return err
}
}
return t.Tunnel.Close()
err := t.Tunnel.Close()
return err
}

func (t *tunnelImpl) Addr() net.Addr {
Expand Down

0 comments on commit 8237d67

Please sign in to comment.