Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow setting the collation in auth handshake #860

Merged
merged 13 commits into from
Apr 30, 2024
26 changes: 24 additions & 2 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/charset"
)

const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
Expand Down Expand Up @@ -268,8 +269,24 @@ func (c *Conn) writeAuthHandshake() error {
data[11] = 0x00

// Charset [1 byte]
// use default collation id 33 here, is utf-8
data[12] = DEFAULT_COLLATION_ID
// use default collation id 33 here, is `utf8mb3_general_ci`
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
}
collation, err := charset.GetCollationByName(collationName)
if err != nil {
return fmt.Errorf("invalid collation name %s", collationName)
}

// the MySQL protocol calls for the collation id to be sent as 1, where only the
// lower 8 bits are used in this field. But wireshark shows that the first by of
// the 23 bytes of filler is used to send the upper 8 bits of the collation id.
// see https://github.com/mysql/mysql-server/pull/541
data[12] = byte(collation.ID & 0xff)
if collation.ID > 255 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be possible always encode the collation.ID to a 2 byte integer, if the right endianness is used. That might simplify the code a bit

data[13] = byte(collation.ID >> 8)
}

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
Expand All @@ -292,6 +309,11 @@ func (c *Conn) writeAuthHandshake() error {

// Filler [23 bytes] (all 0x00)
pos := 13
if collation.ID > 255 {
// skip setting the first byte of the filler to 0x00 since it is used to
// send the upper 8 bits of the collation id
pos++
}
for ; pos < 13+23; pos++ {
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
data[pos] = 0
}
Expand Down
73 changes: 73 additions & 0 deletions client/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package client

import (
"github.com/go-mysql-org/go-mysql/packet"

Check failure on line 4 in client/auth_test.go

View workflow job for this annotation

GitHub Actions / golangci

File is not `goimports`-ed (goimports)
"github.com/pingcap/tidb/pkg/parser/charset"
"net"
"testing"

"github.com/go-mysql-org/go-mysql/mysql"
Expand Down Expand Up @@ -34,3 +37,73 @@
require.Subset(t, data, fixt)
}
}

func TestConnCollation(t *testing.T) {
collations := []string{"big5_chinese_ci",
"utf8_general_ci",
"utf8mb4_0900_ai_ci",
"utf8mb4_de_pb_0900_ai_ci",
"utf8mb4_ja_0900_as_cs",
"utf8mb4_0900_bin",
"utf8mb4_zh_pinyin_tidb_as_cs"}

// test all supported collations by calling writeAuthHandshake() and reading the bytes
// sent to the server to ensure the collation id is set correctly
for _, c := range collations {
collation, err := charset.GetCollationByName(c)
require.NoError(t, err)
server := sendAuthResponse(t, collation.Name)
// read the all the bytes of the handshake response so that client goroutine can complete without blocking
// on the server read.
handShakeResponse := make([]byte, 128)
_, err = server.Read(handShakeResponse)
require.NoError(t, err)

// validate the collation id is set correctly
// if the collation ID is <= 255 the collation ID is stored in the 12th byte
if collation.ID <= 255 {
require.Equal(t, byte(collation.ID), handShakeResponse[12])
// sanity check: validate the 23 bytes of filler with value 0x00 are set correctly
for i := 13; i < 13+23; i++ {
require.Equal(t, byte(0x00), handShakeResponse[i])
}
} else {
// if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes
require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12])
require.Equal(t, byte(collation.ID>>8), handShakeResponse[13])

// sanity check: validate the 22 bytes of filler with value 0x00 are set correctly
for i := 14; i < 14+22; i++ {
require.Equal(t, byte(0x00), handShakeResponse[i])
}
}

// and finally the username
password := string(handShakeResponse[36:40])
require.Equal(t, "test", password)
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved

require.NoError(t, server.Close())
}
}

func sendAuthResponse(t *testing.T, collation string) net.Conn {
server, client := net.Pipe()
c := &Conn{
Conn: &packet.Conn{
Conn: client,
},
authPluginName: "mysql_native_password",
user: "test",
db: "test",
password: "test",
proto: "tcp",
collation: collation,
salt: ([]byte)("123456781234567812345678"),
}

go func() {
err := c.writeAuthHandshake()
require.NoError(t, err)
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
}()
return server
}
21 changes: 20 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
func (s *clientTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "")
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic, but this is essentially a no-op since
// the collation set is the default value
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
})
require.NoError(s.T(), err)

var result *mysql.Result
Expand Down Expand Up @@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
err := s.c.SetCollation("latin1_swedish_ci")
require.Error(s.T(), err)
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
}

func (s *clientTestSuite) TestConn_SetCollation() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic
_ = conn.SetCollation("invalid_collation")
})

require.Error(s.T(), err)
}

func (s *clientTestSuite) testStmt_DropTable() {
str := `drop table if exists mixer_test_stmt`

Expand Down
24 changes: 22 additions & 2 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Conn struct {
status uint16

charset string
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
collation string

salt []byte
authPluginName string
Expand Down Expand Up @@ -67,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options .
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

dialer := &net.Dialer{}
return ConnectWithContext(ctx, addr, user, password, dbName, options...)
}

// ConnectWithContext to a MySQL addr using the provided context.
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
dialer := &net.Dialer{}
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
}

// Dialer connects to the address on the named network using the provided context.
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)

// Connect to a MySQL server using the given Dialer.
// ConnectWithDialer to a MySQL server using the given Dialer.
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
c := new(Conn)

Expand Down Expand Up @@ -357,6 +363,20 @@ func (c *Conn) SetCharset(charset string) error {
}
}

func (c *Conn) SetCollation(collation string) error {
if c.status == 0 {
dvilaverde marked this conversation as resolved.
Show resolved Hide resolved
c.collation = collation
} else {
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
}

return nil
}

func (c *Conn) GetCollation() string {
return c.collation
}

func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
Expand Down
Loading