Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SIGHUP support for tls config update #6215

Merged
merged 2 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestValidCert(t *testing.T) {
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
l.TLSConfig = serverConfig
l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestNoCert(t *testing.T) {
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
l.TLSConfig = serverConfig
l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestSSLConnection(t *testing.T) {
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
l.TLSConfig = serverConfig
l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()
Expand Down
10 changes: 6 additions & 4 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"io"
"net"
"strings"
"sync/atomic"
"time"

proxyproto "github.com/pires/go-proxyproto"
Expand Down Expand Up @@ -139,7 +140,8 @@ type Listener struct {

// TLSConfig is the server TLS config. If set, we will advertise
// that we support SSL.
TLSConfig *tls.Config
// atomic value stores *tls.Config
TLSConfig atomic.Value

// AllowClearTextWithoutTLS needs to be set for the
// mysql_clear_password authentication method to be accepted
Expand Down Expand Up @@ -292,7 +294,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
defer connCount.Add(-1)

// First build and send the server handshake packet.
salt, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig != nil)
salt, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig.Load() != nil)
if err != nil {
if err != io.EOF {
log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err)
Expand Down Expand Up @@ -638,9 +640,9 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
pos += 23

// Check for SSL.
if firstTime && l.TLSConfig != nil && clientFlags&CapabilityClientSSL > 0 {
if firstTime && l.TLSConfig.Load() != nil && clientFlags&CapabilityClientSSL > 0 {
// Need to switch to TLS, and then re-read the packet.
conn := tls.Server(c.conn, l.TLSConfig)
conn := tls.Server(c.conn, l.TLSConfig.Load().(*tls.Config))
c.conn = conn
c.bufferedReader.Reset(conn)
c.Capabilities |= CapabilityClientSSL
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ func TestTLSServer(t *testing.T) {
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
l.TLSConfig = serverConfig
l.TLSConfig.Store(serverConfig)
go l.Accept()

// Setup the right parameters.
Expand Down Expand Up @@ -1063,7 +1063,7 @@ func TestTLSRequired(t *testing.T) {
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
l.TLSConfig = serverConfig
l.TLSConfig.Store(serverConfig)
l.RequireSecureTransport = true
go l.Accept()

Expand Down
38 changes: 31 additions & 7 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"net"
"os"
"os/signal"
"regexp"
"strings"
"sync"
Expand Down Expand Up @@ -330,9 +331,34 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session {

var mysqlListener *mysql.Listener
var mysqlUnixListener *mysql.Listener

var sigChan chan os.Signal
var vtgateHandle *vtgateHandler

// initTLSConfig inits tls config for the given mysql listener
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa string, mysqlServerRequireSecureTransport bool) error {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa)
if err != nil {
log.Exitf("grpcutils.TLSServerConfig failed: %v", err)
return err
}
mysqlListener.TLSConfig.Store(serverConfig)
mysqlListener.RequireSecureTransport = mysqlServerRequireSecureTransport
sigChan = make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGHUP)
go func() {
for range sigChan {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa)
if err != nil {
log.Errorf("grpcutils.TLSServerConfig failed: %v", err)
} else {
log.Info("grpcutils.TLSServerConfig updated")
mysqlListener.TLSConfig.Store(serverConfig)
}
}
}()
return nil
}

// initiMySQLProtocol starts the mysql protocol.
// It should be called only once in a process.
func initMySQLProtocol() {
Expand Down Expand Up @@ -377,12 +403,7 @@ func initMySQLProtocol() {
mysqlListener.ServerVersion = *mysqlServerVersion
}
if *mysqlSslCert != "" && *mysqlSslKey != "" {
mysqlListener.TLSConfig, err = vttls.ServerConfig(*mysqlSslCert, *mysqlSslKey, *mysqlSslCa)
if err != nil {
log.Exitf("grpcutils.TLSServerConfig failed: %v", err)
return
}
mysqlListener.RequireSecureTransport = *mysqlServerRequireSecureTransport
initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlServerRequireSecureTransport)
}
mysqlListener.AllowClearTextWithoutTLS.Set(*mysqlAllowClearTextWithoutTLS)
// Check for the connection threshold
Expand Down Expand Up @@ -449,6 +470,9 @@ func shutdownMysqlProtocolAndDrain() {
mysqlUnixListener.Close()
mysqlUnixListener = nil
}
if sigChan != nil {
signal.Stop(sigChan)
}

if atomic.LoadInt32(&busyConnections) > 0 {
log.Infof("Waiting for all client connections to be idle (%d active)...", atomic.LoadInt32(&busyConnections))
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/plugin_mysql_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package vtgate
import (
"io/ioutil"
"os"
"path"
"strings"
"syscall"
"testing"
"time"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/trace"
Expand All @@ -30,6 +33,7 @@ import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/tlstest"
)

type testHandler struct {
Expand Down Expand Up @@ -223,3 +227,31 @@ func TestDefaultWorkloadOLAP(t *testing.T) {
t.Fatalf("Expected default workload OLAP")
}
}

func TestInitTLSConfig(t *testing.T) {
// Create the certs.
root, err := ioutil.TempDir("", "TestInitTLSConfig")
if err != nil {
t.Fatalf("TempDir failed: %v", err)
}
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")

listener := &mysql.Listener{}
if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), true); err != nil {
t.Fatalf("init tls config failure due to: +%v", err)
}

serverConfig := listener.TLSConfig.Load()
if serverConfig == nil {
t.Fatalf("init tls config shouldn't create nil server config")
}

sigChan <- syscall.SIGHUP
time.Sleep(100 * time.Millisecond) // wait for signal handler

if listener.TLSConfig.Load() == serverConfig {
t.Fatalf("init tls config should have been recreated after SIGHUP")
}
}