Skip to content

Commit

Permalink
allow setting the collation in auth handshake (#860)
Browse files Browse the repository at this point in the history
* allow setting the collation in auth handshake

* Allow connect with context in order to provide configurable connect timeouts

* fixing linting error

* support collations IDs greater than 255 on the auth handshake

* Update client/auth.go

Co-authored-by: Daniël van Eeden <github@myname.nl>

* address PR feedback

* fixing comments

* fixing comments

* fix linting errors

* restore tests that were commented out accidently

* fixing more typos in the comments

* Apply suggestions from code review

Co-authored-by: lance6716 <lance6716@gmail.com>

* addressing PR feedback

---------

Co-authored-by: dvilaverde <dvilaverde@adobe.com>
Co-authored-by: Daniël van Eeden <github@myname.nl>
Co-authored-by: lance6716 <lance6716@gmail.com>
  • Loading branch information
4 people authored Apr 30, 2024
1 parent 7c31dc4 commit 8551be2
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 8 deletions.
30 changes: 26 additions & 4 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/charset"
)

const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
Expand Down Expand Up @@ -268,8 +269,25 @@ func (c *Conn) writeAuthHandshake() error {
data[11] = 0x00

// Charset [1 byte]
// use default collation id 33 here, is utf-8
data[12] = DEFAULT_COLLATION_ID
// use default collation id 33 here, is `utf8mb3_general_ci`
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
}
collation, err := charset.GetCollationByName(collationName)
if err != nil {
return fmt.Errorf("invalid collation name %s", collationName)
}

// the MySQL protocol calls for the collation id to be sent as 1, where only the
// lower 8 bits are used in this field. But wireshark shows that the first byte of
// the 23 bytes of filler is used to send the right middle 8 bits of the collation id.
// see https://github.com/mysql/mysql-server/pull/541
data[12] = byte(collation.ID & 0xff)
// if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of
// padding the filler with a 0. If ID is > 255 then the first byte of filler will contain
// the right middle 8 bits of the collation ID.
data[13] = byte((collation.ID & 0xff00) >> 8)

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
Expand All @@ -291,8 +309,12 @@ func (c *Conn) writeAuthHandshake() error {
}

// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
// the filler starts at position 13, but the first byte of the filler
// has been set with the collation id earlier, so position 13 at this point
// will be either 0x00, or the right middle 8 bits of the collation id.
// Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00.
pos := 14
for ; pos < 14+22; pos++ {
data[pos] = 0
}

Expand Down
78 changes: 77 additions & 1 deletion client/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package client

import (
"net"
"testing"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/stretchr/testify/require"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
)

func TestConnGenAttributes(t *testing.T) {
Expand Down Expand Up @@ -34,3 +38,75 @@ func TestConnGenAttributes(t *testing.T) {
require.Subset(t, data, fixt)
}
}

func TestConnCollation(t *testing.T) {
collations := []string{
"big5_chinese_ci",
"utf8_general_ci",
"utf8mb4_0900_ai_ci",
"utf8mb4_de_pb_0900_ai_ci",
"utf8mb4_ja_0900_as_cs",
"utf8mb4_0900_bin",
"utf8mb4_zh_pinyin_tidb_as_cs",
}

// test all supported collations by calling writeAuthHandshake() and reading the bytes
// sent to the server to ensure the collation id is set correctly
for _, c := range collations {
collation, err := charset.GetCollationByName(c)
require.NoError(t, err)
server := sendAuthResponse(t, collation.Name)
// read the all the bytes of the handshake response so that client goroutine can complete without blocking
// on the server read.
handShakeResponse := make([]byte, 128)
_, err = server.Read(handShakeResponse)
require.NoError(t, err)

// validate the collation id is set correctly
// if the collation ID is <= 255 the collation ID is stored in the 12th byte
if collation.ID <= 255 {
require.Equal(t, byte(collation.ID), handShakeResponse[12])
// the 13th byte should always be 0x00
require.Equal(t, byte(0x00), handShakeResponse[13])
} else {
// if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes
require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12])
require.Equal(t, byte(collation.ID>>8), handShakeResponse[13])
}

// sanity check: validate the 22 bytes of filler with value 0x00 are set correctly
for i := 14; i < 14+22; i++ {
require.Equal(t, byte(0x00), handShakeResponse[i])
}

// and finally the username
username := string(handShakeResponse[36:40])
require.Equal(t, "test", username)

require.NoError(t, server.Close())
}
}

func sendAuthResponse(t *testing.T, collation string) net.Conn {
server, client := net.Pipe()
c := &Conn{
Conn: &packet.Conn{
Conn: client,
},
authPluginName: "mysql_native_password",
user: "test",
db: "test",
password: "test",
proto: "tcp",
collation: collation,
salt: ([]byte)("123456781234567812345678"),
}

go func() {
err := c.writeAuthHandshake()
require.NoError(t, err)
err = c.Close()
require.NoError(t, err)
}()
return server
}
22 changes: 21 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
func (s *clientTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "")
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic, but this is essentially a no-op since
// the collation set is the default value
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
})
require.NoError(s.T(), err)

var result *mysql.Result
Expand Down Expand Up @@ -228,6 +232,22 @@ func (s *clientTestSuite) TestConn_SetCharset() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
err := s.c.SetCollation("latin1_swedish_ci")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "cannot set collation after connection is established")
}

func (s *clientTestSuite) TestConn_SetCollation() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic
_ = conn.SetCollation("invalid_collation")
})

require.Error(s.T(), err)
}

func (s *clientTestSuite) testStmt_DropTable() {
str := `drop table if exists mixer_test_stmt`

Expand Down
23 changes: 21 additions & 2 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Conn struct {
status uint16

charset string
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
collation string

salt []byte
authPluginName string
Expand Down Expand Up @@ -67,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options .
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

dialer := &net.Dialer{}
return ConnectWithContext(ctx, addr, user, password, dbName, options...)
}

// ConnectWithContext to a MySQL addr using the provided context.
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
dialer := &net.Dialer{}
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
}

// Dialer connects to the address on the named network using the provided context.
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)

// Connect to a MySQL server using the given Dialer.
// ConnectWithDialer to a MySQL server using the given Dialer.
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
c := new(Conn)

Expand Down Expand Up @@ -357,6 +363,19 @@ func (c *Conn) SetCharset(charset string) error {
}
}

func (c *Conn) SetCollation(collation string) error {
if len(c.serverVersion) != 0 {
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
}

c.collation = collation
return nil
}

func (c *Conn) GetCollation() string {
return c.collation
}

func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
Expand Down

0 comments on commit 8551be2

Please sign in to comment.