Skip to content

Commit

Permalink
Merge branch 'master' into PARTIAL_UPDATE_ROWS_EVENT
Browse files Browse the repository at this point in the history
# Conflicts:
#	.github/workflows/ci.yml
  • Loading branch information
atercattus committed Apr 5, 2023
2 parents 1ca1b07 + 0e9bf02 commit 82c1282
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 27 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
strategy:
matrix:
go: [ 1.19, 1.18, 1.17, 1.16 ]
os: [ ubuntu-20.04 ]
os: [ ubuntu-22.04, ubuntu-20.04 ]
name: Tests Go ${{ matrix.go }} # This name is used in main branch protection rules
runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -48,3 +48,4 @@ jobs:
uses: golangci/golangci-lint-action@v2
with:
version: latest
args: --timeout=3m
64 changes: 42 additions & 22 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func authPluginAllowed(pluginName string) bool {
return false
}

// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
// See: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
func (c *Conn) readInitialHandshake() error {
data, err := c.ReadPacket()
if err != nil {
Expand All @@ -40,24 +40,28 @@ func (c *Conn) readInitialHandshake() error {
if data[0] < MinProtocolVersion {
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
}
pos := 1

// skip mysql version
// mysql version end with 0x00
version := data[1 : bytes.IndexByte(data[1:], 0x00)+1]
version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1]
c.serverVersion = string(version)
pos := 1 + len(version)
pos += len(version) + 1 /*trailing zero byte*/

// connection id length is 4
c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
pos += 4

c.salt = []byte{}
c.salt = append(c.salt, data[pos:pos+8]...)
// first 8 bytes of the plugin provided data (scramble)
c.salt = append(c.salt[:0], data[pos:pos+8]...)
pos += 8

// skip filter
pos += 8 + 1
if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
}
pos++

// capability lower 2 bytes
// The lower 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
// check protocol
if c.capability&CLIENT_PROTOCOL_41 == 0 {
Expand All @@ -69,35 +73,51 @@ func (c *Conn) readInitialHandshake() error {
pos += 2

if len(data) > pos {
// skip server charset
// default server a_protocol_character_set, only the lower 8-bits
// c.charset = data[pos]
pos += 1

c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2
// capability flags (upper 2 bytes)

// The upper 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
pos += 2

// auth_data is end with 0x00, min data length is 13 + 8 = 21
// ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
maxAuthDataLen := 21
if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen {
maxAuthDataLen = int(data[pos])
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
authPluginDataLen := data[pos]
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
}
pos++

// skip reserved (all [00])
pos += 10 + 1
pos += 6

// auth_data is end with 0x00, so we need to trim 0x00
resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1
c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...)
// https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift
if c.capability&CLIENT_LONG_PASSWORD != 0 {
// skip reserved (all [00])
pos += 4
} else {
// unknown
pos += 4
}

if rest := int(authPluginDataLen) - 8; rest > 0 {
authPluginDataPart2 := data[pos : pos+rest]
pos += rest

// skip reset of end pos
pos = resetOfAuthDataEndPos + 1
c.salt = append(c.salt, authPluginDataPart2...)
}

if c.capability&CLIENT_PLUGIN_AUTH != 0 {
c.authPluginName = string(data[pos : len(data)-1])
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
pos += len(c.authPluginName)

if data[pos] != 0 {
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
}
pos++
}
}

Expand Down
132 changes: 129 additions & 3 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ func (c *Conn) handshake() error {
var err error
if err = c.readInitialHandshake(); err != nil {
c.Close()
return errors.Trace(err)
return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err))
}

if err := c.writeAuthHandshake(); err != nil {
c.Close()

return errors.Trace(err)
return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err))
}

if err := c.handleAuthResult(); err != nil {
c.Close()
return errors.Trace(err)
return errors.Trace(fmt.Errorf("handleAuthResult: %w", err))
}

return nil
Expand Down Expand Up @@ -408,3 +408,129 @@ func (c *Conn) exec(query string) (*Result, error) {

return c.readResult(false)
}

