Skip to content

Commit

Permalink
wip: support zero-downtime socket handoff
Browse files Browse the repository at this point in the history
Add a generic helper for zero-downtime listener socket handoff. Wrap a helper
function so that callers can easily integration and optionally enable the
feature.
  • Loading branch information
demmer committed Feb 10, 2025
1 parent aa0835a commit 257c0ad
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 3 deletions.
179 changes: 179 additions & 0 deletions go/handoff/handoff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package handoff

import (
"errors"
"fmt"
"net"
"os"
"syscall"
"time"

"golang.org/x/sys/unix"

"vitess.io/vitess/go/vt/log"
)

// handoff implements a no-downtime handoff of a TCP listener from one running
// process to another. It can be used for no-downtime deploys of HTTP servers
// on a single host/port.

// Listen opens a unix domain socket and listens for handoff requests. When a
// handoff request is received, the underlying file descriptor of `listener` is
// handed off over the socket.
//
// If an error occurs while opening the unix domain
// socket, or during handoff, it will be logged and the listener will resume
// listening. Otherwise, Listen will return nil when the handoff is complete.
//
// Callers should drain any servers
// connected to the net.Listener, and in-flight requests should be resolved
// before shutting down.
func Listen(socketPath string, listener net.Listener) error {
// Clean up any leftover sockets that might have gotten left from previous
// processes.
os.Remove(socketPath)

unixListener, err := net.Listen("unix", socketPath)
if err != nil {
return err
}
defer func() {
unixListener.Close()
os.Remove(socketPath)
}()

for {
err := listen(unixListener, listener)
if err != nil {
log.Error("handoff socket error", "error", err)
continue
}

return nil
}
}

var magicPacket = "handoff"

func listen(unixListener, listener net.Listener) error {
conn, err := unixListener.Accept()
if err != nil {
return err
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(1 * time.Second))
if err != nil {
return err
}

b := make([]byte, len(magicPacket))
n, err := conn.Read(b)
if err != nil {
return err
}
if string(b[:n]) != magicPacket {
return errors.New("bad magic packet")
}

return handoff(conn, listener)
}

func handoff(conn net.Conn, listener net.Listener) error {
unixFD, err := getFD(conn.(*net.UnixConn))
if err != nil {
return err
}

tcpListener := listener.(*net.TCPListener)

tcpFd, err := getFD(tcpListener)
if err != nil {
return err
}

rights := unix.UnixRights(tcpFd)
err = unix.Sendmsg(unixFD, nil, rights, nil, 0)
if err != nil {
return err
}

return nil
}

// Request checks for the presence of a unix domain socket at `socketPath` and
// opens a connection. The server side of the socket will immediately send a
// file descriptor of a TCP socket over the unix domain socket. This file
// descriptor is converted into a net.Listener and returned to the caller for
// immediate use.
//
// During the time between socket handoff and startup of the new server,
// requests to the socket will block. Requests will only fail if the client
// timeout is shorter than the duration of the handoff period.
//
// If nothing is listening on the other end of the unix domain socket,
// ErrNoHandoff is returned. Clients should check for this condition, and dial
// the TCP socket themselves.
func Request(socketPath string) (net.Listener, error) {
conn, err := net.Dial("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrNoHandoff, err)
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(1 * time.Second))
if err != nil {
return nil, err
}

_, err = conn.Write([]byte(magicPacket))
if err != nil {
return nil, fmt.Errorf("%w: failed to send magic packet", err)
}

f, err := (conn.(*net.UnixConn)).File()
if err != nil {
return nil, fmt.Errorf("%w: fd not read", err)
}
defer f.Close()

b := make([]byte, unix.CmsgSpace(4))
//nolint:dogsled
_, _, _, _, err = unix.Recvmsg(int(f.Fd()), nil, b, 0)
if err != nil {
return nil, fmt.Errorf("%w: msg not received", err)
}

cmsgs, err := unix.ParseSocketControlMessage(b)
if err != nil {
return nil, fmt.Errorf("%w: control msg not parsed", err)
}
fds, err := unix.ParseUnixRights(&cmsgs[0])
if err != nil {
return nil, fmt.Errorf("%w: invalid unix rights", err)
}
fd := fds[0]

listenerFD := os.NewFile(uintptr(fd), "listener")
defer f.Close()

l, err := net.FileListener(listenerFD)
if err != nil {
return nil, fmt.Errorf("%w: failed to acquire new fd", err)
}

return l, nil
}

// ErrNoHandoff indicates that no handoff was performed.
var ErrNoHandoff = errors.New("no handoff")

