diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 13fd8e6c3..e4f1e9be4 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -128,10 +128,15 @@ func (c *Conn) readPluginName(data []byte, pos int) int { return pos } -func (c *Conn) readAuthData(data []byte, pos int) ([]byte, int, int, error) { +func (c *Conn) readAuthData(data []byte, pos int) (auth []byte, authLen int, newPos int, err error) { + // prevent 'panic: runtime error: index out of range' error + defer func() { + if recover() != nil { + err = NewDefaultError(ER_HANDSHAKE_ERROR) + } + }() + // length encoded data - var auth []byte - var authLen int if c.capability&CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 { authData, isNULL, readBytes, err := LengthEncodedString(data[pos:]) if err != nil { @@ -144,7 +149,7 @@ func (c *Conn) readAuthData(data []byte, pos int) ([]byte, int, int, error) { auth = authData authLen = readBytes } else if c.capability&CLIENT_SECURE_CONNECTION != 0 { - //auth length and auth + // auth length and auth authLen = int(data[pos]) pos++ auth = data[pos : pos+authLen] diff --git a/server/handshake_resp_test.go b/server/handshake_resp_test.go new file mode 100644 index 000000000..f1c7fdf4a --- /dev/null +++ b/server/handshake_resp_test.go @@ -0,0 +1,30 @@ +package server + +import ( + "testing" + + "github.com/go-mysql-org/go-mysql/mysql" +) + +func TestReadAuthData(t *testing.T) { + c := &Conn{ + capability: mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, + } + + data := []byte{141, 174, 255, 1, 0, 0, 0, 1, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 114, 111, 111, 116, 0, 20, 190, 183, 72, 209, 170, 60, 191, 100, 227, 81, 203, 221, 190, 14, 213, 116, 244, 140, 90, 121, 109, 121, 115, 113, 108, 95, 112, 101, 114, 102, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // test out of range index returns 'bad handshake' error + _, _, _, err := c.readAuthData(data, len(data)) + if err == nil || err.Error() != "ERROR 1043 (08S01): Bad handshake" { + t.Fatal("expected error, got nil") + } + + // test good index position reads auth data + _, _, readBytes, err := c.readAuthData(data, len(data)-1) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if readBytes != len(data)-1 { + t.Fatalf("expected %d read bytes, got %d", len(data)-1, readBytes) + } +}