From cc2445c9b070df5bd7c94591b5cc22c1116c0596 Mon Sep 17 00:00:00 2001 From: atercattus Date: Wed, 5 Apr 2023 13:55:18 +0400 Subject: [PATCH 1/7] add all known capability flags and fix readInitialHandshake (mysql8 compatibility) --- client/auth.go | 64 +++++++++++++++-------- client/conn.go | 139 +++++++++++++++++++++++++++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 + mysql/const.go | 14 ++++- mysql/util.go | 18 +++++++ 6 files changed, 211 insertions(+), 27 deletions(-) diff --git a/client/auth.go b/client/auth.go index 54e961bf9..0585fc41c 100644 --- a/client/auth.go +++ b/client/auth.go @@ -26,7 +26,7 @@ func authPluginAllowed(pluginName string) bool { return false } -// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +// See: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html func (c *Conn) readInitialHandshake() error { data, err := c.ReadPacket() if err != nil { @@ -40,24 +40,28 @@ func (c *Conn) readInitialHandshake() error { if data[0] < MinProtocolVersion { return errors.Errorf("invalid protocol version %d, must >= 10", data[0]) } + pos := 1 // skip mysql version // mysql version end with 0x00 - version := data[1 : bytes.IndexByte(data[1:], 0x00)+1] + version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1] c.serverVersion = string(version) - pos := 1 + len(version) + pos += len(version) + 1 /*trailing zero byte*/ // connection id length is 4 c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4]) pos += 4 - c.salt = []byte{} - c.salt = append(c.salt, data[pos:pos+8]...) + // first 8 bytes of the plugin provided data (scramble) + c.salt = append(c.salt[:0], data[pos:pos+8]...) + pos += 8 - // skip filter - pos += 8 + 1 + if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble + return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos])) + } + pos++ - // capability lower 2 bytes + // The lower 2 bytes of the Capabilities Flags c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2])) // check protocol if c.capability&CLIENT_PROTOCOL_41 == 0 { @@ -69,35 +73,51 @@ func (c *Conn) readInitialHandshake() error { pos += 2 if len(data) > pos { - // skip server charset + // default server a_protocol_character_set, only the lower 8-bits // c.charset = data[pos] pos += 1 c.status = binary.LittleEndian.Uint16(data[pos : pos+2]) pos += 2 - // capability flags (upper 2 bytes) + + // The upper 2 bytes of the Capabilities Flags c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability pos += 2 - // auth_data is end with 0x00, min data length is 13 + 8 = 21 - // ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake - maxAuthDataLen := 21 - if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen { - maxAuthDataLen = int(data[pos]) + // length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0 + authPluginDataLen := data[pos] + if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) { + return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen) } + pos++ // skip reserved (all [00]) - pos += 10 + 1 + pos += 6 - // auth_data is end with 0x00, so we need to trim 0x00 - resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1 - c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...) + // https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift + if c.capability&CLIENT_LONG_PASSWORD != 0 { + // skip reserved (all [00]) + pos += 4 + } else { + // unknown + pos += 4 + } + + if rest := int(authPluginDataLen) - 8; rest > 0 { + authPluginDataPart2 := data[pos : pos+rest] + pos += rest - // skip reset of end pos - pos = resetOfAuthDataEndPos + 1 + c.salt = append(c.salt, authPluginDataPart2...) + } if c.capability&CLIENT_PLUGIN_AUTH != 0 { - c.authPluginName = string(data[pos : len(data)-1]) + c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(c.authPluginName) + + if data[pos] != 0 { + return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos])) + } + pos++ } } diff --git a/client/conn.go b/client/conn.go index 7716bf387..8b2940ee8 100644 --- a/client/conn.go +++ b/client/conn.go @@ -9,10 +9,11 @@ import ( "strings" "time" + "github.com/pingcap/errors" + . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" "github.com/go-mysql-org/go-mysql/utils" - "github.com/pingcap/errors" ) type Conn struct { @@ -118,18 +119,18 @@ func (c *Conn) handshake() error { var err error if err = c.readInitialHandshake(); err != nil { c.Close() - return errors.Trace(err) + return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err)) } if err := c.writeAuthHandshake(); err != nil { c.Close() - return errors.Trace(err) + return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err)) } if err := c.handleAuthResult(); err != nil { c.Close() - return errors.Trace(err) + return errors.Trace(fmt.Errorf("handleAuthResult: %w", err)) } return nil @@ -198,6 +199,10 @@ func (c *Conn) GetServerVersion() string { return c.serverVersion } +func (c *Conn) CompareServerVersion(v string) (int, error) { + return CompareServerVersions(c.serverVersion, v) +} + func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { if len(args) == 0 { return c.exec(command) @@ -403,3 +408,129 @@ func (c *Conn) exec(query string) (*Result, error) { return c.readResult(false) } + +func (c *Conn) CapabilityString() string { + var caps []string + capability := c.capability + for i := 0; capability != 0; i++ { + field := uint32(1 << i) + if capability&field == 0 { + continue + } + capability ^= field + + switch field { + case CLIENT_LONG_PASSWORD: + caps = append(caps, "CLIENT_LONG_PASSWORD") + case CLIENT_FOUND_ROWS: + caps = append(caps, "CLIENT_FOUND_ROWS") + case CLIENT_LONG_FLAG: + caps = append(caps, "CLIENT_LONG_FLAG") + case CLIENT_CONNECT_WITH_DB: + caps = append(caps, "CLIENT_CONNECT_WITH_DB") + case CLIENT_NO_SCHEMA: + caps = append(caps, "CLIENT_NO_SCHEMA") + case CLIENT_COMPRESS: + caps = append(caps, "CLIENT_COMPRESS") + case CLIENT_ODBC: + caps = append(caps, "CLIENT_ODBC") + case CLIENT_LOCAL_FILES: + caps = append(caps, "CLIENT_LOCAL_FILES") + case CLIENT_IGNORE_SPACE: + caps = append(caps, "CLIENT_IGNORE_SPACE") + case CLIENT_PROTOCOL_41: + caps = append(caps, "CLIENT_PROTOCOL_41") + case CLIENT_INTERACTIVE: + caps = append(caps, "CLIENT_INTERACTIVE") + case CLIENT_SSL: + caps = append(caps, "CLIENT_SSL") + case CLIENT_IGNORE_SIGPIPE: + caps = append(caps, "CLIENT_IGNORE_SIGPIPE") + case CLIENT_TRANSACTIONS: + caps = append(caps, "CLIENT_TRANSACTIONS") + case CLIENT_RESERVED: + caps = append(caps, "CLIENT_RESERVED") + case CLIENT_SECURE_CONNECTION: + caps = append(caps, "CLIENT_SECURE_CONNECTION") + case CLIENT_MULTI_STATEMENTS: + caps = append(caps, "CLIENT_MULTI_STATEMENTS") + case CLIENT_MULTI_RESULTS: + caps = append(caps, "CLIENT_MULTI_RESULTS") + case CLIENT_PS_MULTI_RESULTS: + caps = append(caps, "CLIENT_PS_MULTI_RESULTS") + case CLIENT_PLUGIN_AUTH: + caps = append(caps, "CLIENT_PLUGIN_AUTH") + case CLIENT_CONNECT_ATTRS: + caps = append(caps, "CLIENT_CONNECT_ATTRS") + case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: + caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA") + case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: + caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS") + case CLIENT_SESSION_TRACK: + caps = append(caps, "CLIENT_SESSION_TRACK") + case CLIENT_DEPRECATE_EOF: + caps = append(caps, "CLIENT_DEPRECATE_EOF") + case CLIENT_OPTIONAL_RESULTSET_METADATA: + caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA") + case CLIENT_ZSTD_COMPRESSION_ALGORITHM: + caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM") + case CLIENT_QUERY_ATTRIBUTES: + caps = append(caps, "CLIENT_QUERY_ATTRIBUTES") + case MULTI_FACTOR_AUTHENTICATION: + caps = append(caps, "MULTI_FACTOR_AUTHENTICATION") + case CLIENT_CAPABILITY_EXTENSION: + caps = append(caps, "CLIENT_CAPABILITY_EXTENSION") + case CLIENT_SSL_VERIFY_SERVER_CERT: + caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT") + case CLIENT_REMEMBER_OPTIONS: + caps = append(caps, "CLIENT_REMEMBER_OPTIONS") + default: + caps = append(caps, fmt.Sprintf("(%d)", field)) + } + } + + return strings.Join(caps, "|") +} + +func (c *Conn) StatusString() string { + var stats []string + status := c.status + for i := 0; status != 0; i++ { + field := uint16(1 << i) + if status&field == 0 { + continue + } + status ^= field + + switch field { + case SERVER_STATUS_IN_TRANS: + stats = append(stats, "SERVER_STATUS_IN_TRANS") + case SERVER_STATUS_AUTOCOMMIT: + stats = append(stats, "SERVER_STATUS_AUTOCOMMIT") + case SERVER_MORE_RESULTS_EXISTS: + stats = append(stats, "SERVER_MORE_RESULTS_EXISTS") + case SERVER_STATUS_NO_GOOD_INDEX_USED: + stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED") + case SERVER_STATUS_NO_INDEX_USED: + stats = append(stats, "SERVER_STATUS_NO_INDEX_USED") + case SERVER_STATUS_CURSOR_EXISTS: + stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS") + case SERVER_STATUS_LAST_ROW_SEND: + stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND") + case SERVER_STATUS_DB_DROPPED: + stats = append(stats, "SERVER_STATUS_DB_DROPPED") + case SERVER_STATUS_NO_BACKSLASH_ESCAPED: + stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED") + case SERVER_STATUS_METADATA_CHANGED: + stats = append(stats, "SERVER_STATUS_METADATA_CHANGED") + case SERVER_QUERY_WAS_SLOW: + stats = append(stats, "SERVER_QUERY_WAS_SLOW") + case SERVER_PS_OUT_PARAMS: + stats = append(stats, "SERVER_PS_OUT_PARAMS") + default: + stats = append(stats, fmt.Sprintf("(%d)", field)) + } + } + + return strings.Join(stats, "|") +} diff --git a/go.mod b/go.mod index b79f997e2..15ab27a2f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/BurntSushi/toml v0.3.1 github.com/DataDog/zstd v1.5.2 + github.com/Masterminds/semver v1.5.0 github.com/go-sql-driver/mysql v1.6.0 github.com/google/uuid v1.3.0 github.com/jmoiron/sqlx v1.3.3 diff --git a/go.sum b/go.sum index ea24a843c..3df58f187 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= diff --git a/mysql/const.go b/mysql/const.go index e2f3a0afe..a1a5bde42 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -76,6 +76,8 @@ const ( ) const ( + // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html + CLIENT_LONG_PASSWORD uint32 = 1 << iota CLIENT_FOUND_ROWS CLIENT_LONG_FLAG @@ -98,6 +100,16 @@ const ( CLIENT_PLUGIN_AUTH CLIENT_CONNECT_ATTRS CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS + CLIENT_SESSION_TRACK + CLIENT_DEPRECATE_EOF + CLIENT_OPTIONAL_RESULTSET_METADATA + CLIENT_ZSTD_COMPRESSION_ALGORITHM + CLIENT_QUERY_ATTRIBUTES + MULTI_FACTOR_AUTHENTICATION + CLIENT_CAPABILITY_EXTENSION + CLIENT_SSL_VERIFY_SERVER_CERT + CLIENT_REMEMBER_OPTIONS ) const ( @@ -119,7 +131,7 @@ const ( MYSQL_TYPE_VARCHAR MYSQL_TYPE_BIT - //mysql 5.6 + // mysql 5.6 MYSQL_TYPE_TIMESTAMP2 MYSQL_TYPE_DATETIME2 MYSQL_TYPE_TIME2 diff --git a/mysql/util.go b/mysql/util.go index e8d436fa2..6d8ec4471 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/Masterminds/semver" "github.com/pingcap/errors" "github.com/siddontang/go/hack" ) @@ -379,6 +380,23 @@ func ErrorEqual(err1, err2 error) bool { return e1.Error() == e2.Error() } +func CompareServerVersions(a, b string) (int, error) { + var ( + aVer, bVer *semver.Version + err error + ) + + if aVer, err = semver.NewVersion(a); err != nil { + return 0, fmt.Errorf("cannot parse %q as semver: %w", a, err) + } + + if bVer, err = semver.NewVersion(b); err != nil { + return 0, fmt.Errorf("cannot parse %q as semver: %w", b, err) + } + + return aVer.Compare(bVer), nil +} + var encodeRef = map[byte]byte{ '\x00': '0', '\'': '\'', From 136935a9e96856ab21b5304156787e98a1ae148f Mon Sep 17 00:00:00 2001 From: atercattus Date: Wed, 5 Apr 2023 13:55:18 +0400 Subject: [PATCH 2/7] add all known capability flags and fix readInitialHandshake (mysql8 compatibility) --- client/auth.go | 64 +++++++++++++++-------- client/conn.go | 139 +++++++++++++++++++++++++++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 + mysql/const.go | 14 ++++- mysql/util.go | 18 +++++++ 6 files changed, 211 insertions(+), 27 deletions(-) diff --git a/client/auth.go b/client/auth.go index 54e961bf9..d69000d23 100644 --- a/client/auth.go +++ b/client/auth.go @@ -26,7 +26,7 @@ func authPluginAllowed(pluginName string) bool { return false } -// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +// See: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html func (c *Conn) readInitialHandshake() error { data, err := c.ReadPacket() if err != nil { @@ -40,24 +40,28 @@ func (c *Conn) readInitialHandshake() error { if data[0] < MinProtocolVersion { return errors.Errorf("invalid protocol version %d, must >= 10", data[0]) } + pos := 1 // skip mysql version // mysql version end with 0x00 - version := data[1 : bytes.IndexByte(data[1:], 0x00)+1] + version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1] c.serverVersion = string(version) - pos := 1 + len(version) + pos += len(version) + 1 /*trailing zero byte*/ // connection id length is 4 c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4]) pos += 4 - c.salt = []byte{} - c.salt = append(c.salt, data[pos:pos+8]...) + // first 8 bytes of the plugin provided data (scramble) + c.salt = append(c.salt[:0], data[pos:pos+8]...) + pos += 8 - // skip filter - pos += 8 + 1 + if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble + return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos])) + } + pos++ - // capability lower 2 bytes + // The lower 2 bytes of the Capabilities Flags c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2])) // check protocol if c.capability&CLIENT_PROTOCOL_41 == 0 { @@ -69,35 +73,51 @@ func (c *Conn) readInitialHandshake() error { pos += 2 if len(data) > pos { - // skip server charset + // default server a_protocol_character_set, only the lower 8-bits // c.charset = data[pos] pos += 1 c.status = binary.LittleEndian.Uint16(data[pos : pos+2]) pos += 2 - // capability flags (upper 2 bytes) + + // The upper 2 bytes of the Capabilities Flags c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability pos += 2 - // auth_data is end with 0x00, min data length is 13 + 8 = 21 - // ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake - maxAuthDataLen := 21 - if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen { - maxAuthDataLen = int(data[pos]) + // length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0 + authPluginDataLen := data[pos] + if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) { + return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen) } + pos++ // skip reserved (all [00]) - pos += 10 + 1 + pos += 6 - // auth_data is end with 0x00, so we need to trim 0x00 - resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1 - c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...) + // https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift + if c.capability&CLIENT_LONG_PASSWORD != 0 { + // skip reserved (all [00]) + pos += 4 + } else { + // unknown + pos += 4 + } + + if rest := int(authPluginDataLen) - 8; rest > 0 { + authPluginDataPart2 := data[pos : pos+rest] + pos += rest - // skip reset of end pos - pos = resetOfAuthDataEndPos + 1 + c.salt = append(c.salt, authPluginDataPart2...) + } if c.capability&CLIENT_PLUGIN_AUTH != 0 { - c.authPluginName = string(data[pos : len(data)-1]) + c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(c.authPluginName) + + if data[pos] != 0 { + return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos])) + } + // pos++ // ineffectual } } diff --git a/client/conn.go b/client/conn.go index 7716bf387..8b2940ee8 100644 --- a/client/conn.go +++ b/client/conn.go @@ -9,10 +9,11 @@ import ( "strings" "time" + "github.com/pingcap/errors" + . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" "github.com/go-mysql-org/go-mysql/utils" - "github.com/pingcap/errors" ) type Conn struct { @@ -118,18 +119,18 @@ func (c *Conn) handshake() error { var err error if err = c.readInitialHandshake(); err != nil { c.Close() - return errors.Trace(err) + return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err)) } if err := c.writeAuthHandshake(); err != nil { c.Close() - return errors.Trace(err) + return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err)) } if err := c.handleAuthResult(); err != nil { c.Close() - return errors.Trace(err) + return errors.Trace(fmt.Errorf("handleAuthResult: %w", err)) } return nil @@ -198,6 +199,10 @@ func (c *Conn) GetServerVersion() string { return c.serverVersion } +func (c *Conn) CompareServerVersion(v string) (int, error) { + return CompareServerVersions(c.serverVersion, v) +} + func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { if len(args) == 0 { return c.exec(command) @@ -403,3 +408,129 @@ func (c *Conn) exec(query string) (*Result, error) { return c.readResult(false) } + +func (c *Conn) CapabilityString() string { + var caps []string + capability := c.capability + for i := 0; capability != 0; i++ { + field := uint32(1 << i) + if capability&field == 0 { + continue + } + capability ^= field + + switch field { + case CLIENT_LONG_PASSWORD: + caps = append(caps, "CLIENT_LONG_PASSWORD") + case CLIENT_FOUND_ROWS: + caps = append(caps, "CLIENT_FOUND_ROWS") + case CLIENT_LONG_FLAG: + caps = append(caps, "CLIENT_LONG_FLAG") + case CLIENT_CONNECT_WITH_DB: + caps = append(caps, "CLIENT_CONNECT_WITH_DB") + case CLIENT_NO_SCHEMA: + caps = append(caps, "CLIENT_NO_SCHEMA") + case CLIENT_COMPRESS: + caps = append(caps, "CLIENT_COMPRESS") + case CLIENT_ODBC: + caps = append(caps, "CLIENT_ODBC") + case CLIENT_LOCAL_FILES: + caps = append(caps, "CLIENT_LOCAL_FILES") + case CLIENT_IGNORE_SPACE: + caps = append(caps, "CLIENT_IGNORE_SPACE") + case CLIENT_PROTOCOL_41: + caps = append(caps, "CLIENT_PROTOCOL_41") + case CLIENT_INTERACTIVE: + caps = append(caps, "CLIENT_INTERACTIVE") + case CLIENT_SSL: + caps = append(caps, "CLIENT_SSL") + case CLIENT_IGNORE_SIGPIPE: + caps = append(caps, "CLIENT_IGNORE_SIGPIPE") + case CLIENT_TRANSACTIONS: + caps = append(caps, "CLIENT_TRANSACTIONS") + case CLIENT_RESERVED: + caps = append(caps, "CLIENT_RESERVED") + case CLIENT_SECURE_CONNECTION: + caps = append(caps, "CLIENT_SECURE_CONNECTION") + case CLIENT_MULTI_STATEMENTS: + caps = append(caps, "CLIENT_MULTI_STATEMENTS") + case CLIENT_MULTI_RESULTS: + caps = append(caps, "CLIENT_MULTI_RESULTS") + case CLIENT_PS_MULTI_RESULTS: + caps = append(caps, "CLIENT_PS_MULTI_RESULTS") + case CLIENT_PLUGIN_AUTH: + caps = append(caps, "CLIENT_PLUGIN_AUTH") + case CLIENT_CONNECT_ATTRS: + caps = append(caps, "CLIENT_CONNECT_ATTRS") + case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: + caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA") + case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: + caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS") + case CLIENT_SESSION_TRACK: + caps = append(caps, "CLIENT_SESSION_TRACK") + case CLIENT_DEPRECATE_EOF: + caps = append(caps, "CLIENT_DEPRECATE_EOF") + case CLIENT_OPTIONAL_RESULTSET_METADATA: + caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA") + case CLIENT_ZSTD_COMPRESSION_ALGORITHM: + caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM") + case CLIENT_QUERY_ATTRIBUTES: + caps = append(caps, "CLIENT_QUERY_ATTRIBUTES") + case MULTI_FACTOR_AUTHENTICATION: + caps = append(caps, "MULTI_FACTOR_AUTHENTICATION") + case CLIENT_CAPABILITY_EXTENSION: + caps = append(caps, "CLIENT_CAPABILITY_EXTENSION") + case CLIENT_SSL_VERIFY_SERVER_CERT: + caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT") + case CLIENT_REMEMBER_OPTIONS: + caps = append(caps, "CLIENT_REMEMBER_OPTIONS") + default: + caps = append(caps, fmt.Sprintf("(%d)", field)) + } + } + + return strings.Join(caps, "|") +} + +func (c *Conn) StatusString() string { + var stats []string + status := c.status + for i := 0; status != 0; i++ { + field := uint16(1 << i) + if status&field == 0 { + continue + } + status ^= field + + switch field { + case SERVER_STATUS_IN_TRANS: + stats = append(stats, "SERVER_STATUS_IN_TRANS") + case SERVER_STATUS_AUTOCOMMIT: + stats = append(stats, "SERVER_STATUS_AUTOCOMMIT") + case SERVER_MORE_RESULTS_EXISTS: + stats = append(stats, "SERVER_MORE_RESULTS_EXISTS") + case SERVER_STATUS_NO_GOOD_INDEX_USED: + stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED") + case SERVER_STATUS_NO_INDEX_USED: + stats = append(stats, "SERVER_STATUS_NO_INDEX_USED") + case SERVER_STATUS_CURSOR_EXISTS: + stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS") + case SERVER_STATUS_LAST_ROW_SEND: + stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND") + case SERVER_STATUS_DB_DROPPED: + stats = append(stats, "SERVER_STATUS_DB_DROPPED") + case SERVER_STATUS_NO_BACKSLASH_ESCAPED: + stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED") + case SERVER_STATUS_METADATA_CHANGED: + stats = append(stats, "SERVER_STATUS_METADATA_CHANGED") + case SERVER_QUERY_WAS_SLOW: + stats = append(stats, "SERVER_QUERY_WAS_SLOW") + case SERVER_PS_OUT_PARAMS: + stats = append(stats, "SERVER_PS_OUT_PARAMS") + default: + stats = append(stats, fmt.Sprintf("(%d)", field)) + } + } + + return strings.Join(stats, "|") +} diff --git a/go.mod b/go.mod index b79f997e2..15ab27a2f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/BurntSushi/toml v0.3.1 github.com/DataDog/zstd v1.5.2 + github.com/Masterminds/semver v1.5.0 github.com/go-sql-driver/mysql v1.6.0 github.com/google/uuid v1.3.0 github.com/jmoiron/sqlx v1.3.3 diff --git a/go.sum b/go.sum index ea24a843c..3df58f187 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= diff --git a/mysql/const.go b/mysql/const.go index e2f3a0afe..a1a5bde42 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -76,6 +76,8 @@ const ( ) const ( + // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html + CLIENT_LONG_PASSWORD uint32 = 1 << iota CLIENT_FOUND_ROWS CLIENT_LONG_FLAG @@ -98,6 +100,16 @@ const ( CLIENT_PLUGIN_AUTH CLIENT_CONNECT_ATTRS CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS + CLIENT_SESSION_TRACK + CLIENT_DEPRECATE_EOF + CLIENT_OPTIONAL_RESULTSET_METADATA + CLIENT_ZSTD_COMPRESSION_ALGORITHM + CLIENT_QUERY_ATTRIBUTES + MULTI_FACTOR_AUTHENTICATION + CLIENT_CAPABILITY_EXTENSION + CLIENT_SSL_VERIFY_SERVER_CERT + CLIENT_REMEMBER_OPTIONS ) const ( @@ -119,7 +131,7 @@ const ( MYSQL_TYPE_VARCHAR MYSQL_TYPE_BIT - //mysql 5.6 + // mysql 5.6 MYSQL_TYPE_TIMESTAMP2 MYSQL_TYPE_DATETIME2 MYSQL_TYPE_TIME2 diff --git a/mysql/util.go b/mysql/util.go index e8d436fa2..6d8ec4471 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/Masterminds/semver" "github.com/pingcap/errors" "github.com/siddontang/go/hack" ) @@ -379,6 +380,23 @@ func ErrorEqual(err1, err2 error) bool { return e1.Error() == e2.Error() } +func CompareServerVersions(a, b string) (int, error) { + var ( + aVer, bVer *semver.Version + err error + ) + + if aVer, err = semver.NewVersion(a); err != nil { + return 0, fmt.Errorf("cannot parse %q as semver: %w", a, err) + } + + if bVer, err = semver.NewVersion(b); err != nil { + return 0, fmt.Errorf("cannot parse %q as semver: %w", b, err) + } + + return aVer.Compare(bVer), nil +} + var encodeRef = map[byte]byte{ '\x00': '0', '\'': '\'', From 6cff2315097e5445eff7ce984bf4b448be4eb360 Mon Sep 17 00:00:00 2001 From: atercattus Date: Wed, 5 Apr 2023 17:44:43 +0400 Subject: [PATCH 3/7] CI - remove ubuntu-18.04 from a test matrix (https://github.com/go-mysql-org/go-mysql/issues/775) --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7273d23fb..67a2b264c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,7 @@ jobs: strategy: matrix: go: [ 1.19, 1.18, 1.17, 1.16 ] - os: [ ubuntu-18.04, ubuntu-20.04 ] + os: [ ubuntu-20.04 ] name: Tests Go ${{ matrix.go }} # This name is used in main branch protection rules runs-on: ${{ matrix.os }} From ea490bf1aa40e7ecbb512453ebc36206aa2799b6 Mon Sep 17 00:00:00 2001 From: atercattus Date: Thu, 6 Apr 2023 12:15:51 +0400 Subject: [PATCH 4/7] More precise authorization processing --- client/auth.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/auth.go b/client/auth.go index d69000d23..8b1601a85 100644 --- a/client/auth.go +++ b/client/auth.go @@ -103,7 +103,11 @@ func (c *Conn) readInitialHandshake() error { pos += 4 } + // Rest of the plugin provided data (scramble) if rest := int(authPluginDataLen) - 8; rest > 0 { + if max := 13; rest > max { // $len=MAX(13, length of auth-plugin-data - 8) + rest = max + } authPluginDataPart2 := data[pos : pos+rest] pos += rest From 687a2b94b57ac137e902dff1eab7a013336e69da Mon Sep 17 00:00:00 2001 From: atercattus Date: Mon, 10 Apr 2023 20:59:45 +0400 Subject: [PATCH 5/7] client - fix an initial handshake auth logic --- client/auth.go | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/client/auth.go b/client/auth.go index 8b1601a85..17940b8ee 100644 --- a/client/auth.go +++ b/client/auth.go @@ -26,7 +26,11 @@ func authPluginAllowed(pluginName string) bool { return false } -// See: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +// See: +// - https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +// - https://github.com/alibaba/canal/blob/0ec46991499a22870dde4ae736b2586cbcbfea94/driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/packets/server/HandshakeInitializationPacket.java#L89 +// - https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift +// - https://github.com/github/vitess-gh/blob/70ae1a2b3a116ff6411b0f40852d6e71382f6e07/go/mysql/client.go func (c *Conn) readInitialHandshake() error { data, err := c.ReadPacket() if err != nil { @@ -91,24 +95,22 @@ func (c *Conn) readInitialHandshake() error { } pos++ - // skip reserved (all [00]) - pos += 6 + // skip reserved (all [00] ?) + pos += 10 - // https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift - if c.capability&CLIENT_LONG_PASSWORD != 0 { - // skip reserved (all [00]) - pos += 4 - } else { - // unknown - pos += 4 - } + if c.capability&CLIENT_SECURE_CONNECTION != 0 { + // Rest of the plugin provided data (scramble) - // Rest of the plugin provided data (scramble) - if rest := int(authPluginDataLen) - 8; rest > 0 { - if max := 13; rest > max { // $len=MAX(13, length of auth-plugin-data - 8) + // $len=MAX(13, length of auth-plugin-data - 8) + rest := int(authPluginDataLen) - 8 + if max := 13; rest > max { rest = max } - authPluginDataPart2 := data[pos : pos+rest] + if data[pos+rest-1] != 0 { + return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos])) + } + + authPluginDataPart2 := data[pos : pos+rest-1] pos += rest c.salt = append(c.salt, authPluginDataPart2...) @@ -119,7 +121,7 @@ func (c *Conn) readInitialHandshake() error { pos += len(c.authPluginName) if data[pos] != 0 { - return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos])) + return errors.Errorf("expect 0x00 after authPluginName, got %q", rune(data[pos])) } // pos++ // ineffectual } From 963974ed4eb1842f73d79a49845990e997631919 Mon Sep 17 00:00:00 2001 From: Aleksey Akulovich Date: Tue, 11 Apr 2023 11:42:59 +0400 Subject: [PATCH 6/7] Update client/auth.go Ooooops. I saw only 13 and 26 bytes variants before... Co-authored-by: lance6716 --- client/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/auth.go b/client/auth.go index 17940b8ee..2fb53b2b1 100644 --- a/client/auth.go +++ b/client/auth.go @@ -103,7 +103,7 @@ func (c *Conn) readInitialHandshake() error { // $len=MAX(13, length of auth-plugin-data - 8) rest := int(authPluginDataLen) - 8 - if max := 13; rest > max { + if max := 13; rest < max { rest = max } if data[pos+rest-1] != 0 { From 7df3e00d63d53eb598823f33755aac50de92e4bf Mon Sep 17 00:00:00 2001 From: atercattus Date: Tue, 11 Apr 2023 20:01:36 +0400 Subject: [PATCH 7/7] client - fix an initial handshake auth logic #2 --- client/auth.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/client/auth.go b/client/auth.go index 2fb53b2b1..ff7ebcd01 100644 --- a/client/auth.go +++ b/client/auth.go @@ -101,16 +101,18 @@ func (c *Conn) readInitialHandshake() error { if c.capability&CLIENT_SECURE_CONNECTION != 0 { // Rest of the plugin provided data (scramble) + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html // $len=MAX(13, length of auth-plugin-data - 8) + // + // https://github.com/mysql/mysql-server/blob/1bfe02bdad6604d54913c62614bde57a055c8332/sql/auth/sql_authentication.cc#L1641-L1642 + // the first packet *must* have at least 20 bytes of a scramble. + // if a plugin provided less, we pad it to 20 with zeros rest := int(authPluginDataLen) - 8 - if max := 13; rest < max { + if max := 12 + 1; rest < max { rest = max } - if data[pos+rest-1] != 0 { - return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos])) - } - authPluginDataPart2 := data[pos : pos+rest-1] + authPluginDataPart2 := data[pos : pos+rest] pos += rest c.salt = append(c.salt, authPluginDataPart2...)