diff --git a/docs/src/content/docs/commands/QWATCH.md b/docs/src/content/docs/commands/QWATCH.md index 5f36ac173..a8408b32f 100644 --- a/docs/src/content/docs/commands/QWATCH.md +++ b/docs/src/content/docs/commands/QWATCH.md @@ -1,9 +1,9 @@ --- -title: QWATCH -description: The `QWATCH` command is a novel feature designed to provide real-time updates to clients based on changes in underlying data. +title: Q.WATCH +description: The `Q.WATCH` command is a novel feature designed to provide real-time updates to clients based on changes in underlying data. --- -The `QWATCH` command is a novel feature designed to provide real-time updates to clients based on changes in underlying +The `Q.WATCH` command is a novel feature designed to provide real-time updates to clients based on changes in underlying data. It operates similarly to the `SUBSCRIBE` command but focuses on SQL-like queries over data structures. Whenever data modifications affect the query's results, the updated result set is pushed to the subscribed client. This eliminates the need for clients to constantly poll for changes. @@ -13,16 +13,16 @@ to build real-time reactive applications like leaderboards. ## Protocol Support -| Protocol | Supported | -|-----------|-----------| -| TCP-RESP | ✅ | -| HTTP | ✅ | -| WebSocket | ❌ | +| Protocol | Supported | +| -------- | --------- | +| TCP-RESP | ✅ | +| HTTP | ✅ | +| WebSocket| ✅ | ## Syntax ```bash -QWATCH +Q.WATCH ``` ## Parameters @@ -96,7 +96,7 @@ Supported conditions: 1. `Missing query` - - Error Message: `(error) ERROR wrong number of arguments for 'qwatch' command` + - Error Message: `(error) ERROR wrong number of arguments for 'q.watch' command` - Occurs if no DSQL Query is provided. 2. `Invalid query`: @@ -113,12 +113,12 @@ Supported conditions: ### Basic Usage -Let's explore a practical example of using the `QWATCH` command to create a real-time leaderboard for a game match, +Let's explore a practical example of using the `Q.WATCH` command to create a real-time leaderboard for a game match, including filtering with a `WHERE` clause. ```bash -127.0.0.1:7379> QWATCH "SELECT $key, $value WHERE $key like 'match:100:*' AND $value > 10 ORDER BY $value DESC LIMIT 3" -qwatch from SELECT $key, $value WHERE $key like 'match:100:*' AND $value > 10 ORDER BY $value asc: [] +127.0.0.1:7379> Q.WATCH "SELECT $key, $value WHERE $key like 'match:100:*' AND $value > 10 ORDER BY $value DESC LIMIT 3" +q.watch from SELECT $key, $value WHERE $key like 'match:100:*' AND $value > 10 ORDER BY $value asc: [] ``` This query does the following: @@ -133,7 +133,7 @@ This query does the following: Imagine we're tracking player scores in a game match with ID 100. Each player's score is stored in a key formatted as `match:100:user:`. -Let's walk through a series of updates and see how the `QWATCH` command responds. Please note +Let's walk through a series of updates and see how the `Q.WATCH` command responds. Please note that the response will be RESP encoded and parsing will be handled by the SDK that you are using. 1. Initial state (empty leaderboard): `[]` @@ -152,9 +152,9 @@ that the response will be RESP encoded and parsing will be handled by the SDK th 127.0.0.1:7379> SET match:100:user:1 15 ``` - QWATCH Response: + Q.WATCH Response: ```bash - qwatch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:1", "15"]]` + q.watch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:1", "15"]]` ``` 4. Player 2 scores 20 points: @@ -163,9 +163,9 @@ that the response will be RESP encoded and parsing will be handled by the SDK th 127.0.0.1:7379> SET match:100:user:2 20 ``` - QWATCH Response: + Q.WATCH Response: ```bash - qwatch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:2", "20"], ["match:100:user:1", "15"]]` + q.watch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:2", "20"], ["match:100:user:1", "15"]]` ``` 5. Player 3 scores 12 points: @@ -174,9 +174,9 @@ that the response will be RESP encoded and parsing will be handled by the SDK th 127.0.0.1:7379> SET match:100:user:3 12 ``` - QWATCH Response: + Q.WATCH Response: ```bash - qwatch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:2", "20"], ["match:100:user:1", "15"], ["match:100:user:3", "12"]]` + q.watch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:2", "20"], ["match:100:user:1", "15"], ["match:100:user:3", "12"]]` ``` 6. Player 4 scores 25 points: @@ -185,9 +185,9 @@ that the response will be RESP encoded and parsing will be handled by the SDK th 127.0.0.1:7379> SET match:100:user:4 25 ``` - QWATCH Response: + Q.WATCH Response: ```bash - qwatch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:4", "25"], ["match:100:user:2", "20"], ["match:100:user:1", "15"]]` + q.watch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:4", "25"], ["match:100:user:2", "20"], ["match:100:user:1", "15"]]` ``` 7. Player 0 improves their score to 30: @@ -196,12 +196,12 @@ that the response will be RESP encoded and parsing will be handled by the SDK th 127.0.0.1:7379> SET match:100:user:0 30 ``` - QWATCH Response: + Q.WATCH Response: ```bash - qwatch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:0", "30"], ["match:100:user:4", "25"], ["match:100:user:2", "20"]]` + q.watch from SELECT $key, $value WHERE $key like 'match:100:*' and $value > 100 ORDER BY $value asc: `[["match:100:user:0", "30"], ["match:100:user:4", "25"], ["match:100:user:2", "20"]]` ``` -This example demonstrates how `QWATCH` provides real-time updates as the leaderboard changes, always keeping clients +This example demonstrates how `Q.WATCH` provides real-time updates as the leaderboard changes, always keeping clients informed of the top 3 scores above 10, without the need for constant polling. ## Best Practices diff --git a/integration_tests/commands/websocket/main_test.go b/integration_tests/commands/websocket/main_test.go index eec1e6336..cc330a19d 100644 --- a/integration_tests/commands/websocket/main_test.go +++ b/integration_tests/commands/websocket/main_test.go @@ -25,7 +25,6 @@ func TestMain(m *testing.M) { Logger: l, } ctx, cancel := context.WithCancel(context.Background()) - defer cancel() RunWebsocketServer(ctx, &wg, opts) // Wait for the server to start @@ -36,10 +35,10 @@ func TestMain(m *testing.M) { // Run the test suite exitCode := m.Run() - // abort conn := executor.ConnectToServer() executor.FireCommand(conn, "abort") + cancel() wg.Wait() os.Exit(exitCode) } diff --git a/integration_tests/commands/websocket/qwatch_test.go b/integration_tests/commands/websocket/qwatch_test.go new file mode 100644 index 000000000..27c8009f0 --- /dev/null +++ b/integration_tests/commands/websocket/qwatch_test.go @@ -0,0 +1,53 @@ +package websocket + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQWatch(t *testing.T) { + exec := NewWebsocketCommandExecutor() + conn := exec.ConnectToServer() + + testCases := []struct { + name string + cmds []string + expect interface{} + }{ + { + name: "Wrong number of arguments", + cmds: []string{"Q.WATCH "}, + expect: "ERR wrong number of arguments for 'q.watch' command", + }, + { + name: "Invalid query", + cmds: []string{"Q.WATCH \"SELECT \""}, + expect: "error parsing SQL statement: syntax error at position 8", + }, + // TODO - once following query is registered, websocket will also attempt sending updates + // while keys are set for other tests in this package + // Add unregister test case to handle this scenario once qunwatch support is added + { + name: "Successful register", + cmds: []string{`Q.WATCH "SELECT $key, $value WHERE $key like 'test-key?'"`}, + expect: []interface{}{"q.watch", "SELECT $key, $value WHERE $key like 'test-key?'", []interface{}{}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, cmd := range tc.cmds { + result, err := exec.FireCommandAndReadResponse(conn, cmd) + assert.Nil(t, err) + if _, ok := tc.expect.(string); ok { + // compare strings + assert.Equal(t, tc.expect, result, "Value mismatch for cmd %s", cmd) + } else { + // compare lists + assert.ElementsMatch(t, tc.expect, result, "Value mismatch for cmd %s", cmd) + } + } + }) + } +} diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go index f4292f6c1..485ca5183 100644 --- a/integration_tests/commands/websocket/setup.go +++ b/integration_tests/commands/websocket/setup.go @@ -12,6 +12,7 @@ import ( "github.com/dicedb/dice/config" derrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/querymanager" "github.com/dicedb/dice/internal/server" "github.com/dicedb/dice/internal/shard" dstore "github.com/dicedb/dice/internal/store" @@ -108,7 +109,9 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO globalErrChannel := make(chan error) watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Performance.WatchChanBufSize) shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel, opt.Logger) - testServer := server.NewWebSocketServer(shardManager, watchChan, testPort1, opt.Logger) + queryWatcherLocal := querymanager.NewQueryManager(opt.Logger) + config.WebsocketPort = opt.Port + testServer := server.NewWebSocketServer(shardManager, testPort1, opt.Logger) shardManagerCtx, cancelShardManager := context.WithCancel(ctx) // run shard manager @@ -118,6 +121,13 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO shardManager.Run(shardManagerCtx) }() + // run query manager + wg.Add(1) + go func() { + defer wg.Done() + queryWatcherLocal.Run(ctx, watchChan) + }() + // start websocket server wg.Add(1) go func() { diff --git a/internal/eval/eval.go b/internal/eval/eval.go index 6c1a02521..9b6168871 100644 --- a/internal/eval/eval.go +++ b/internal/eval/eval.go @@ -1958,7 +1958,7 @@ func evalMULTI(args []string, store *dstore.Store) []byte { // Every time a key in the watch list is modified, the client will be sent a response // containing the new value of the key along with the operation that was performed on it. // Contains only one argument, the query to be watched. -func EvalQWATCH(args []string, httpOp bool, client *comm.Client, store *dstore.Store) []byte { +func EvalQWATCH(args []string, httpOp, websocketOp bool, client *comm.Client, store *dstore.Store) []byte { if len(args) != 1 { return diceerrors.NewErrArity("Q.WATCH") } @@ -1977,7 +1977,7 @@ func EvalQWATCH(args []string, httpOp bool, client *comm.Client, store *dstore.S }) var watchSubscription querymanager.QuerySubscription - if httpOp { + if httpOp || websocketOp { watchSubscription = querymanager.QuerySubscription{ Subscribe: true, Query: query, diff --git a/internal/eval/execute.go b/internal/eval/execute.go index 2fcb7589c..35cdda04e 100644 --- a/internal/eval/execute.go +++ b/internal/eval/execute.go @@ -17,15 +17,6 @@ func ExecuteCommand(c *cmd.DiceDBCmd, client *comm.Client, store *dstore.Store, return &EvalResponse{Result: diceerrors.NewErrWithFormattedMessage("unknown command '%s', with args beginning with: %s", c.Cmd, strings.Join(c.Args, " ")), Error: nil} } - // Till the time we refactor to handle QWATCH differently for websocket - if websocketOp { - if diceCmd.IsMigrated { - return diceCmd.NewEval(c.Args, store) - } - - return &EvalResponse{Result: diceCmd.Eval(c.Args, store), Error: nil} - } - // Temporary logic till we move all commands to new eval logic. // MigratedDiceCmds map contains refactored eval commands // For any command we will first check in the existing map @@ -40,7 +31,7 @@ func ExecuteCommand(c *cmd.DiceDBCmd, client *comm.Client, store *dstore.Store, // Old implementation kept as it is, but we will be moving // to the new implementation soon for all commands case "SUBSCRIBE", "Q.WATCH": - return &EvalResponse{Result: EvalQWATCH(c.Args, httpOp, client, store), Error: nil} + return &EvalResponse{Result: EvalQWATCH(c.Args, httpOp, websocketOp, client, store), Error: nil} case "UNSUBSCRIBE", "Q.UNWATCH": return &EvalResponse{Result: EvalQUNWATCH(c.Args, httpOp, client), Error: nil} case auth.Cmd: diff --git a/internal/server/utils/redisCmdAdapter.go b/internal/server/utils/redisCmdAdapter.go index 6edb7d945..ce19a33bc 100644 --- a/internal/server/utils/redisCmdAdapter.go +++ b/internal/server/utils/redisCmdAdapter.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "github.com/dicedb/dice/internal/cmd" @@ -33,6 +34,8 @@ const ( JSON = "json" ) +const QWatch string = "Q.WATCH" + func ParseHTTPRequest(r *http.Request) (*cmd.DiceDBCmd, error) { commandParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/") if len(commandParts) == 0 { @@ -117,16 +120,39 @@ func ParseHTTPRequest(r *http.Request) (*cmd.DiceDBCmd, error) { } func ParseWebsocketMessage(msg []byte) (*cmd.DiceDBCmd, error) { - cmdStr := string(msg) - cmdStr = strings.TrimSpace(cmdStr) - + cmdStr := strings.TrimSpace(string(msg)) if cmdStr == "" { return nil, diceerrors.ErrEmptyCommand } - cmdArr := strings.Split(cmdStr, " ") - command := strings.ToUpper(cmdArr[0]) - cmdArr = cmdArr[1:] // args + var command string + idx := strings.Index(cmdStr, " ") + // handle commands with no args + if idx == -1 { + command = strings.ToUpper(cmdStr) + return &cmd.DiceDBCmd{ + Cmd: command, + Args: nil, + }, nil + } + + // handle commands with args + command = strings.ToUpper(cmdStr[:idx]) + cmdStr = cmdStr[idx+1:] + + var cmdArr []string // args + // handle qwatch commands + if command == QWatch { + // remove quotes from query string + cmdStr, err := strconv.Unquote(cmdStr) + if err != nil { + return nil, fmt.Errorf("error parsing q.watch query: %v", err) + } + cmdArr = []string{cmdStr} + } else { + // handle other commands + cmdArr = strings.Split(cmdStr, " ") + } // if key prefix is empty for JSON.INGEST command // add "" to cmdArr diff --git a/internal/server/utils/redisCmdAdapter_test.go b/internal/server/utils/redisCmdAdapter_test.go index e9f9b9fe1..929b3b1f9 100644 --- a/internal/server/utils/redisCmdAdapter_test.go +++ b/internal/server/utils/redisCmdAdapter_test.go @@ -248,12 +248,114 @@ func TestParseWebsocketMessage(t *testing.T) { expectedCmd: "SET", expectedArgs: []string{"k1", "v1", "nx"}, }, + { + name: "Test SET command with value as a map", + message: `set k0 {"k1":"v1"} nx`, + expectedCmd: "SET", + expectedArgs: []string{"k0", `{"k1":"v1"}`, "nx"}, + }, + { + name: "Test SET command with value as an array", + message: `set k1 ["v1","v2","v3"] nx`, + expectedCmd: "SET", + expectedArgs: []string{"k1", `["v1","v2","v3"]`, "nx"}, + }, + { + name: "Test SET command with value as a map containing an array", + message: `set k1 {"k2":["v1","v2"]} nx`, + expectedCmd: "SET", + expectedArgs: []string{"k1", `{"k2":["v1","v2"]}`, "nx"}, + }, + { + name: "Test SET command with value as a deeply nested map", + message: `set k1 {"k2":{"k3":{"k4":"value"}}} nx`, + expectedCmd: "SET", + expectedArgs: []string{"k1", `{"k2":{"k3":{"k4":"value"}}}`, "nx"}, + }, + { + name: "Test SET command with value as an array of maps", + message: `set k0 [{"k1":"v1"},{"k2":"v2"}] nx`, + expectedCmd: "SET", + expectedArgs: []string{"k0", `[{"k1":"v1"},{"k2":"v2"}]`, "nx"}, + }, { name: "Test GET command", message: "get k1", expectedCmd: "GET", expectedArgs: []string{"k1"}, }, + { + name: "Test DEL command", + message: "del k1", + expectedCmd: "DEL", + expectedArgs: []string{"k1"}, + }, + { + name: "Test DEL command with multiple keys", + message: `del k1 k2 k3`, + expectedCmd: "DEL", + expectedArgs: []string{"k1", "k2", "k3"}, + }, + { + name: "Test KEYS command", + message: "keys *", + expectedCmd: "KEYS", + expectedArgs: []string{"*"}, + }, + { + name: "Test MSET command", + message: "mset k1 v1 k2 v2", + expectedCmd: "MSET", + expectedArgs: []string{"k1", "v1", "k2", "v2"}, + }, + { + name: "Test MSET command with options", + message: "mset k1 v1 k2 v2 nx", + expectedCmd: "MSET", + expectedArgs: []string{"k1", "v1", "k2", "v2", "nx"}, + }, + { + name: "Test SLEEP command", + message: "sleep 1", + expectedCmd: "SLEEP", + expectedArgs: []string{"1"}, + }, + { + name: "Test PING command", + message: "ping", + expectedCmd: "PING", + expectedArgs: nil, + }, + { + name: "Test EXPIRE command", + message: "expire k1 1", + expectedCmd: "EXPIRE", + expectedArgs: []string{"k1", "1"}, + }, + { + name: "Test AUTH command", + message: "auth user password", + expectedCmd: "AUTH", + expectedArgs: []string{"user", "password"}, + }, + { + name: "Test LPUSH command", + message: "lpush k1 v1", + expectedCmd: "LPUSH", + expectedArgs: []string{"k1", "v1"}, + }, + { + name: "Test LPUSH command with multiple items", + message: `lpush k1 v1 v2 v3`, + expectedCmd: "LPUSH", + expectedArgs: []string{"k1", "v1", "v2", "v3"}, + }, + { + name: "Test JSON.ARRPOP command", + message: "json.arrpop k1 $ 1", + expectedCmd: "JSON.ARRPOP", + expectedArgs: []string{"k1", "$", "1"}, + }, { name: "Test JSON.SET command", message: `json.set k1 . {"field":"value"}`, @@ -284,6 +386,18 @@ func TestParseWebsocketMessage(t *testing.T) { expectedCmd: "JSON.INGEST", expectedArgs: []string{"", "$..field", `{"field":"value"}`}, }, + { + name: "Test simple Q.WATCH command", + message: "q.watch \"select $key, $value where $key like 'k?'\"", + expectedCmd: "Q.WATCH", + expectedArgs: []string{"select $key, $value where $key like 'k?'"}, + }, + { + name: "Test complex Q.WATCH command", + message: "q.watch \"SELECT $key, $value WHERE $key LIKE 'player:*' AND '$value.score' > 10 ORDER BY $value.score DESC LIMIT 5\"", + expectedCmd: "Q.WATCH", + expectedArgs: []string{"SELECT $key, $value WHERE $key LIKE 'player:*' AND '$value.score' > 10 ORDER BY $value.score DESC LIMIT 5"}, + }, } for _, tc := range commands { diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index 35698ace5..b6f1d1df6 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -16,12 +16,12 @@ import ( "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" + "github.com/dicedb/dice/internal/cmd" + "github.com/dicedb/dice/internal/comm" diceerrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/ops" - "github.com/dicedb/dice/internal/querymanager" "github.com/dicedb/dice/internal/server/utils" "github.com/dicedb/dice/internal/shard" - dstore "github.com/dicedb/dice/internal/store" "github.com/gorilla/websocket" "golang.org/x/exp/rand" ) @@ -31,23 +31,20 @@ const Qunwatch = "Q.UNWATCH" const Subscribe = "SUBSCRIBE" var unimplementedCommandsWebsocket = map[string]bool{ - Qwatch: true, - Qunwatch: true, - Subscribe: true, + Qunwatch: true, } type WebsocketServer struct { - querymanager *querymanager.Manager - shardManager *shard.ShardManager - ioChan chan *ops.StoreResponse - watchChan chan dstore.QueryWatchEvent - websocketServer *http.Server - upgrader websocket.Upgrader - logger *slog.Logger - shutdownChan chan struct{} + shardManager *shard.ShardManager + ioChan chan *ops.StoreResponse + websocketServer *http.Server + upgrader websocket.Upgrader + qwatchResponseChan chan comm.QwatchResponse + shutdownChan chan struct{} + logger *slog.Logger } -func NewWebSocketServer(shardManager *shard.ShardManager, watchChan chan dstore.QueryWatchEvent, port int, logger *slog.Logger) *WebsocketServer { +func NewWebSocketServer(shardManager *shard.ShardManager, port int, logger *slog.Logger) *WebsocketServer { mux := http.NewServeMux() srv := &http.Server{ Addr: fmt.Sprintf(":%d", port), @@ -60,23 +57,16 @@ func NewWebSocketServer(shardManager *shard.ShardManager, watchChan chan dstore. } websocketServer := &WebsocketServer{ - shardManager: shardManager, - querymanager: querymanager.NewQueryManager(logger), - ioChan: make(chan *ops.StoreResponse, 1000), - watchChan: watchChan, - websocketServer: srv, - upgrader: upgrader, - logger: logger, - shutdownChan: make(chan struct{}), + shardManager: shardManager, + ioChan: make(chan *ops.StoreResponse, 1000), + websocketServer: srv, + upgrader: upgrader, + qwatchResponseChan: make(chan comm.QwatchResponse), + shutdownChan: make(chan struct{}), + logger: logger, } mux.HandleFunc("/", websocketServer.WebsocketHandler) - mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("OK")) - if err != nil { - return - } - }) return websocketServer } @@ -96,7 +86,7 @@ func (s *WebsocketServer) Run(ctx context.Context) error { case <-ctx.Done(): case <-s.shutdownChan: err = diceerrors.ErrAborted - s.logger.Debug("Shutting down Websocket Server") + s.logger.Debug("Shutting down Websocket Server", slog.Any("time", time.Now())) } shutdownErr := s.websocketServer.Shutdown(websocketCtx) @@ -111,6 +101,9 @@ func (s *WebsocketServer) Run(ctx context.Context) error { defer wg.Done() s.logger.Info("Websocket Server running", slog.String("port", s.websocketServer.Addr[1:])) err = s.websocketServer.ListenAndServe() + if err != nil { + s.logger.Debug("Error in Websocket Server", slog.Any("time", time.Now()), slog.Any("error", err)) + } }() wg.Wait() @@ -130,6 +123,7 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques conn.Close() }() + maxRetries := config.DiceConfig.WebSocket.MaxWriteResponseRetries for { // read incoming message _, msg, err := conn.ReadMessage() @@ -147,88 +141,153 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques if errors.Is(err, diceerrors.ErrEmptyCommand) { continue } else if err != nil { - writeResponse(conn, []byte("error: parsing failed")) + if err := WriteResponseWithRetries(conn, []byte("error: parsing failed"), maxRetries); err != nil { + s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) + } continue } + // TODO - on abort, close client connection instead of closing server? if diceDBCmd.Cmd == Abort { close(s.shutdownChan) break } if unimplementedCommandsWebsocket[diceDBCmd.Cmd] { - writeResponse(conn, []byte("Command is not implemented with Websocket")) + if err := WriteResponseWithRetries(conn, []byte("Command is not implemented with Websocket"), maxRetries); err != nil { + s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) + } continue } - // send request to Shard Manager - s.shardManager.GetShard(0).ReqChan <- &ops.StoreOp{ + // create request + sp := &ops.StoreOp{ Cmd: diceDBCmd, WorkerID: "wsServer", ShardID: 0, WebsocketOp: true, } - // Wait for response - resp := <-s.ioChan + // handle q.watch commands + if diceDBCmd.Cmd == Qwatch || diceDBCmd.Cmd == Subscribe { + clientIdentifierID := generateUniqueInt32(r) + sp.Client = comm.NewHTTPQwatchClient(s.qwatchResponseChan, clientIdentifierID) - _, ok := WorkerCmdsMeta[diceDBCmd.Cmd] - respArr := []string{ - "(nil)", // Represents a RESP Nil Bulk String, which indicates a null value. - "OK", // Represents a RESP Simple String with value "OK". - "QUEUED", // Represents a Simple String indicating that a command has been queued. - "0", // Represents a RESP Integer with value 0. - "1", // Represents a RESP Integer with value 1. - "-1", // Represents a RESP Integer with value -1. - "-2", // Represents a RESP Integer with value -2. - "*0", // Represents an empty RESP Array. + // start a goroutine for subsequent updates + go s.processQwatchUpdates(clientIdentifierID, conn, diceDBCmd) } - var rp *clientio.RESPParser - var responseValue interface{} - // TODO: Remove this conditional check and if (true) condition when all commands are migrated - if !ok { - var err error - if resp.EvalResponse.Error != nil { - rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(resp.EvalResponse.Error.Error()))) - } else { - rp = clientio.NewRESPParser(bytes.NewBuffer(resp.EvalResponse.Result.([]byte))) - } + s.shardManager.GetShard(0).ReqChan <- sp + resp := <-s.ioChan + if err := s.processResponse(conn, diceDBCmd, resp); err != nil { + break + } + } +} - responseValue, err = rp.DecodeOne() - if err != nil { - s.logger.Error("Error decoding response", "error", err) - writeResponse(conn, []byte("error: Internal Server Error")) - return - } - } else { - if resp.EvalResponse.Error != nil { - responseValue = resp.EvalResponse.Error.Error() - } else { - responseValue = resp.EvalResponse.Result +func (s *WebsocketServer) processQwatchUpdates(clientIdentifierID uint32, conn *websocket.Conn, dicDBCmd *cmd.DiceDBCmd) { + for { + select { + case resp := <-s.qwatchResponseChan: + if resp.ClientIdentifierID == clientIdentifierID { + if err := s.processResponse(conn, dicDBCmd, resp); err != nil { + s.logger.Debug("Error writing response to client. Shutting down goroutine for q.watch updates", slog.Any("clientIdentifierID", clientIdentifierID), slog.Any("error", err)) + return + } } + case <-s.shutdownChan: + return } + } +} - if val, ok := responseValue.(clientio.RespType); ok { - responseValue = respArr[val] +func (s *WebsocketServer) processResponse(conn *websocket.Conn, diceDBCmd *cmd.DiceDBCmd, response interface{}) error { + var result interface{} + var err error + maxRetries := config.DiceConfig.WebSocket.MaxWriteResponseRetries + + // check response type + switch resp := response.(type) { + case comm.QwatchResponse: + result = resp.Result + err = resp.Error + case *ops.StoreResponse: + result = resp.EvalResponse.Result + err = resp.EvalResponse.Error + default: + s.logger.Debug("Unsupported response type") + if err := WriteResponseWithRetries(conn, []byte("error: 500 Internal Server Error"), maxRetries); err != nil { + s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) + return fmt.Errorf("error writing response: %v", err) } + return nil + } - if bt, ok := responseValue.([]byte); ok { - responseValue = string(bt) + _, ok := WorkerCmdsMeta[diceDBCmd.Cmd] + respArr := []string{ + "(nil)", // Represents a RESP Nil Bulk String, which indicates a null value. + "OK", // Represents a RESP Simple String with value "OK". + "QUEUED", // Represents a Simple String indicating that a command has been queued. + "0", // Represents a RESP Integer with value 0. + "1", // Represents a RESP Integer with value 1. + "-1", // Represents a RESP Integer with value -1. + "-2", // Represents a RESP Integer with value -2. + "*0", // Represents an empty RESP Array. + } + + var responseValue interface{} + // TODO: Remove this conditional check and if (true) condition when all commands are migrated + if !ok { + var rp *clientio.RESPParser + if err != nil { + rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(err.Error()))) + } else { + rp = clientio.NewRESPParser(bytes.NewBuffer(result.([]byte))) } - respBytes, err := json.Marshal(responseValue) + responseValue, err = rp.DecodeOne() if err != nil { - writeResponse(conn, []byte("error: marshaling json response")) - continue + s.logger.Debug("Error decoding response", "error", err) + if err := WriteResponseWithRetries(conn, []byte("error: 500 Internal Server Error"), maxRetries); err != nil { + s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) + return fmt.Errorf("error writing response: %v", err) + } + return nil } + } else { + if err != nil { + responseValue = err.Error() + } else { + responseValue = result + } + } + + if val, ok := responseValue.(clientio.RespType); ok { + responseValue = respArr[val] + } + + if bt, ok := responseValue.([]byte); ok { + responseValue = string(bt) + } - // Write response with retries for transient errors - if err := WriteResponseWithRetries(conn, respBytes, config.DiceConfig.WebSocket.MaxWriteResponseRetries); err != nil { - s.logger.Error(fmt.Sprintf("Error reading message: %v", err)) - break // Exit the loop on write error + respBytes, err := json.Marshal(responseValue) + if err != nil { + s.logger.Debug("Error marshaling json", "error", err) + if err := WriteResponseWithRetries(conn, []byte("error: marshaling json"), maxRetries); err != nil { + s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) + return fmt.Errorf("error writing response: %v", err) } + return nil + } + + // success + // Write response with retries for transient errors + if err := WriteResponseWithRetries(conn, respBytes, config.DiceConfig.WebSocket.MaxWriteResponseRetries); err != nil { + s.logger.Debug(fmt.Sprintf("Error writing message: %v", err)) + return fmt.Errorf("error writing response: %v", err) } + + return nil } func WriteResponseWithRetries(conn *websocket.Conn, text []byte, maxRetries int) error { @@ -285,16 +344,3 @@ func WriteResponseWithRetries(conn *websocket.Conn, text []byte, maxRetries int) return nil } - -func writeResponse(conn *websocket.Conn, text []byte) { - // Set a write deadline to prevent hanging - if err := conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil { - slog.Error(fmt.Sprintf("Error setting write deadline: %v", err)) - return - } - - err := conn.WriteMessage(websocket.TextMessage, text) - if err != nil { - slog.Error(fmt.Sprintf("Error writing response: %v", err)) - } -} diff --git a/main.go b/main.go index 6e54f5419..893513591 100644 --- a/main.go +++ b/main.go @@ -207,7 +207,7 @@ func main() { }() } - websocketServer := server.NewWebSocketServer(shardManager, queryWatchChan, config.WebsocketPort, logr) + websocketServer := server.NewWebSocketServer(shardManager, config.WebsocketPort, logr) serverWg.Add(1) go func() { defer serverWg.Done()