diff --git a/backend/src/apiserver/client/sql.go b/backend/src/apiserver/client/sql.go index a7a93bce7bf0..b973d8dc8db9 100644 --- a/backend/src/apiserver/client/sql.go +++ b/backend/src/apiserver/client/sql.go @@ -21,18 +21,25 @@ import ( ) func CreateMySQLConfig(user, password string, mysqlServiceHost string, - mysqlServicePort string, dbName string, mysqlGroupConcatMaxLen string) *mysql.Config { + mysqlServicePort string, dbName string, mysqlGroupConcatMaxLen string, mysqlExtraParams map[string]string) *mysql.Config { + + params := map[string]string{ + "charset": "utf8", + "parseTime": "True", + "loc": "Local", + "group_concat_max_len": mysqlGroupConcatMaxLen, + } + + for k, v := range mysqlExtraParams { + params[k] = v + } + return &mysql.Config{ User: user, Passwd: password, Net: "tcp", Addr: fmt.Sprintf("%s:%s", mysqlServiceHost, mysqlServicePort), - Params: map[string]string{ - "charset": "utf8", - "parseTime": "True", - "loc": "Local", - "group_concat_max_len": mysqlGroupConcatMaxLen, - }, + Params: params, DBName: dbName, AllowNativePasswords: true, } diff --git a/backend/src/apiserver/client/sql_test.go b/backend/src/apiserver/client/sql_test.go new file mode 100644 index 000000000000..81d19cc8b5d0 --- /dev/null +++ b/backend/src/apiserver/client/sql_test.go @@ -0,0 +1,81 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "reflect" + "testing" + + "github.com/go-sql-driver/mysql" +) + +func TestCreateMySQLConfig(t *testing.T) { + type args struct { + user string + password string + host string + port string + dbName string + mysqlGroupConcatMaxLen string + mysqlExtraParams map[string]string + } + tests := []struct { + name string + args args + want *mysql.Config + }{ + { + name: "default config", + args: args{ + user: "root", + host: "mysql", + port: "3306", + mysqlGroupConcatMaxLen: "1024", + mysqlExtraParams: nil, + }, + want: &mysql.Config{ + User: "root", + Net: "tcp", + Addr: "mysql:3306", + Params: map[string]string{"charset": "utf8", "parseTime": "True", "loc": "Local", "group_concat_max_len": "1024"}, + AllowNativePasswords: true, + }, + }, + { + name: "extra parameters", + args: args{ + user: "root", + host: "mysql", + port: "3306", + mysqlGroupConcatMaxLen: "1024", + mysqlExtraParams: map[string]string{"tls": "true"}, + }, + want: &mysql.Config{ + User: "root", + Net: "tcp", + Addr: "mysql:3306", + Params: map[string]string{"charset": "utf8", "parseTime": "True", "loc": "Local", "group_concat_max_len": "1024", "tls": "true"}, + AllowNativePasswords: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CreateMySQLConfig(tt.args.user, tt.args.password, tt.args.host, tt.args.port, tt.args.dbName, tt.args.mysqlGroupConcatMaxLen, tt.args.mysqlExtraParams); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CreateMySQLConfig() = %#v, want %v", got, tt.want) + } + }) + } +} diff --git a/backend/src/apiserver/client_manager.go b/backend/src/apiserver/client_manager.go index 14f72ce85b88..81f414b0bbaf 100644 --- a/backend/src/apiserver/client_manager.go +++ b/backend/src/apiserver/client_manager.go @@ -47,6 +47,7 @@ const ( mysqlGroupConcatMaxLen = "DBConfig.GroupConcatMaxLen" kfamServiceHost = "PROFILES_KFAM_SERVICE_HOST" kfamServicePort = "PROFILES_KFAM_SERVICE_PORT" + mysqlExtraParams = "DBConfig.ExtraParams" visualizationServiceHost = "ML_PIPELINE_VISUALIZATIONSERVER_SERVICE_HOST" visualizationServicePort = "ML_PIPELINE_VISUALIZATIONSERVER_SERVICE_PORT" @@ -263,6 +264,7 @@ func initMysql(driverName string, initConnectionTimeout time.Duration) string { common.GetStringConfigWithDefault(mysqlServicePort, "3306"), "", common.GetStringConfigWithDefault(mysqlGroupConcatMaxLen, "1024"), + common.GetMapConfig(mysqlExtraParams), ) var db *sql.DB diff --git a/backend/src/apiserver/common/config.go b/backend/src/apiserver/common/config.go index 05bced88c862..5e7b79d259b4 100644 --- a/backend/src/apiserver/common/config.go +++ b/backend/src/apiserver/common/config.go @@ -40,6 +40,14 @@ func GetStringConfigWithDefault(configName, value string) string { return viper.GetString(configName) } +func GetMapConfig(configName string) map[string]string { + if !viper.IsSet(configName) { + glog.Infof("Config %s not specified, skipping", configName) + return nil + } + return viper.GetStringMapString(configName) +} + func GetBoolConfigWithDefault(configName string, value bool) bool { if !viper.IsSet(configName) { return value