diff --git a/go/handoff/handoff.go b/go/handoff/handoff.go new file mode 100644 index 00000000000..455e9440134 --- /dev/null +++ b/go/handoff/handoff.go @@ -0,0 +1,179 @@ +package handoff + +import ( + "errors" + "fmt" + "net" + "os" + "syscall" + "time" + + "golang.org/x/sys/unix" + + "vitess.io/vitess/go/vt/log" +) + +// handoff implements a no-downtime handoff of a TCP listener from one running +// process to another. It can be used for no-downtime deploys of HTTP servers +// on a single host/port. + +// Listen opens a unix domain socket and listens for handoff requests. When a +// handoff request is received, the underlying file descriptor of `listener` is +// handed off over the socket. +// +// If an error occurs while opening the unix domain +// socket, or during handoff, it will be logged and the listener will resume +// listening. Otherwise, Listen will return nil when the handoff is complete. +// +// Callers should drain any servers +// connected to the net.Listener, and in-flight requests should be resolved +// before shutting down. +func Listen(socketPath string, listener net.Listener) error { + // Clean up any leftover sockets that might have gotten left from previous + // processes. + os.Remove(socketPath) + + unixListener, err := net.Listen("unix", socketPath) + if err != nil { + return err + } + defer func() { + unixListener.Close() + os.Remove(socketPath) + }() + + for { + err := listen(unixListener, listener) + if err != nil { + log.Error("handoff socket error", "error", err) + continue + } + + return nil + } +} + +var magicPacket = "handoff" + +func listen(unixListener, listener net.Listener) error { + conn, err := unixListener.Accept() + if err != nil { + return err + } + defer conn.Close() + err = conn.SetDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + return err + } + + b := make([]byte, len(magicPacket)) + n, err := conn.Read(b) + if err != nil { + return err + } + if string(b[:n]) != magicPacket { + return errors.New("bad magic packet") + } + + return handoff(conn, listener) +} + +func handoff(conn net.Conn, listener net.Listener) error { + unixFD, err := getFD(conn.(*net.UnixConn)) + if err != nil { + return err + } + + tcpListener := listener.(*net.TCPListener) + + tcpFd, err := getFD(tcpListener) + if err != nil { + return err + } + + rights := unix.UnixRights(tcpFd) + err = unix.Sendmsg(unixFD, nil, rights, nil, 0) + if err != nil { + return err + } + + return nil +} + +// Request checks for the presence of a unix domain socket at `socketPath` and +// opens a connection. The server side of the socket will immediately send a +// file descriptor of a TCP socket over the unix domain socket. This file +// descriptor is converted into a net.Listener and returned to the caller for +// immediate use. +// +// During the time between socket handoff and startup of the new server, +// requests to the socket will block. Requests will only fail if the client +// timeout is shorter than the duration of the handoff period. +// +// If nothing is listening on the other end of the unix domain socket, +// ErrNoHandoff is returned. Clients should check for this condition, and dial +// the TCP socket themselves. +func Request(socketPath string) (net.Listener, error) { + conn, err := net.Dial("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrNoHandoff, err) + } + defer conn.Close() + err = conn.SetDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + return nil, err + } + + _, err = conn.Write([]byte(magicPacket)) + if err != nil { + return nil, fmt.Errorf("%w: failed to send magic packet", err) + } + + f, err := (conn.(*net.UnixConn)).File() + if err != nil { + return nil, fmt.Errorf("%w: fd not read", err) + } + defer f.Close() + + b := make([]byte, unix.CmsgSpace(4)) + //nolint:dogsled + _, _, _, _, err = unix.Recvmsg(int(f.Fd()), nil, b, 0) + if err != nil { + return nil, fmt.Errorf("%w: msg not received", err) + } + + cmsgs, err := unix.ParseSocketControlMessage(b) + if err != nil { + return nil, fmt.Errorf("%w: control msg not parsed", err) + } + fds, err := unix.ParseUnixRights(&cmsgs[0]) + if err != nil { + return nil, fmt.Errorf("%w: invalid unix rights", err) + } + fd := fds[0] + + listenerFD := os.NewFile(uintptr(fd), "listener") + defer f.Close() + + l, err := net.FileListener(listenerFD) + if err != nil { + return nil, fmt.Errorf("%w: failed to acquire new fd", err) + } + + return l, nil +} + +// ErrNoHandoff indicates that no handoff was performed. +var ErrNoHandoff = errors.New("no handoff") + +func getFD(conn syscall.Conn) (fd int, err error) { + raw, err := conn.SyscallConn() + if err != nil { + return -1, err + } + + err = raw.Control(func(ptr uintptr) { + fd = int(ptr) + }) + return fd, err +} diff --git a/go/vt/servenv/grpc_server.go b/go/vt/servenv/grpc_server.go index 3ad79dc3641..72e1ffff310 100644 --- a/go/vt/servenv/grpc_server.go +++ b/go/vt/servenv/grpc_server.go @@ -290,7 +290,7 @@ func serveGRPC() { // listen on the port log.Infof("Listening for gRPC calls on port %v", gRPCPort) - listener, err := net.Listen("tcp", net.JoinHostPort(gRPCBindAddress, strconv.Itoa(gRPCPort))) + listener, err := HandoffOrListen("grpc", "tcp", net.JoinHostPort(gRPCBindAddress, strconv.Itoa(gRPCPort))) if err != nil { log.Exitf("Cannot listen on port %v for gRPC: %v", gRPCPort, err) } diff --git a/go/vt/servenv/handoff.go b/go/vt/servenv/handoff.go new file mode 100644 index 00000000000..99853478d1a --- /dev/null +++ b/go/vt/servenv/handoff.go @@ -0,0 +1,85 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package servenv + +import ( + "errors" + "net" + + "github.com/spf13/pflag" + + "vitess.io/vitess/go/handoff" + "vitess.io/vitess/go/vt/log" +) + +var ( + // handoffPath specifies the filesystem path where handoff sockets are to be created, + // if zero-downtime handoff is enabled. + // + // To expose this flag, call RegisterHandoffFlags before ParseFlags. + rootHandoffPath string +) + +func RegisterHandoffFlags() { + OnParse(func(fs *pflag.FlagSet) { + fs.StringVar(&rootHandoffPath, "handoff_path", rootHandoffPath, "Root path to enable zero-downtime handoff sockets.") + }) +} + +// HandoffOrListen implements optional support for zero-downtime socket handoff. +// +// If enabled by configuring a root path for the handoff, this first attempts to +// take over the socket from a running process, otherwise this just calls the regular +// net.Listen. Once it takes over the socket or creates a new one, it will also start +// listening for requests to hand off the socket to future runs. +func HandoffOrListen(serviceName, protocol, address string) (net.Listener, error) { + + // If there is no path to handoff, then just pass through to the core net.listen. + if rootHandoffPath == "" { + return net.Listen(protocol, address) + } + + handoffPath := rootHandoffPath + serviceName + + // Request socket from an already running process, or start a new + // listener to serve requests on. + log.Infof("handoff: requesting handoff socket from %s", handoffPath) + listener, err := handoff.Request(handoffPath) + if err == nil { + log.Infof("handoff: received handoff from %s", handoffPath) + } else { + if errors.Is(err, handoff.ErrNoHandoff) { + log.Infof("handoff: no handoff, listening on %s %s", protocol, address) + listener, err = net.Listen(protocol, address) + } else { + log.Exitf("handoff: fatal error: %v", err) + } + } + + // Advertise unix domain socket for handoff by future processes. + go func() { + err := handoff.Listen(handoffPath, listener) + if err != nil { + log.Errorf("Handoff failed: %v", err) + return + } + + log.Infof("handed off socket %s %s", protocol, address) + }() + + return listener, err +} diff --git a/go/vt/servenv/run.go b/go/vt/servenv/run.go index 6f028786eaf..210158312cc 100644 --- a/go/vt/servenv/run.go +++ b/go/vt/servenv/run.go @@ -44,7 +44,7 @@ func Run(bindAddress string, port int) { serveGRPC() serveSocketFile() - l, err := net.Listen("tcp", net.JoinHostPort(bindAddress, strconv.Itoa(port))) + l, err := HandoffOrListen("servenv-"+strconv.Itoa(port), "tcp", net.JoinHostPort(bindAddress, strconv.Itoa(port))) if err != nil { log.Exit(err) } diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go index 65771400036..a988608dc53 100644 --- a/go/vt/vtgateproxy/mysql_server.go +++ b/go/vt/vtgateproxy/mysql_server.go @@ -30,6 +30,8 @@ import ( "syscall" "time" + "github.com/pires/go-proxyproto" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" @@ -504,7 +506,15 @@ func initMySQLProtocol() { } if *mysqlServerPort >= 0 { log.Infof("Mysql Server listening on Port %d", *mysqlServerPort) - mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol, *mysqlConnBufferPooling, *mysqlKeepAlivePeriod, *mysqlServerFlushDelay) + listener, err := servenv.HandoffOrListen("mysql", *mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort))) + if err != nil { + log.Exitf("HandoffOrListen failed: %v", err) + } + if *mysqlProxyProtocol { + listener = &proxyproto.Listener{Listener: listener} + } + + mysqlListener, err = mysql.NewFromListener(listener, authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlConnBufferPooling, *mysqlKeepAlivePeriod, *mysqlServerFlushDelay) if err != nil { log.Exitf("mysql.NewListener failed: %v", err) }