Skip to content

Commit

Permalink
feat: Supports callbacks when reading a message fails
Browse files Browse the repository at this point in the history
  • Loading branch information
tttoad committed Feb 10, 2025
1 parent 2ecac8d commit 8335e42
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 22 deletions.
46 changes: 29 additions & 17 deletions server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -226,27 +227,39 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co
// Loop until fail to read from the WebSocket connection.
for {
msgContext := context.Background()
request := protobufs.AgentToServer{}

// Block until the next message can be read.
mt, msgBytes, err := wsConn.ReadMessage()
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)
break
isBreak, err := func() (bool, error) {
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)
return true, err
}

Check warning on line 239 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L237-L239

Added lines #L237 - L239 were not covered by tests
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
return true, err
}
if mt != websocket.BinaryMessage {
err = fmt.Errorf("Received unexpected message type from WebSocket: %v", mt)
s.logger.Errorf(msgContext, err.Error())
return false, err
}
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
break
}
if mt != websocket.BinaryMessage {
s.logger.Errorf(msgContext, "Received unexpected message type from WebSocket: %v", mt)
continue
}

// Decode WebSocket message as a Protobuf message.
var request protobufs.AgentToServer
err = internal.DecodeWSMessage(msgBytes, &request)
// Decode WebSocket message as a Protobuf message.
err = internal.DecodeWSMessage(msgBytes, &request)
if err != nil {
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)
return false, err
}

Check warning on line 255 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L253-L255

Added lines #L253 - L255 were not covered by tests
return false, nil
}()
if err != nil {
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)
connectionCallbacks.OnReadMessageError(agentConn, mt, msgBytes, err)
if isBreak {
break
}
continue
}

Expand Down Expand Up @@ -366,7 +379,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
w.Header().Set(headerContentEncoding, contentEncodingGzip)
}
_, err = w.Write(bodyBytes)

if err != nil {
s.logger.Debugf(req.Context(), "Cannot send HTTP response: %v", err)
}
Expand Down
51 changes: 46 additions & 5 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,52 @@ func TestServerReceiveSendMessage(t *testing.T) {
assert.EqualValues(t, settings.CustomCapabilities, response.CustomCapabilities.Capabilities)
}

func TestServerReceiveSendErrorMessage(t *testing.T) {
var rcvMsg atomic.Value
type ErrorInfo struct {
mt int
msgByte []byte
err error
}
callbacks := types.Callbacks{
OnConnecting: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{
OnReadMessageError: func(conn types.Connection, mt int, msgByte []byte, err error) {
rcvMsg.Store(ErrorInfo{
mt: mt,
msgByte: msgByte,
err: err,
})
},
}}
},
}

// Start a Server.
settings := &StartSettings{Settings: Settings{
Callbacks: callbacks,
CustomCapabilities: []string{"local.test.capability"},
}}
srv := startServer(t, settings)
defer srv.Stop(context.Background())

// Connect using a WebSocket client.
conn, _, _ := dialClient(settings)
require.NotNil(t, conn)
defer conn.Close()

// Send a message to the Server.
err := conn.WriteMessage(websocket.TextMessage, []byte("abc"))
require.NoError(t, err)

// Wait until Server receives the message.
eventually(t, func() bool { return rcvMsg.Load() != nil })
errInfo := rcvMsg.Load().(ErrorInfo)
assert.EqualValues(t, websocket.TextMessage, errInfo.mt)
assert.EqualValues(t, []byte("abc"), errInfo.msgByte)
assert.NotNil(t, errInfo.err)
}

func TestServerReceiveSendMessageWithCompression(t *testing.T) {
// Use highly compressible config body.
uncompressedCfg := []byte(strings.Repeat("test", 10000))
Expand Down Expand Up @@ -620,7 +666,6 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) {
}

func TestServerHonoursClientRequestContentEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -698,7 +743,6 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) {
}

func TestServerHonoursAcceptEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -985,7 +1029,6 @@ func BenchmarkSendToClient(b *testing.B) {
}
srv := New(&sharedinternal.NopLogger{})
err := srv.Start(*settings)

if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -1017,7 +1060,6 @@ func BenchmarkSendToClient(b *testing.B) {

for _, conn := range serverConnections {
err := conn.Send(context.Background(), &protobufs.ServerToAgent{})

if err != nil {
b.Error(err)
}
Expand All @@ -1026,5 +1068,4 @@ func BenchmarkSendToClient(b *testing.B) {
for _, conn := range clientConnections {
conn.Close()
}

}
9 changes: 9 additions & 0 deletions server/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ type ConnectionCallbacks struct {

// OnConnectionClose is called when the OpAMP connection is closed.
OnConnectionClose func(conn Connection)

// OnReadMessageError is called when an error occurs while reading or deserializing a message.
OnReadMessageError func(conn Connection, mt int, msgByte []byte, err error)
}

func defaultOnConnected(ctx context.Context, conn Connection) {}
Expand All @@ -77,6 +80,8 @@ func defaultOnMessage(

func defaultOnConnectionClose(conn Connection) {}

func defaultOnReadMessageError(conn Connection, mt int, msgByte []byte, err error) {}

func (c *ConnectionCallbacks) SetDefaults() {
if c.OnConnected == nil {
c.OnConnected = defaultOnConnected
Expand All @@ -89,4 +94,8 @@ func (c *ConnectionCallbacks) SetDefaults() {
if c.OnConnectionClose == nil {
c.OnConnectionClose = defaultOnConnectionClose
}

if c.OnReadMessageError == nil {
c.OnReadMessageError = defaultOnReadMessageError
}
}

0 comments on commit 8335e42

Please sign in to comment.