diff --git a/config/config.go b/config/config.go index 12a9f872f..45b17e1fe 100644 --- a/config/config.go +++ b/config/config.go @@ -108,7 +108,7 @@ var baseConfig = Config{ KeepAlive: int32(300), Timeout: int32(300), MaxConn: int32(0), - ShardCronFrequency: 1 * time.Second, + ShardCronFrequency: 30 * time.Second, MultiplexerPollTimeout: 100 * time.Millisecond, MaxClients: int32(20000), MaxMemory: 0, diff --git a/internal/clientio/iohandler/netconn/netconn.go b/internal/clientio/iohandler/netconn/netconn.go index 14ba9c620..bac5c553d 100644 --- a/internal/clientio/iohandler/netconn/netconn.go +++ b/internal/clientio/iohandler/netconn/netconn.go @@ -9,6 +9,7 @@ import ( "log/slog" "net" "os" + "sync" "syscall" "time" @@ -18,24 +19,31 @@ import ( const ( maxRequestSize = 512 * 1024 // 512 KB - readBufferSize = 4 * 1024 // 4 KB + bufferSize = 4 * 1024 // 4 KB idleTimeout = 10 * time.Minute ) +var bufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, bufferSize) + }, +} + var ( ErrRequestTooLarge = errors.New("request too large") ErrIdleTimeout = errors.New("connection idle timeout") - ErrorClosed = errors.New("connection closed") + ErrorConnClosed = errors.New("connection closed") ) // IOHandler handles I/O operations for a network connection type IOHandler struct { - fd int - file *os.File - conn net.Conn - reader *bufio.Reader - writer *bufio.Writer - logger *slog.Logger + fd int + file *os.File + conn net.Conn + reader *bufio.Reader + writer *bufio.Writer + bufferPool sync.Pool + logger *slog.Logger } var _ iohandler.IOHandler = (*IOHandler)(nil) @@ -69,8 +77,13 @@ func NewIOHandler(clientFD int, logger *slog.Logger) (*IOHandler, error) { fd: clientFD, file: file, conn: conn, - reader: bufio.NewReader(conn), - writer: bufio.NewWriter(conn), + reader: bufio.NewReaderSize(conn, bufferSize), + writer: bufio.NewWriterSize(conn, bufferSize), + bufferPool: sync.Pool{ + New: func() interface{} { + return make([]byte, maxRequestSize) + }, + }, logger: logger, }, nil } @@ -78,8 +91,8 @@ func NewIOHandler(clientFD int, logger *slog.Logger) (*IOHandler, error) { func NewIOHandlerWithConn(conn net.Conn) *IOHandler { return &IOHandler{ conn: conn, - reader: bufio.NewReader(conn), - writer: bufio.NewWriter(conn), + reader: bufio.NewReaderSize(conn, bufferSize), + writer: bufio.NewWriterSize(conn, bufferSize), } } @@ -90,7 +103,8 @@ func (h *IOHandler) FileDescriptor() int { // ReadRequest reads data from the network connection func (h *IOHandler) Read(ctx context.Context) ([]byte, error) { var data []byte - buf := make([]byte, readBufferSize) + buf := bufferPool.Get().([]byte) + defer bufferPool.Put(buf) for { select { @@ -112,12 +126,12 @@ func (h *IOHandler) Read(ctx context.Context) ([]byte, error) { // No more data to read at this time return data, nil case errors.Is(err, net.ErrClosed), errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ECONNRESET): - h.logger.Error("Connection closed", slog.Any("error", err)) + h.logger.Info("Connection closed", slog.Any("error", err)) cerr := h.Close() if cerr != nil { h.logger.Warn("Error closing connection", slog.Any("error", errors.Join(err, cerr))) } - return nil, ErrorClosed + return nil, ErrorConnClosed case errors.Is(err, syscall.ETIMEDOUT): h.logger.Info("Connection idle timeout", slog.Any("error", err)) cerr := h.Close() @@ -146,8 +160,6 @@ func (h *IOHandler) Read(ctx context.Context) ([]byte, error) { // WriteResponse writes the response back to the network connection func (h *IOHandler) Write(ctx context.Context, response interface{}) error { - errChan := make(chan error, 1) - // Process the incoming response by calling the handleResponse function. // This function checks the response against known RESP formatted values // and returns the corresponding byte array representation. The result @@ -166,32 +178,23 @@ func (h *IOHandler) Write(ctx context.Context, response interface{}) error { resp = clientio.Encode(response, true) } - go func(errChan chan error) { - _, err := h.writer.Write(resp) - if err == nil { - err = h.writer.Flush() - } - - errChan <- err - }(errChan) - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-errChan: - if err != nil { - if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { - cerr := h.Close() - if cerr != nil { - err = errors.Join(err, cerr) - } + _, err := h.writer.Write(resp) + if err == nil { + err = h.writer.Flush() + } - h.logger.Error("Connection closed", slog.Any("error", err)) - return err + if err != nil { + if errors.Is(err, net.ErrClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + cerr := h.Close() + if cerr != nil { + err = errors.Join(err, cerr) } - return fmt.Errorf("error writing response: %w", err) + h.logger.Info("Connection closed", slog.Any("error", err)) // Connection closed, logging as info + return nil } + + return fmt.Errorf("error writing response: %w", err) } return nil diff --git a/internal/shard/shard_manager.go b/internal/shard/shard_manager.go index d5010bcbe..82ccbcaa8 100644 --- a/internal/shard/shard_manager.go +++ b/internal/shard/shard_manager.go @@ -89,8 +89,8 @@ func (manager *ShardManager) GetShardInfo(key string) (id ShardID, c chan *ops.S } // GetShardCount returns the number of shards managed by this ShardManager. -func (manager *ShardManager) GetShardCount() int8 { - return int8(len(manager.shards)) +func (manager *ShardManager) GetShardCount() uint8 { + return uint8(len(manager.shards)) } // GetShard returns the ShardThread for the given ShardID. diff --git a/internal/worker/worker.go b/internal/worker/worker.go index dd379a187..bb1146bef 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/dicedb/dice/internal/clientio/iohandler/netconn" "log/slog" "net" "syscall" @@ -80,9 +81,15 @@ func (w *BaseWorker) Start(ctx context.Context) error { default: data, err := w.ioHandler.Read(ctx) if err != nil { + if errors.Is(netconn.ErrorConnClosed, err) { + w.logger.Debug("Connection closed", slog.String("workerID", w.id)) + return nil + } + w.logger.Debug("Read error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) return err } + cmds, err := w.parser.Parse(data) if err != nil { err = w.ioHandler.Write(ctx, err) diff --git a/main.go b/main.go index d68485092..52450e57e 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,8 @@ import ( "os" "os/signal" "runtime" + "runtime/pprof" + "runtime/trace" "sync" "syscall" @@ -149,6 +151,50 @@ func main() { } }() } else { + // Start CPU profiling + cpuFile, err := os.Create("cpu.prof") + if err != nil { + logr.Warn("could not create CPU profile: ", err) + } + defer cpuFile.Close() + + if err := pprof.StartCPUProfile(cpuFile); err != nil { + logr.Warn("could not start CPU profile: ", err) + } + defer pprof.StopCPUProfile() + + // Start memory profiling + memFile, err := os.Create("mem.prof") + if err != nil { + logr.Warn("could not create memory profile: ", err) + } + defer memFile.Close() + + // Start block profiling + runtime.SetBlockProfileRate(1) + defer func() { + blockFile, err := os.Create("block.prof") + if err != nil { + logr.Warn("could not create block profile: ", err) + } + defer blockFile.Close() + if err := pprof.Lookup("block").WriteTo(blockFile, 0); err != nil { + logr.Warn("could not write block profile: ", err) + } + }() + + // Start execution trace + traceFile, err := os.Create("trace.out") + if err != nil { + logr.Warn("could not create trace output file: ", err) + } + defer traceFile.Close() + + if err := trace.Start(traceFile); err != nil { + logr.Warn("could not start trace: ", err) + } + defer trace.Stop() + workerManager := worker.NewWorkerManager(config.DiceConfig.Server.MaxClients, shardManager) // Initialize the RESP Server respServer := resp.NewServer(shardManager, workerManager, serverErrCh, logr) @@ -181,6 +227,12 @@ func main() { respServer.Shutdown() cancel() }() + + // Ensure all profiling data is written before exiting + runtime.GC() + if err := pprof.WriteHeapProfile(memFile); err != nil { + logr.Warn("could not write memory profile: ", err) + } } websocketServer := server.NewWebSocketServer(shardManager, watchChan, logr)