diff --git a/internal/examples/server/data/agent.go b/internal/examples/server/data/agent.go index e920f030..b562bd3e 100644 --- a/internal/examples/server/data/agent.go +++ b/internal/examples/server/data/agent.go @@ -25,8 +25,6 @@ type Agent struct { // Connection to the Agent. conn types.Connection - // Mutex to protect Send() operation. - connMutex sync.Mutex // mutex for the fields that follow it. mux sync.RWMutex @@ -421,9 +419,6 @@ func (agent *Agent) calcConnectionSettings(response *protobufs.ServerToAgent) { } func (agent *Agent) SendToAgent(msg *protobufs.ServerToAgent) { - agent.connMutex.Lock() - defer agent.connMutex.Unlock() - agent.conn.Send(context.Background(), msg) } diff --git a/server/serverimpl.go b/server/serverimpl.go index 5fc7cac0..cd19d68a 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -8,6 +8,7 @@ import ( "io" "net" "net/http" + "sync" "github.com/gorilla/websocket" "google.golang.org/protobuf/proto" @@ -179,7 +180,7 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) { } func (s *server) handleWSConnection(wsConn *websocket.Conn, connectionCallbacks serverTypes.ConnectionCallbacks) { - agentConn := wsConnection{wsConn: wsConn} + agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}} defer func() { // Close the connection when all is done. diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index 7ebaaa32..04827d38 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "sync/atomic" "testing" "time" @@ -723,3 +724,124 @@ func TestDecodeMessage(t *testing.T) { } } } + +func TestConnectionAllowsConcurrentWrites(t *testing.T) { + srvConnVal := atomic.Value{} + callbacks := CallbacksStruct{ + OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ + OnConnectedFunc: func(conn types.Connection) { + srvConnVal.Store(conn) + }, + }} + }, + } + + // Start a Server. + settings := &StartSettings{Settings: Settings{Callbacks: callbacks}} + srv := startServer(t, settings) + defer srv.Stop(context.Background()) + + // Connect to the Server. + conn, _, err := dialClient(settings) + + // Verify that the connection is successful. + assert.NoError(t, err) + assert.NotNil(t, conn) + + defer conn.Close() + + timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + select { + case <-timeout.Done(): + t.Error("Client failed to connect before timeout") + default: + if _, ok := srvConnVal.Load().(types.Connection); ok == true { + break + } + } + + cancel() + + srvConn := srvConnVal.Load().(types.Connection) + for i := 0; i < 20; i++ { + go func() { + defer func() { + if recover() != nil { + require.Fail(t, "Sending to client panicked") + } + }() + + srvConn.Send(context.Background(), &protobufs.ServerToAgent{}) + }() + } +} + +func BenchmarkSendToClient(b *testing.B) { + clientConnections := []*websocket.Conn{} + serverConnections := []types.Connection{} + srvConnectionsMutex := sync.Mutex{} + callbacks := CallbacksStruct{ + OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ + OnConnectedFunc: func(conn types.Connection) { + srvConnectionsMutex.Lock() + serverConnections = append(serverConnections, conn) + srvConnectionsMutex.Unlock() + }, + }} + }, + } + + // Start a Server. + settings := &StartSettings{ + Settings: Settings{Callbacks: callbacks}, + ListenEndpoint: testhelpers.GetAvailableLocalAddress(), + ListenPath: "/", + } + srv := New(&sharedinternal.NopLogger{}) + err := srv.Start(*settings) + + if err != nil { + b.Error(err) + } + + defer srv.Stop(context.Background()) + + for i := 0; i < b.N; i++ { + conn, resp, err := dialClient(settings) + + if err != nil || resp == nil || conn == nil { + b.Error("Could not establish connection:", err) + } + + clientConnections = append(clientConnections, conn) + } + + timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + select { + case <-timeout.Done(): + b.Error("Connections failed to establish in time") + default: + if len(serverConnections) == b.N { + break + } + } + + cancel() + + for _, conn := range serverConnections { + err := conn.Send(context.Background(), &protobufs.ServerToAgent{}) + + if err != nil { + b.Error(err) + } + } + + for _, conn := range clientConnections { + conn.Close() + } + +} diff --git a/server/wsconnection.go b/server/wsconnection.go index 9ce53a49..b2b99219 100644 --- a/server/wsconnection.go +++ b/server/wsconnection.go @@ -3,6 +3,7 @@ package server import ( "context" "net" + "sync" "github.com/gorilla/websocket" @@ -13,7 +14,11 @@ import ( // wsConnection represents a persistent OpAMP connection over a WebSocket. type wsConnection struct { - wsConn *websocket.Conn + // The websocket library does not allow multiple concurrent write operations, + // so ensure that we only have a single operation in progress at a time. + // For more: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency + connMutex *sync.Mutex + wsConn *websocket.Conn } var _ types.Connection = (*wsConnection)(nil) @@ -22,10 +27,10 @@ func (c wsConnection) Connection() net.Conn { return c.wsConn.UnderlyingConn() } -// Message header is currently uint64 zero value. -const wsMsgHeader = uint64(0) - func (c wsConnection) Send(_ context.Context, message *protobufs.ServerToAgent) error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + return internal.WriteWSMessage(c.wsConn, message) }