diff --git a/pkg/bindingsdriver/driver.go b/pkg/bindingsdriver/driver.go index 66dfd929..7cc93746 100644 --- a/pkg/bindingsdriver/driver.go +++ b/pkg/bindingsdriver/driver.go @@ -59,9 +59,11 @@ type ConnectionHandler func(net.Conn) error type bindingsListener struct { listener net.Listener - stop chan struct{} cnxnHandler ConnectionHandler log logr.Logger + + stopOnce sync.Once + stop chan struct{} } func newBindingsListener(address string, cnxnHandler ConnectionHandler) (*bindingsListener, error) { @@ -72,8 +74,8 @@ func newBindingsListener(address string, cnxnHandler ConnectionHandler) (*bindin bl := &bindingsListener{ listener: l, - stop: make(chan struct{}), cnxnHandler: cnxnHandler, + stop: make(chan struct{}), } go bl.run() @@ -83,18 +85,16 @@ func newBindingsListener(address string, cnxnHandler ConnectionHandler) (*bindin // Stop stops the listener. It is safe to call stop multiple times. func (b *bindingsListener) Stop() { - select { - case b.stop <- struct{}{}: - close(b.stop) - default: - } + b.stopOnce.Do(func() { + b.listener.Close() + b.stop <- struct{}{} + }) } func (b *bindingsListener) run() { for { select { case <-b.stop: - b.listener.Close() return default: } diff --git a/pkg/bindingsdriver/driver_test.go b/pkg/bindingsdriver/driver_test.go index 4dd1d116..2a8ecb3b 100644 --- a/pkg/bindingsdriver/driver_test.go +++ b/pkg/bindingsdriver/driver_test.go @@ -6,6 +6,7 @@ import ( "math/rand/v2" "net" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -35,7 +36,7 @@ func TestBindingsListener(t *testing.T) { assert.NotNil(t, bl) // test that we can connect to the listener - conn, err := net.Dial("tcp", loopbackAddr(port)) + conn, err := net.DialTimeout("tcp", loopbackAddr(port), 10*time.Millisecond) assert.NoError(t, err) out, err := io.ReadAll(conn) @@ -44,6 +45,13 @@ func TestBindingsListener(t *testing.T) { assert.Equal(t, "hello world", string(out)) assert.NotPanics(t, func() { bl.Stop() }) + + // test that we can't connect to the listener after it's stopped + conn, err = net.DialTimeout("tcp", loopbackAddr(port), 10*time.Millisecond) + assert.Error(t, err) + assert.Nil(t, conn) + + // test that we can stop the listener multiple times assert.NotPanics(t, func() { bl.Stop() }) }