diff --git a/protocol/app/flags/flags.go b/protocol/app/flags/flags.go index 1a0ad084e1..cd1cfa80f7 100644 --- a/protocol/app/flags/flags.go +++ b/protocol/app/flags/flags.go @@ -72,8 +72,8 @@ const ( DefaultGrpcStreamingEnabled = false DefaultGrpcStreamingFlushIntervalMs = 50 - DefaultGrpcStreamingMaxBatchSize = 10000 - DefaultGrpcStreamingMaxChannelBufferSize = 10000 + DefaultGrpcStreamingMaxBatchSize = 100_000 + DefaultGrpcStreamingMaxChannelBufferSize = 100_000 DefaultWebsocketStreamingEnabled = false DefaultWebsocketStreamingPort = 9092 DefaultFullNodeStreamingSnapshotInterval = 0 diff --git a/protocol/app/flags/flags_test.go b/protocol/app/flags/flags_test.go index 8b3ed1a8f4..b4107335b1 100644 --- a/protocol/app/flags/flags_test.go +++ b/protocol/app/flags/flags_test.go @@ -257,8 +257,8 @@ func TestGetFlagValuesFromOptions(t *testing.T) { expectedGrpcEnable: true, expectedGrpcStreamingEnable: false, expectedGrpcStreamingFlushMs: 50, - expectedGrpcStreamingBatchSize: 10000, - expectedGrpcStreamingMaxChannelBufferSize: 10000, + expectedGrpcStreamingBatchSize: 100_000, + expectedGrpcStreamingMaxChannelBufferSize: 100_000, expectedWebsocketEnabled: false, expectedWebsocketPort: 9092, expectedFullNodeStreamingSnapshotInterval: 0, diff --git a/protocol/streaming/ws/websocket_server.go b/protocol/streaming/ws/websocket_server.go index 3cbb9219a4..33a7434e42 100644 --- a/protocol/streaming/ws/websocket_server.go +++ b/protocol/streaming/ws/websocket_server.go @@ -19,6 +19,8 @@ import ( const ( CLOB_PAIR_IDS_QUERY_PARAM = "clobPairIds" MARKET_IDS_QUERY_PARAM = "marketIds" + + CLOSE_DEADLINE = 5 * time.Second ) var upgrader = websocket.Upgrader{ @@ -68,33 +70,30 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) { // Parse clobPairIds from query parameters clobPairIds, err := parseUint32(r, CLOB_PAIR_IDS_QUERY_PARAM) if err != nil { - ws.logger.Error( - "Error parsing clobPairIds", - "err", err, - ) - http.Error(w, err.Error(), http.StatusBadRequest) + ws.logger.Error("Error parsing clobPairIds", "err", err) + if err := sendCloseWithReason(conn, websocket.CloseUnsupportedData, err.Error()); err != nil { + ws.logger.Error("Error sending close message", "err", err) + } return } // Parse marketIds from query parameters marketIds, err := parseUint32(r, MARKET_IDS_QUERY_PARAM) if err != nil { - ws.logger.Error( - "Error parsing marketIds", - "err", err, - ) - http.Error(w, err.Error(), http.StatusBadRequest) + ws.logger.Error("Error parsing marketIds", "err", err) + if err := sendCloseWithReason(conn, websocket.CloseUnsupportedData, err.Error()); err != nil { + ws.logger.Error("Error sending close message", "err", err) + } return } // Parse subaccountIds from query parameters subaccountIds, err := parseSubaccountIds(r) if err != nil { - ws.logger.Error( - "Error parsing subaccountIds", - "err", err, - ) - http.Error(w, err.Error(), http.StatusBadRequest) + ws.logger.Error("Error parsing subaccountIds", "err", err) + if err := sendCloseWithReason(conn, websocket.CloseUnsupportedData, err.Error()); err != nil { + ws.logger.Error("Error sending close message", "err", err) + } return } @@ -118,10 +117,26 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) { "Ending handler for websocket connection", "err", err, ) + if err := sendCloseWithReason(conn, websocket.CloseInternalServerErr, err.Error()); err != nil { + ws.logger.Error("Error sending close message", "err", err) + } return } } +func sendCloseWithReason(conn *websocket.Conn, closeCode int, reason string) error { + closeMessage := websocket.FormatCloseMessage(closeCode, reason) + // Set a write deadline to avoid blocking indefinitely + if err := conn.SetWriteDeadline(time.Now().Add(CLOSE_DEADLINE)); err != nil { + return err + } + return conn.WriteControl( + websocket.CloseMessage, + closeMessage, + time.Now().Add(CLOSE_DEADLINE), + ) +} + // parseSubaccountIds is a helper function to parse the subaccountIds from the query parameters. func parseSubaccountIds(r *http.Request) ([]*satypes.SubaccountId, error) { subaccountIdsParam := r.URL.Query().Get("subaccountIds")