From 6496d1509473b9a043939f110c090e8aa294e787 Mon Sep 17 00:00:00 2001 From: huiqing Date: Thu, 21 May 2020 15:16:23 -0700 Subject: [PATCH 1/2] SIGHUP for tls config update Signed-off-by: huiqing --- go/mysql/auth_server_clientcert_test.go | 4 ++-- go/mysql/handshake_test.go | 2 +- go/mysql/server.go | 10 ++++++---- go/mysql/server_test.go | 4 ++-- go/vt/vtgate/plugin_mysql_server.go | 22 ++++++++++++++++++++-- 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 73afb810ce0..9dbfdfe0d72 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -65,7 +65,7 @@ func TestValidCert(t *testing.T) { if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } - l.TLSConfig = serverConfig + l.TLSConfig.Store(serverConfig) go func() { l.Accept() }() @@ -147,7 +147,7 @@ func TestNoCert(t *testing.T) { if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } - l.TLSConfig = serverConfig + l.TLSConfig.Store(serverConfig) go func() { l.Accept() }() diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index 6c3f8b5ff93..ca36d4bf806 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -129,7 +129,7 @@ func TestSSLConnection(t *testing.T) { if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } - l.TLSConfig = serverConfig + l.TLSConfig.Store(serverConfig) go func() { l.Accept() }() diff --git a/go/mysql/server.go b/go/mysql/server.go index b40960fc39c..d47804b0280 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -22,6 +22,7 @@ import ( "io" "net" "strings" + "sync/atomic" "time" proxyproto "github.com/pires/go-proxyproto" @@ -139,7 +140,8 @@ type Listener struct { // TLSConfig is the server TLS config. If set, we will advertise // that we support SSL. - TLSConfig *tls.Config + // atomic value stores *tls.Config + TLSConfig atomic.Value // AllowClearTextWithoutTLS needs to be set for the // mysql_clear_password authentication method to be accepted @@ -292,7 +294,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti defer connCount.Add(-1) // First build and send the server handshake packet. - salt, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig != nil) + salt, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig.Load() != nil) if err != nil { if err != io.EOF { log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err) @@ -638,9 +640,9 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by pos += 23 // Check for SSL. - if firstTime && l.TLSConfig != nil && clientFlags&CapabilityClientSSL > 0 { + if firstTime && l.TLSConfig.Load() != nil && clientFlags&CapabilityClientSSL > 0 { // Need to switch to TLS, and then re-read the packet. - conn := tls.Server(c.conn, l.TLSConfig) + conn := tls.Server(c.conn, l.TLSConfig.Load().(*tls.Config)) c.conn = conn c.bufferedReader.Reset(conn) c.Capabilities |= CapabilityClientSSL diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 6dce83b57de..a8d7dd3b0df 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -957,7 +957,7 @@ func TestTLSServer(t *testing.T) { if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } - l.TLSConfig = serverConfig + l.TLSConfig.Store(serverConfig) go l.Accept() // Setup the right parameters. @@ -1063,7 +1063,7 @@ func TestTLSRequired(t *testing.T) { if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } - l.TLSConfig = serverConfig + l.TLSConfig.Store(serverConfig) l.RequireSecureTransport = true go l.Accept() diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index dc5e614b42e..f1358c318b2 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "os" + "os/signal" "regexp" "strings" "sync" @@ -330,7 +331,7 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener - +var sigChan chan os.Signal var vtgateHandle *vtgateHandler // initiMySQLProtocol starts the mysql protocol. @@ -377,12 +378,26 @@ func initMySQLProtocol() { mysqlListener.ServerVersion = *mysqlServerVersion } if *mysqlSslCert != "" && *mysqlSslKey != "" { - mysqlListener.TLSConfig, err = vttls.ServerConfig(*mysqlSslCert, *mysqlSslKey, *mysqlSslCa) + serverConfig, err := vttls.ServerConfig(*mysqlSslCert, *mysqlSslKey, *mysqlSslCa) if err != nil { log.Exitf("grpcutils.TLSServerConfig failed: %v", err) return } + mysqlListener.TLSConfig.Store(serverConfig) mysqlListener.RequireSecureTransport = *mysqlServerRequireSecureTransport + sigChan = make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGHUP) + go func() { + for range sigChan { + serverConfig, err := vttls.ServerConfig(*mysqlSslCert, *mysqlSslKey, *mysqlSslCa) + if err != nil { + log.Errorf("grpcutils.TLSServerConfig failed: %v", err) + } else { + log.Info("grpcutils.TLSServerConfig updated") + mysqlListener.TLSConfig.Store(serverConfig) + } + } + }() } mysqlListener.AllowClearTextWithoutTLS.Set(*mysqlAllowClearTextWithoutTLS) // Check for the connection threshold @@ -449,6 +464,9 @@ func shutdownMysqlProtocolAndDrain() { mysqlUnixListener.Close() mysqlUnixListener = nil } + if sigChan != nil { + signal.Stop(sigChan) + } if atomic.LoadInt32(&busyConnections) > 0 { log.Infof("Waiting for all client connections to be idle (%d active)...", atomic.LoadInt32(&busyConnections)) From 7c2bc2b4943e940e7763523b5a9e7a9b51f2e24e Mon Sep 17 00:00:00 2001 From: huiqing Date: Tue, 26 May 2020 22:33:23 -0700 Subject: [PATCH 2/2] wrap into `initTlsConfig` to make unit test of `SIGHUP` handling easier Signed-off-by: huiqing --- go/vt/vtgate/plugin_mysql_server.go | 46 +++++++++++++----------- go/vt/vtgate/plugin_mysql_server_test.go | 32 +++++++++++++++++ 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index f1358c318b2..e98c95727e2 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -334,6 +334,31 @@ var mysqlUnixListener *mysql.Listener var sigChan chan os.Signal var vtgateHandle *vtgateHandler +// initTLSConfig inits tls config for the given mysql listener +func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa string, mysqlServerRequireSecureTransport bool) error { + serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa) + if err != nil { + log.Exitf("grpcutils.TLSServerConfig failed: %v", err) + return err + } + mysqlListener.TLSConfig.Store(serverConfig) + mysqlListener.RequireSecureTransport = mysqlServerRequireSecureTransport + sigChan = make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGHUP) + go func() { + for range sigChan { + serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa) + if err != nil { + log.Errorf("grpcutils.TLSServerConfig failed: %v", err) + } else { + log.Info("grpcutils.TLSServerConfig updated") + mysqlListener.TLSConfig.Store(serverConfig) + } + } + }() + return nil +} + // initiMySQLProtocol starts the mysql protocol. // It should be called only once in a process. func initMySQLProtocol() { @@ -378,26 +403,7 @@ func initMySQLProtocol() { mysqlListener.ServerVersion = *mysqlServerVersion } if *mysqlSslCert != "" && *mysqlSslKey != "" { - serverConfig, err := vttls.ServerConfig(*mysqlSslCert, *mysqlSslKey, *mysqlSslCa) - if err != nil { - log.Exitf("grpcutils.TLSServerConfig failed: %v", err) - return - } - mysqlListener.TLSConfig.Store(serverConfig) - mysqlListener.RequireSecureTransport = *mysqlServerRequireSecureTransport - sigChan = make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGHUP) - go func() { - for range sigChan { - serverConfig, err := vttls.ServerConfig(*mysqlSslCert, *mysqlSslKey, *mysqlSslCa) - if err != nil { - log.Errorf("grpcutils.TLSServerConfig failed: %v", err) - } else { - log.Info("grpcutils.TLSServerConfig updated") - mysqlListener.TLSConfig.Store(serverConfig) - } - } - }() + initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlServerRequireSecureTransport) } mysqlListener.AllowClearTextWithoutTLS.Set(*mysqlAllowClearTextWithoutTLS) // Check for the connection threshold diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 993ab442792..6a43aa93015 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -19,8 +19,11 @@ package vtgate import ( "io/ioutil" "os" + "path" "strings" + "syscall" "testing" + "time" "github.com/stretchr/testify/assert" "vitess.io/vitess/go/trace" @@ -30,6 +33,7 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/tlstest" ) type testHandler struct { @@ -223,3 +227,31 @@ func TestDefaultWorkloadOLAP(t *testing.T) { t.Fatalf("Expected default workload OLAP") } } + +func TestInitTLSConfig(t *testing.T) { + // Create the certs. + root, err := ioutil.TempDir("", "TestInitTLSConfig") + if err != nil { + t.Fatalf("TempDir failed: %v", err) + } + defer os.RemoveAll(root) + tlstest.CreateCA(root) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") + + listener := &mysql.Listener{} + if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), true); err != nil { + t.Fatalf("init tls config failure due to: +%v", err) + } + + serverConfig := listener.TLSConfig.Load() + if serverConfig == nil { + t.Fatalf("init tls config shouldn't create nil server config") + } + + sigChan <- syscall.SIGHUP + time.Sleep(100 * time.Millisecond) // wait for signal handler + + if listener.TLSConfig.Load() == serverConfig { + t.Fatalf("init tls config should have been recreated after SIGHUP") + } +}