func (c *Conn) CapabilityString() string {
var caps []string
capability := c.capability
for i := 0; capability != 0; i++ {
field := uint32(1 << i)
if capability&field == 0 {
continue
}
capability ^= field

switch field {
case CLIENT_LONG_PASSWORD:
caps = append(caps, "CLIENT_LONG_PASSWORD")
case CLIENT_FOUND_ROWS:
caps = append(caps, "CLIENT_FOUND_ROWS")
case CLIENT_LONG_FLAG:
caps = append(caps, "CLIENT_LONG_FLAG")
case CLIENT_CONNECT_WITH_DB:
caps = append(caps, "CLIENT_CONNECT_WITH_DB")
case CLIENT_NO_SCHEMA:
caps = append(caps, "CLIENT_NO_SCHEMA")
case CLIENT_COMPRESS:
caps = append(caps, "CLIENT_COMPRESS")
case CLIENT_ODBC:
caps = append(caps, "CLIENT_ODBC")
case CLIENT_LOCAL_FILES:
caps = append(caps, "CLIENT_LOCAL_FILES")
case CLIENT_IGNORE_SPACE:
caps = append(caps, "CLIENT_IGNORE_SPACE")
case CLIENT_PROTOCOL_41:
caps = append(caps, "CLIENT_PROTOCOL_41")
case CLIENT_INTERACTIVE:
caps = append(caps, "CLIENT_INTERACTIVE")
case CLIENT_SSL:
caps = append(caps, "CLIENT_SSL")
case CLIENT_IGNORE_SIGPIPE:
caps = append(caps, "CLIENT_IGNORE_SIGPIPE")
case CLIENT_TRANSACTIONS:
caps = append(caps, "CLIENT_TRANSACTIONS")
case CLIENT_RESERVED:
caps = append(caps, "CLIENT_RESERVED")
case CLIENT_SECURE_CONNECTION:
caps = append(caps, "CLIENT_SECURE_CONNECTION")
case CLIENT_MULTI_STATEMENTS:
caps = append(caps, "CLIENT_MULTI_STATEMENTS")
case CLIENT_MULTI_RESULTS:
caps = append(caps, "CLIENT_MULTI_RESULTS")
case CLIENT_PS_MULTI_RESULTS:
caps = append(caps, "CLIENT_PS_MULTI_RESULTS")
case CLIENT_PLUGIN_AUTH:
caps = append(caps, "CLIENT_PLUGIN_AUTH")
case CLIENT_CONNECT_ATTRS:
caps = append(caps, "CLIENT_CONNECT_ATTRS")
case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA:
caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA")
case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS:
caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS")
case CLIENT_SESSION_TRACK:
caps = append(caps, "CLIENT_SESSION_TRACK")
case CLIENT_DEPRECATE_EOF:
caps = append(caps, "CLIENT_DEPRECATE_EOF")
case CLIENT_OPTIONAL_RESULTSET_METADATA:
caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA")
case CLIENT_ZSTD_COMPRESSION_ALGORITHM:
caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM")
case CLIENT_QUERY_ATTRIBUTES:
caps = append(caps, "CLIENT_QUERY_ATTRIBUTES")
case MULTI_FACTOR_AUTHENTICATION:
caps = append(caps, "MULTI_FACTOR_AUTHENTICATION")
case CLIENT_CAPABILITY_EXTENSION:
caps = append(caps, "CLIENT_CAPABILITY_EXTENSION")
case CLIENT_SSL_VERIFY_SERVER_CERT:
caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT")
case CLIENT_REMEMBER_OPTIONS:
caps = append(caps, "CLIENT_REMEMBER_OPTIONS")
default:
caps = append(caps, fmt.Sprintf("(%d)", field))
}
}

return strings.Join(caps, "|")
}

func (c *Conn) StatusString() string {
var stats []string
status := c.status
for i := 0; status != 0; i++ {
field := uint16(1 << i)
if status&field == 0 {
continue
}
status ^= field

switch field {
case SERVER_STATUS_IN_TRANS:
stats = append(stats, "SERVER_STATUS_IN_TRANS")
case SERVER_STATUS_AUTOCOMMIT:
stats = append(stats, "SERVER_STATUS_AUTOCOMMIT")
case SERVER_MORE_RESULTS_EXISTS:
stats = append(stats, "SERVER_MORE_RESULTS_EXISTS")
case SERVER_STATUS_NO_GOOD_INDEX_USED:
stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED")
case SERVER_STATUS_NO_INDEX_USED:
stats = append(stats, "SERVER_STATUS_NO_INDEX_USED")
case SERVER_STATUS_CURSOR_EXISTS:
stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS")
case SERVER_STATUS_LAST_ROW_SEND:
stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND")
case SERVER_STATUS_DB_DROPPED:
stats = append(stats, "SERVER_STATUS_DB_DROPPED")
case SERVER_STATUS_NO_BACKSLASH_ESCAPED:
stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED")
case SERVER_STATUS_METADATA_CHANGED:
stats = append(stats, "SERVER_STATUS_METADATA_CHANGED")
case SERVER_QUERY_WAS_SLOW:
stats = append(stats, "SERVER_QUERY_WAS_SLOW")
case SERVER_PS_OUT_PARAMS:
stats = append(stats, "SERVER_PS_OUT_PARAMS")
default:
stats = append(stats, fmt.Sprintf("(%d)", field))
}
}

return strings.Join(stats, "|")
}
14 changes: 13 additions & 1 deletion mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ const (
)

const (
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html

CLIENT_LONG_PASSWORD uint32 = 1 << iota
CLIENT_FOUND_ROWS
CLIENT_LONG_FLAG
Expand All @@ -98,6 +100,16 @@ const (
CLIENT_PLUGIN_AUTH
CLIENT_CONNECT_ATTRS
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS
CLIENT_SESSION_TRACK
CLIENT_DEPRECATE_EOF
CLIENT_OPTIONAL_RESULTSET_METADATA
CLIENT_ZSTD_COMPRESSION_ALGORITHM
CLIENT_QUERY_ATTRIBUTES
MULTI_FACTOR_AUTHENTICATION
CLIENT_CAPABILITY_EXTENSION
CLIENT_SSL_VERIFY_SERVER_CERT
CLIENT_REMEMBER_OPTIONS
)

const (
Expand All @@ -119,7 +131,7 @@ const (
MYSQL_TYPE_VARCHAR
MYSQL_TYPE_BIT

//mysql 5.6
// mysql 5.6
MYSQL_TYPE_TIMESTAMP2
MYSQL_TYPE_DATETIME2
MYSQL_TYPE_TIME2
Expand Down

0 comments on commit 82c1282

Please sign in to comment.