diff --git a/client/auth.go b/client/auth.go index e4fa908d3..1f4d7c1de 100644 --- a/client/auth.go +++ b/client/auth.go @@ -9,6 +9,7 @@ import ( . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/charset" ) const defaultAuthPluginName = AUTH_NATIVE_PASSWORD @@ -268,8 +269,25 @@ func (c *Conn) writeAuthHandshake() error { data[11] = 0x00 // Charset [1 byte] - // use default collation id 33 here, is utf-8 - data[12] = DEFAULT_COLLATION_ID + // use default collation id 33 here, is `utf8mb3_general_ci` + collationName := c.collation + if len(collationName) == 0 { + collationName = DEFAULT_COLLATION_NAME + } + collation, err := charset.GetCollationByName(collationName) + if err != nil { + return fmt.Errorf("invalid collation name %s", collationName) + } + + // the MySQL protocol calls for the collation id to be sent as 1, where only the + // lower 8 bits are used in this field. But wireshark shows that the first byte of + // the 23 bytes of filler is used to send the right middle 8 bits of the collation id. + // see https://github.com/mysql/mysql-server/pull/541 + data[12] = byte(collation.ID & 0xff) + // if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of + // padding the filler with a 0. If ID is > 255 then the first byte of filler will contain + // the right middle 8 bits of the collation ID. + data[13] = byte((collation.ID & 0xff00) >> 8) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -291,8 +309,12 @@ func (c *Conn) writeAuthHandshake() error { } // Filler [23 bytes] (all 0x00) - pos := 13 - for ; pos < 13+23; pos++ { + // the filler starts at position 13, but the first byte of the filler + // has been set with the collation id earlier, so position 13 at this point + // will be either 0x00, or the right middle 8 bits of the collation id. + // Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00. + pos := 14 + for ; pos < 14+22; pos++ { data[pos] = 0 } diff --git a/client/auth_test.go b/client/auth_test.go index 85dba1e98..0837f1767 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -1,10 +1,14 @@ package client import ( + "net" "testing" - "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/tidb/pkg/parser/charset" "github.com/stretchr/testify/require" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/packet" ) func TestConnGenAttributes(t *testing.T) { @@ -34,3 +38,75 @@ func TestConnGenAttributes(t *testing.T) { require.Subset(t, data, fixt) } } + +func TestConnCollation(t *testing.T) { + collations := []string{ + "big5_chinese_ci", + "utf8_general_ci", + "utf8mb4_0900_ai_ci", + "utf8mb4_de_pb_0900_ai_ci", + "utf8mb4_ja_0900_as_cs", + "utf8mb4_0900_bin", + "utf8mb4_zh_pinyin_tidb_as_cs", + } + + // test all supported collations by calling writeAuthHandshake() and reading the bytes + // sent to the server to ensure the collation id is set correctly + for _, c := range collations { + collation, err := charset.GetCollationByName(c) + require.NoError(t, err) + server := sendAuthResponse(t, collation.Name) + // read the all the bytes of the handshake response so that client goroutine can complete without blocking + // on the server read. + handShakeResponse := make([]byte, 128) + _, err = server.Read(handShakeResponse) + require.NoError(t, err) + + // validate the collation id is set correctly + // if the collation ID is <= 255 the collation ID is stored in the 12th byte + if collation.ID <= 255 { + require.Equal(t, byte(collation.ID), handShakeResponse[12]) + // the 13th byte should always be 0x00 + require.Equal(t, byte(0x00), handShakeResponse[13]) + } else { + // if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes + require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12]) + require.Equal(t, byte(collation.ID>>8), handShakeResponse[13]) + } + + // sanity check: validate the 22 bytes of filler with value 0x00 are set correctly + for i := 14; i < 14+22; i++ { + require.Equal(t, byte(0x00), handShakeResponse[i]) + } + + // and finally the username + username := string(handShakeResponse[36:40]) + require.Equal(t, "test", username) + + require.NoError(t, server.Close()) + } +} + +func sendAuthResponse(t *testing.T, collation string) net.Conn { + server, client := net.Pipe() + c := &Conn{ + Conn: &packet.Conn{ + Conn: client, + }, + authPluginName: "mysql_native_password", + user: "test", + db: "test", + password: "test", + proto: "tcp", + collation: collation, + salt: ([]byte)("123456781234567812345678"), + } + + go func() { + err := c.writeAuthHandshake() + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) + }() + return server +} diff --git a/client/client_test.go b/client/client_test.go index c47c795ef..10515e622 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) { func (s *clientTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "") + s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic, but this is essentially a no-op since + // the collation set is the default value + _ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) + }) require.NoError(s.T(), err) var result *mysql.Result @@ -228,6 +232,22 @@ func (s *clientTestSuite) TestConn_SetCharset() { require.NoError(s.T(), err) } +func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { + err := s.c.SetCollation("latin1_swedish_ci") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "cannot set collation after connection is established") +} + +func (s *clientTestSuite) TestConn_SetCollation() { + addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic + _ = conn.SetCollation("invalid_collation") + }) + + require.Error(s.T(), err) +} + func (s *clientTestSuite) testStmt_DropTable() { str := `drop table if exists mixer_test_stmt` diff --git a/client/conn.go b/client/conn.go index b1f3e52d1..9fc7faf16 100644 --- a/client/conn.go +++ b/client/conn.go @@ -37,6 +37,8 @@ type Conn struct { status uint16 charset string + // sets the collation to be set on the auth handshake, this does not issue a 'set names' command + collation string salt []byte authPluginName string @@ -67,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options . ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - dialer := &net.Dialer{} + return ConnectWithContext(ctx, addr, user, password, dbName, options...) +} +// ConnectWithContext to a MySQL addr using the provided context. +func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { + dialer := &net.Dialer{} return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...) } // Dialer connects to the address on the named network using the provided context. type Dialer func(ctx context.Context, network, address string) (net.Conn, error) -// Connect to a MySQL server using the given Dialer. +// ConnectWithDialer to a MySQL server using the given Dialer. func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) { c := new(Conn) @@ -357,6 +363,19 @@ func (c *Conn) SetCharset(charset string) error { } } +func (c *Conn) SetCollation(collation string) error { + if len(c.serverVersion) != 0 { + return errors.Trace(errors.Errorf("cannot set collation after connection is established")) + } + + c.collation = collation + return nil +} + +func (c *Conn) GetCollation() string { + return c.collation +} + func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil { return nil, errors.Trace(err)