func getFD(conn syscall.Conn) (fd int, err error) {
raw, err := conn.SyscallConn()
if err != nil {
return -1, err
}

err = raw.Control(func(ptr uintptr) {
fd = int(ptr)
})
return fd, err
}
2 changes: 1 addition & 1 deletion go/vt/servenv/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ func serveGRPC() {

// listen on the port
log.Infof("Listening for gRPC calls on port %v", gRPCPort)
listener, err := net.Listen("tcp", net.JoinHostPort(gRPCBindAddress, strconv.Itoa(gRPCPort)))
listener, err := HandoffOrListen("grpc", "tcp", net.JoinHostPort(gRPCBindAddress, strconv.Itoa(gRPCPort)))
if err != nil {
log.Exitf("Cannot listen on port %v for gRPC: %v", gRPCPort, err)
}
Expand Down
85 changes: 85 additions & 0 deletions go/vt/servenv/handoff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
Copyright 2025 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package servenv

import (
"errors"
"net"

"github.com/spf13/pflag"

"vitess.io/vitess/go/handoff"
"vitess.io/vitess/go/vt/log"
)

var (
// handoffPath specifies the filesystem path where handoff sockets are to be created,
// if zero-downtime handoff is enabled.
//
// To expose this flag, call RegisterHandoffFlags before ParseFlags.
rootHandoffPath string
)

func RegisterHandoffFlags() {
OnParse(func(fs *pflag.FlagSet) {
fs.StringVar(&rootHandoffPath, "handoff_path", rootHandoffPath, "Root path to enable zero-downtime handoff sockets.")
})
}

// HandoffOrListen implements optional support for zero-downtime socket handoff.
//
// If enabled by configuring a root path for the handoff, this first attempts to
// take over the socket from a running process, otherwise this just calls the regular
// net.Listen. Once it takes over the socket or creates a new one, it will also start
// listening for requests to hand off the socket to future runs.
func HandoffOrListen(serviceName, protocol, address string) (net.Listener, error) {

// If there is no path to handoff, then just pass through to the core net.listen.
if rootHandoffPath == "" {
return net.Listen(protocol, address)
}

handoffPath := rootHandoffPath + serviceName

// Request socket from an already running process, or start a new
// listener to serve requests on.
log.Infof("handoff: requesting handoff socket from %s", handoffPath)
listener, err := handoff.Request(handoffPath)
if err == nil {
log.Infof("handoff: received handoff from %s", handoffPath)
} else {
if errors.Is(err, handoff.ErrNoHandoff) {
log.Infof("handoff: no handoff, listening on %s %s", protocol, address)
listener, err = net.Listen(protocol, address)
} else {
log.Exitf("handoff: fatal error: %v", err)
}
}

// Advertise unix domain socket for handoff by future processes.
go func() {
err := handoff.Listen(handoffPath, listener)
if err != nil {
log.Errorf("Handoff failed: %v", err)
return
}

log.Infof("handed off socket %s %s", protocol, address)
}()

return listener, err
}
2 changes: 1 addition & 1 deletion go/vt/servenv/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Run(bindAddress string, port int) {
serveGRPC()
serveSocketFile()

l, err := net.Listen("tcp", net.JoinHostPort(bindAddress, strconv.Itoa(port)))
l, err := HandoffOrListen("servenv-"+strconv.Itoa(port), "tcp", net.JoinHostPort(bindAddress, strconv.Itoa(port)))
if err != nil {
log.Exit(err)
}
Expand Down
12 changes: 11 additions & 1 deletion go/vt/vtgateproxy/mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"syscall"
"time"

"github.com/pires/go-proxyproto"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtenv"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -504,7 +506,15 @@ func initMySQLProtocol() {
}
if *mysqlServerPort >= 0 {
log.Infof("Mysql Server listening on Port %d", *mysqlServerPort)
mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol, *mysqlConnBufferPooling, *mysqlKeepAlivePeriod, *mysqlServerFlushDelay)
listener, err := servenv.HandoffOrListen("mysql", *mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)))
if err != nil {
log.Exitf("HandoffOrListen failed: %v", err)
}
if *mysqlProxyProtocol {
listener = &proxyproto.Listener{Listener: listener}
}

mysqlListener, err = mysql.NewFromListener(listener, authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlConnBufferPooling, *mysqlKeepAlivePeriod, *mysqlServerFlushDelay)
if err != nil {
log.Exitf("mysql.NewListener failed: %v", err)
}
Expand Down

0 comments on commit 257c0ad

Please sign in to comment.