Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TLS to connect mysql/TiDB #894

Merged
merged 1 commit into from
Feb 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arbiter/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewServer(cfg *Config) (srv *Server, err error) {
up := cfg.Up
down := cfg.Down

srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port)
srv.downDB, err = createDB(down.User, down.Password, down.Host, down.Port, nil)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
7 changes: 4 additions & 3 deletions arbiter/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package arbiter

import (
"context"
"crypto/tls"
"database/sql"
"fmt"
"sync"
Expand Down Expand Up @@ -55,7 +56,7 @@ func (l *dummyLoader) Close() {
type testNewServerSuite struct {
db *sql.DB
dbMock sqlmock.Sqlmock
origCreateDB func(string, string, string, int) (*sql.DB, error)
origCreateDB func(string, string, string, int, *tls.Config) (*sql.DB, error)
origNewReader func(*reader.Config) (*reader.Reader, error)
origNewLoader func(*sql.DB, ...loader.Option) (loader.Loader, error)
}
Expand All @@ -71,7 +72,7 @@ func (s *testNewServerSuite) SetUpTest(c *C) {
s.dbMock = mock

s.origCreateDB = createDB
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return s.db, nil
}

Expand Down Expand Up @@ -105,7 +106,7 @@ func (s *testNewServerSuite) TestRejectInvalidAddr(c *C) {
}

func (s *testNewServerSuite) TestStopIfFailedtoConnectDownStream(c *C) {
createDB = func(user string, password string, host string, port int) (*sql.DB, error) {
createDB = func(user string, password string, host string, port int, _ *tls.Config) (*sql.DB, error) {
return nil, fmt.Errorf("Can't create db")
}

Expand Down
10 changes: 10 additions & 0 deletions cmd/drainer/drainer.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ port = 3306
# when setting SyncPartialColumn drainer will allow the downstream schema
# having more or less column numbers and relax sql mode by removing STRICT_TRANS_TABLES.
# sync-mode = 1
#
# Uncomment this part if you need TLS to connecting downstream MySQL/TiDB.
# You can only specified only `ssl-ca` if there is no client certificate and don't need server to authenticate client.
# [syncer.to.security]
# Path of file that contains list of trusted SSL CAs.
# ssl-ca = "/path/to/ca.pem"
# Path of file that contains X509 certificate in PEM format.
# ssl-cert = "/path/to/drainer.pem"
# Path of file that contains X509 key in PEM format.
# ssl-key = "/path/to/drainer-key.pem"

[syncer.to.checkpoint]
# only support mysql or tidb now, you can uncomment this to control where the checkpoint is saved.
Expand Down
7 changes: 7 additions & 0 deletions drainer/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ func (cfg *Config) Parse(args []string) error {
return errors.Errorf("tls config %+v error %v", cfg.Security, err)
}

if cfg.SyncerCfg != nil && cfg.SyncerCfg.To != nil {
cfg.SyncerCfg.To.TLS, err = cfg.SyncerCfg.To.Security.ToTLSConfig()
if err != nil {
return errors.Errorf("tls config %+v error %v", cfg.SyncerCfg.To.Security, err)
}
}

if err = cfg.adjustConfig(); err != nil {
return errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion drainer/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func feedByRelayLogIfNeed(cfg *Config) error {
return errors.Annotate(err, "failed to create reader")
}

db, err := loader.CreateDBWithSQLMode(scfg.To.User, scfg.To.Password, scfg.To.Host, scfg.To.Port, scfg.StrSQLMode)
db, err := loader.CreateDBWithSQLMode(scfg.To.User, scfg.To.Password, scfg.To.Host, scfg.To.Port, scfg.To.TLS, scfg.StrSQLMode)
if err != nil {
return errors.Annotate(err, "failed to create SQL db")
}
Expand Down
8 changes: 6 additions & 2 deletions drainer/sync/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ func NewMysqlSyncer(
relayer relay.Relayer,
info *loopbacksync.LoopBackSync,
) (*MysqlSyncer, error) {
db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, sqlMode)
if cfg.TLS != nil {
log.Info("enable TLS to connect downstream MySQL/TiDB")
}

db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.TLS, sqlMode)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -103,7 +107,7 @@ func NewMysqlSyncer(

if newMode != oldMode {
db.Close()
db, err = createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, &newMode)
db, err = createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.TLS, &newMode)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
3 changes: 2 additions & 1 deletion drainer/sync/syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package sync

import (
"crypto/tls"
"database/sql"
"reflect"
"sync/atomic"
Expand Down Expand Up @@ -58,7 +59,7 @@ func (s *syncerSuite) SetUpTest(c *check.C) {

// create mysql syncer
oldCreateDB := createDB
createDB = func(string, string, string, int, *string) (db *sql.DB, err error) {
createDB = func(string, string, string, int, *tls.Config, *string) (db *sql.DB, err error) {
db, s.mysqlMock, err = sqlmock.New()
return
}
Expand Down
11 changes: 8 additions & 3 deletions drainer/sync/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
package sync

import (
"crypto/tls"

// mysql driver
_ "github.com/go-sql-driver/mysql"
"github.com/pingcap/tidb-binlog/pkg/security"
)

// DBConfig is the DB configuration.
type DBConfig struct {
Host string `toml:"host" json:"host"`
User string `toml:"user" json:"user"`
Password string `toml:"password" json:"password"`
Host string `toml:"host" json:"host"`
User string `toml:"user" json:"user"`
Password string `toml:"password" json:"password"`
Security security.Config `toml:"security" json:"security"`
TLS *tls.Config `toml:"-" json:"-"`
// if EncryptedPassword is not empty, Password will be ignore.
EncryptedPassword string `toml:"encrypted_password" json:"encrypted_password"`
SyncMode int `toml:"sync-mode" json:"sync-mode"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/loader/example_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import "log"

func Example() {
// create sql.DB
db, err := CreateDB("root", "", "localhost", 4000)
db, err := CreateDB("root", "", "localhost", 4000, nil /* *tls.Config */)
if err != nil {
log.Fatal(err)
}
Expand Down
21 changes: 18 additions & 3 deletions pkg/loader/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
package loader

import (
"crypto/tls"
gosql "database/sql"
"fmt"
"hash/crc32"
"net/url"
"strconv"
"strings"
"sync/atomic"

"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
)

Expand Down Expand Up @@ -77,14 +81,25 @@ func getTableInfo(db *gosql.DB, schema string, table string) (info *tableInfo, e
return
}

var customID int64

// CreateDBWithSQLMode return sql.DB
func CreateDBWithSQLMode(user string, password string, host string, port int, sqlMode *string) (db *gosql.DB, err error) {
func CreateDBWithSQLMode(user string, password string, host string, port int, tlsConfig *tls.Config, sqlMode *string) (db *gosql.DB, err error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, host, port)
if sqlMode != nil {
// same as "set sql_mode = '<sqlMode>'"
dsn += "&sql_mode='" + url.QueryEscape(*sqlMode) + "'"
}

if tlsConfig != nil {
name := "custom_" + strconv.FormatInt(atomic.AddInt64(&customID, 1), 10)
err := mysql.RegisterTLSConfig(name, tlsConfig)
if err != nil {
return nil, errors.Annotate(err, "failed to RegisterTLSConfig")
}
dsn += "&tls=" + name
}

db, err = gosql.Open("mysql", dsn)
if err != nil {
return nil, errors.Trace(err)
Expand All @@ -93,8 +108,8 @@ func CreateDBWithSQLMode(user string, password string, host string, port int, sq
}

// CreateDB return sql.DB
func CreateDB(user string, password string, host string, port int) (db *gosql.DB, err error) {
return CreateDBWithSQLMode(user, password, host, port, nil)
func CreateDB(user string, password string, host string, port int, tls *tls.Config) (db *gosql.DB, err error) {
return CreateDBWithSQLMode(user, password, host, port, tls, nil)
}

func quoteSchema(schema string, table string) string {
Expand Down
2 changes: 1 addition & 1 deletion reparo/syncer/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ var (
var createDB = loader.CreateDB

func newMysqlSyncer(cfg *DBConfig, worker int, batchSize int, safemode bool) (*mysqlSyncer, error) {
db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port)
db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port, nil)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
3 changes: 2 additions & 1 deletion reparo/syncer/mysql_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package syncer

import (
"crypto/tls"
"database/sql"
"time"

Expand All @@ -24,7 +25,7 @@ func (s *testMysqlSuite) testMysqlSyncer(c *check.C, safemode bool) {
)

oldCreateDB := createDB
createDB = func(string, string, string, int) (db *sql.DB, err error) {
createDB = func(string, string, string, int, *tls.Config) (db *sql.DB, err error) {
db, mock, err = sqlmock.New()
return
}
Expand Down