From b233acd43679a4e6ae846d7f92420a4b00fae7de Mon Sep 17 00:00:00 2001 From: Jason Mo Date: Fri, 8 Sep 2023 15:28:08 +0800 Subject: [PATCH 1/5] update parser --- parser/ast/format_test.go | 2 +- parser/consistent_test.go | 6 + parser/misc.go | 264 +++++++++++++++++++- parser/parser_test.go | 2 + parser/test_driver/test_driver_mydecimal.go | 4 +- 5 files changed, 264 insertions(+), 14 deletions(-) diff --git a/parser/ast/format_test.go b/parser/ast/format_test.go index 59424d2876d34..2227ae6e575a6 100644 --- a/parser/ast/format_test.go +++ b/parser/ast/format_test.go @@ -88,8 +88,8 @@ func TestAstFormat(t *testing.T) { expr := fmt.Sprintf("select %s", tt.input) charset, collation := getDefaultCharsetAndCollate() stmts, _, err := parser.New().Parse(expr, charset, collation) - node := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr require.NoError(t, err) + node := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr writer := bytes.NewBufferString("") node.Format(writer) diff --git a/parser/consistent_test.go b/parser/consistent_test.go index 1acc1a58bc850..fa142510cc2e8 100644 --- a/parser/consistent_test.go +++ b/parser/consistent_test.go @@ -46,6 +46,12 @@ func TestKeywordConsistent(t *testing.T) { requires.NotEqual(t, k, v) requires.Equal(t, tokenMap[v], tokenMap[k]) } + + requires.Len(t, reservedKeywords, len(reservedTokenMap)) + for _, v := range reservedKeywords { + requires.NotNil(t, reservedTokenMap[v]) + } + keywordCount := len(reservedKeywords) + len(unreservedKeywords) + len(notKeywordTokens) + len(tidbKeywords) requires.Equal(t, keywordCount-len(windowFuncTokenMap), len(tokenMap)-len(aliases)) diff --git a/parser/misc.go b/parser/misc.go index 462f024130978..1ec6fa2c81927 100644 --- a/parser/misc.go +++ b/parser/misc.go @@ -142,6 +142,244 @@ func isInTokenMap(target string) bool { return ok } +func isReservedKeyWord(target string) bool { + _, ok := reservedTokenMap[target] + return ok +} + +var reservedTokenMap = map[string]int{ + "ADD": add, + "ALL": all, + "ALTER": alter, + "ANALYZE": analyze, + "AND": and, + "ARRAY": array, + "AS": as, + "ASC": asc, + "BETWEEN": between, + "BIGINT": bigIntType, + "BINARY": binaryType, + "BLOB": blobType, + "BOTH": both, + "BY": by, + "CALL": call, + "CASCADE": cascade, + "CASE": caseKwd, + "CHANGE": change, + "CHARACTER": character, + "CHAR": charType, + "CHECK": check, + "COLLATE": collate, + "COLUMN": column, + "CONSTRAINT": constraint, + "CONTINUE": continueKwd, + "CONVERT": convert, + "CREATE": create, + "CROSS": cross, + "CUME_DIST": cumeDist, + "CURRENT_DATE": currentDate, + "CURRENT_TIME": currentTime, + "CURRENT_TIMESTAMP": currentTs, + "CURRENT_USER": currentUser, + "CURRENT_ROLE": currentRole, + "CURSOR": cursor, + "DATABASE": database, + "DATABASES": databases, + "DAY_HOUR": dayHour, + "DAY_MICROSECOND": dayMicrosecond, + "DAY_MINUTE": dayMinute, + "DAY_SECOND": daySecond, + "DECIMAL": decimalType, + "DEFAULT": defaultKwd, + "DELAYED": delayed, + "DELETE": deleteKwd, + "DENSE_RANK": denseRank, + "DESC": desc, + "DESCRIBE": describe, + "DISTINCT": distinct, + "DISTINCTROW": distinctRow, + "DIV": div, + "DOUBLE": doubleType, + "DROP": drop, + "DUAL": dual, + "ELSEIF": elseIfKwd, + "ELSE": elseKwd, + "ENCLOSED": enclosed, + "ESCAPED": escaped, + "EXISTS": exists, + "EXIT": exit, + "EXPLAIN": explain, + "EXCEPT": except, + "FALSE": falseKwd, + "FETCH": fetch, + "FIRST_VALUE": firstValue, + "FLOAT": floatType, + "FOR": forKwd, + "FORCE": force, + "FOREIGN": foreign, + "FROM": from, + "FULLTEXT": fulltext, + "GENERATED": generated, + "GRANT": grant, + "GROUP": group, + "GROUPS": groups, + "HAVING": having, + "HIGH_PRIORITY": highPriority, + "HOUR_MICROSECOND": hourMicrosecond, + "HOUR_MINUTE": hourMinute, + "HOUR_SECOND": hourSecond, + "IF": ifKwd, + "IGNORE": ignore, + "IN": in, + "INDEX": index, + "INFILE": infile, + "INNER": inner, + "INOUT": inout, + "INTEGER": integerType, + "INTERSECT": intersect, + "INTERVAL": interval, + "INTO": into, + "OUTFILE": outfile, + "IS": is, + "INSERT": insert, + "INT": intType, + "INT1": int1Type, + "INT2": int2Type, + "INT3": int3Type, + "INT4": int4Type, + "INT8": int8Type, + "ITERATE": iterate, + "JOIN": join, + "KEY": key, + "KEYS": keys, + "KILL": kill, + "LAG": lag, + "LAST_VALUE": lastValue, + "LEAD": lead, + "LEADING": leading, + "LEAVE": leave, + "LEFT": left, + "LIKE": like, + "ILIKE": ilike, + "LIMIT": limit, + "LINES": lines, + "LINEAR": linear, + "LOAD": load, + "LOCALTIME": localTime, + "LOCALTIMESTAMP": localTs, + "LOCK": lock, + "LONGBLOB": longblobType, + "LONGTEXT": longtextType, + "LOW_PRIORITY": lowPriority, + "MATCH": match, + "MAXVALUE": maxValue, + "MEDIUMBLOB": mediumblobType, + "MEDIUMINT": mediumIntType, + "MEDIUMTEXT": mediumtextType, + "MINUTE_MICROSECOND": minuteMicrosecond, + "MINUTE_SECOND": minuteSecond, + "MOD": mod, + "NOT": not, + "NO_WRITE_TO_BINLOG": noWriteToBinLog, + "NTH_VALUE": nthValue, + "NTILE": ntile, + "NULL": null, + "NUMERIC": numericType, + "OF": of, + "ON": on, + "OPTIMIZE": optimize, + "OPTION": option, + "OPTIONALLY": optionally, + "OR": or, + "ORDER": order, + "OUT": out, + "OUTER": outer, + "OVER": over, + "PARTITION": partition, + "PERCENT_RANK": percentRank, + "PRECISION": precisionType, + "PRIMARY": primary, + "PROCEDURE": procedure, + "RANGE": rangeKwd, + "RANK": rank, + "READ": read, + "REAL": realType, + "RECURSIVE": recursive, + "REFERENCES": references, + "REGEXP": regexpKwd, + "RELEASE": release, + "RENAME": rename, + "REPEAT": repeat, + "REPLACE": replace, + "REQUIRE": require, + "RESTRICT": restrict, + "REVOKE": revoke, + "RIGHT": right, + "RLIKE": rlike, + "ROW": row, + "ROWS": rows, + "ROW_NUMBER": rowNumber, + "SECOND_MICROSECOND": secondMicrosecond, + "SELECT": selectKwd, + "SET": set, + "SHOW": show, + "SMALLINT": smallIntType, + "SPATIAL": spatial, + "SQL": sql, + "SQL_BIG_RESULT": sqlBigResult, + "SQL_CALC_FOUND_ROWS": sqlCalcFoundRows, + "SQL_SMALL_RESULT": sqlSmallResult, + "SQLEXCEPTION": sqlexception, + "SQLSTATE": sqlstate, + "SQLWARNING": sqlwarning, + "SSL": ssl, + "STARTING": starting, + "STATS_EXTENDED": statsExtended, + "STRAIGHT_JOIN": straightJoin, + "TiDB_CURRENT_TSO": tidbCurrentTSO, + "TABLE": tableKwd, + "TABLESAMPLE": tableSample, + "STORED": stored, + "TERMINATED": terminated, + "THEN": then, + "TINYBLOB": tinyblobType, + "TINYINT": tinyIntType, + "TINYTEXT": tinytextType, + "TO": to, + "TRAILING": trailing, + "TRIGGER": trigger, + "TRUE": trueKwd, + "UNIQUE": unique, + "UNION": union, + "UNLOCK": unlock, + "UNSIGNED": unsigned, + "UNTIL": until, + "UPDATE": update, + "USAGE": usage, + "USE": use, + "USING": using, + "UTC_DATE": utcDate, + "UTC_TIMESTAMP": utcTimestamp, + "UTC_TIME": utcTime, + "VALUES": values, + "LONG": long, + "VARCHAR": varcharType, + "VARCHARACTER": varcharacter, + "VARBINARY": varbinaryType, + "VARYING": varying, + "VIRTUAL": virtual, + "WHEN": when, + "WHERE": where, + "WHILE": while, + "WRITE": write, + "WINDOW": window, + "WITH": with, + "XOR": xor, + "YEAR_MONTH": yearMonth, + "ZEROFILL": zerofill, + "NATURAL": natural, +} + // tokenMap is a map of known identifiers to the parser token ID. // Please try to keep the map in alphabetical order. var tokenMap = map[string]int{ @@ -1066,20 +1304,10 @@ var hintTokenMap = map[string]int{ func (s *Scanner) isTokenIdentifier(lit string, offset int) int { // An identifier before or after '.' means it is part of a qualified identifier. // We do not parse it as keyword. - if s.r.peek() == '.' { + if s.r.peek() == '.' || (offset != 0 && s.r.s[offset-1] == '.') { return 0 } - for idx := offset - 1; idx >= 0; idx-- { - if s.r.s[idx] == ' ' { - continue - } else if s.r.s[idx] == '.' { - return 0 - } else { - break - } - } - buf := &s.buf buf.Reset() buf.Grow(len(lit)) @@ -1094,6 +1322,20 @@ func (s *Scanner) isTokenIdentifier(lit string, offset int) int { } } + // select * from t where t. status = 1; -- parse ok, unreserved keyword + // select * from t where t. and = 1; -- parse failed, reserverd keyword. + if !isReservedKeyWord(string(data)) { + for idx := offset - 1; idx >= 0; idx-- { + if s.r.s[idx] == ' ' { + continue + } else if s.r.s[idx] == '.' { + return 0 + } else { + break + } + } + } + checkBtFuncToken := s.r.peek() == '(' if !checkBtFuncToken && s.sqlMode.HasIgnoreSpaceMode() { s.skipWhitespace() diff --git a/parser/parser_test.go b/parser/parser_test.go index f32f2cb7381f7..646092b002723 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1059,6 +1059,8 @@ AAAAAAAAAAAA5gm5Mg== {"select `t`.`1a`.1 from t;", true, "SELECT `t`.`1a`.`1` FROM `t`"}, {"select * from 1db.1table;", true, "SELECT * FROM `1db`.`1table`"}, {"select * from t where t. status = 1;", true, "SELECT * FROM `t` WHERE `t`.`status`=1"}, + {"select * from t where t. and = 1;", false, ""}, + {"select * from t where a between 1. and 2;", true, "SELECT * FROM `t` WHERE `a` BETWEEN 1. AND 2"}, // for show placement {"SHOW PLACEMENT", true, "SHOW PLACEMENT"}, diff --git a/parser/test_driver/test_driver_mydecimal.go b/parser/test_driver/test_driver_mydecimal.go index 91bd04486689e..21189182d35ed 100644 --- a/parser/test_driver/test_driver_mydecimal.go +++ b/parser/test_driver/test_driver_mydecimal.go @@ -124,7 +124,7 @@ func (d *MyDecimal) ToString() (str []byte) { if d.negative { length++ } - if digitsFrac > 0 { + if digitsFrac >= 0 { length++ } str = str[:length] @@ -134,7 +134,7 @@ func (d *MyDecimal) ToString() (str []byte) { strIdx++ } var fill int - if digitsFrac > 0 { + if digitsFrac >= 0 { fracIdx := strIdx + digitsIntLen fill = digitsFracLen - digitsFrac wordIdx := wordStartIdx + digitsToWords(digitsInt) From a783b84842ad994e4035218fa1d59fdfc0c77106 Mon Sep 17 00:00:00 2001 From: Jason Mo Date: Fri, 8 Sep 2023 17:23:41 +0800 Subject: [PATCH 2/5] upate --- parser/misc.go | 4 ++-- parser/parser_test.go | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/parser/misc.go b/parser/misc.go index 1ec6fa2c81927..a59034708d2f5 100644 --- a/parser/misc.go +++ b/parser/misc.go @@ -1322,8 +1322,8 @@ func (s *Scanner) isTokenIdentifier(lit string, offset int) int { } } - // select * from t where t. status = 1; -- parse ok, unreserved keyword - // select * from t where t. and = 1; -- parse failed, reserverd keyword. + // select * from t where t. status = 1; -- parse ok, unreserved keyword. + // select * from t where t. and = 1; -- parse failed, reserved keyword. if !isReservedKeyWord(string(data)) { for idx := offset - 1; idx >= 0; idx-- { if s.r.s[idx] == ' ' { diff --git a/parser/parser_test.go b/parser/parser_test.go index 646092b002723..0a706f6ecd48b 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1058,8 +1058,10 @@ AAAAAAAAAAAA5gm5Mg== {"select `t`.`1a`.1 from t;", true, "SELECT `t`.`1a`.`1` FROM `t`"}, {"select * from 1db.1table;", true, "SELECT * FROM `1db`.`1table`"}, - {"select * from t where t. status = 1;", true, "SELECT * FROM `t` WHERE `t`.`status`=1"}, + {"select * from t where t.and = 1;", true, "SELECT * FROM `t` WHERE `t`.`and`=1"}, {"select * from t where t. and = 1;", false, ""}, + {"select * from t where t.status = 1;", true, "SELECT * FROM `t` WHERE `t`.`status`=1"}, + {"select * from t where t. status = 1;", true, "SELECT * FROM `t` WHERE `t`.`status`=1"}, {"select * from t where a between 1. and 2;", true, "SELECT * FROM `t` WHERE `a` BETWEEN 1. AND 2"}, // for show placement From c521d486b3ee7c4bc53e14477ab8c89952efeb3d Mon Sep 17 00:00:00 2001 From: Jason Mo Date: Fri, 8 Sep 2023 17:40:44 +0800 Subject: [PATCH 3/5] update --- parser/consistent_test.go | 5 +++-- parser/misc.go | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/parser/consistent_test.go b/parser/consistent_test.go index fa142510cc2e8..1a9a6d1527b8b 100644 --- a/parser/consistent_test.go +++ b/parser/consistent_test.go @@ -14,6 +14,7 @@ package parser import ( + "fmt" gio "io" "os" "sort" @@ -47,10 +48,10 @@ func TestKeywordConsistent(t *testing.T) { requires.Equal(t, tokenMap[v], tokenMap[k]) } - requires.Len(t, reservedKeywords, len(reservedTokenMap)) for _, v := range reservedKeywords { - requires.NotNil(t, reservedTokenMap[v]) + requires.NotEqual(t, reservedTokenMap[v], 0, fmt.Sprintf("Not found %s in reservedTokenMap", v)) } + requires.Len(t, reservedKeywords, len(reservedTokenMap)) keywordCount := len(reservedKeywords) + len(unreservedKeywords) + len(notKeywordTokens) + len(tidbKeywords) requires.Equal(t, keywordCount-len(windowFuncTokenMap), len(tokenMap)-len(aliases)) diff --git a/parser/misc.go b/parser/misc.go index 58e19a3bc9096..aebd39a8f9d63 100644 --- a/parser/misc.go +++ b/parser/misc.go @@ -214,6 +214,8 @@ var reservedTokenMap = map[string]int{ "FETCH": fetch, "FIRST_VALUE": firstValue, "FLOAT": floatType, + "FLOAT4": float4Type, + "FLOAT8": float8Type, "FOR": forKwd, "FORCE": force, "FOREIGN": foreign, @@ -276,6 +278,7 @@ var reservedTokenMap = map[string]int{ "MEDIUMBLOB": mediumblobType, "MEDIUMINT": mediumIntType, "MEDIUMTEXT": mediumtextType, + "MIDDLEINT": middleIntType, "MINUTE_MICROSECOND": minuteMicrosecond, "MINUTE_SECOND": minuteSecond, "MOD": mod, From 9d21d6ca9b2d48cd758f5343b76e54511bc10af3 Mon Sep 17 00:00:00 2001 From: Jason Mo Date: Fri, 8 Sep 2023 17:52:48 +0800 Subject: [PATCH 4/5] update --- parser/misc.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/parser/misc.go b/parser/misc.go index aebd39a8f9d63..9450acd68f3d9 100644 --- a/parser/misc.go +++ b/parser/misc.go @@ -142,11 +142,16 @@ func isInTokenMap(target string) bool { return ok } -func isReservedKeyWord(target string) bool { +func isInReservedTokenMap(target string) bool { _, ok := reservedTokenMap[target] return ok } +func isInWindowFuncTokenMap(target string) bool { + _, ok := windowFuncTokenMap[target] + return ok +} + var reservedTokenMap = map[string]int{ "ADD": add, "ALL": all, @@ -1330,7 +1335,7 @@ func (s *Scanner) isTokenIdentifier(lit string, offset int) int { // select * from t where t. status = 1; -- parse ok, unreserved keyword. // select * from t where t. and = 1; -- parse failed, reserved keyword. - if !isReservedKeyWord(string(data)) { + if !isInReservedTokenMap(string(data)) || (!s.supportWindowFunc && isInWindowFuncTokenMap(string(data))) { for idx := offset - 1; idx >= 0; idx-- { if s.r.s[idx] == ' ' { continue From 15899e9013006f46f0bf9961de3b7a252731539f Mon Sep 17 00:00:00 2001 From: Jason Mo Date: Fri, 8 Sep 2023 18:05:03 +0800 Subject: [PATCH 5/5] udpate --- parser/parser_test.go | 11 ++++++++++- parser/test_driver/test_driver_mydecimal.go | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/parser/parser_test.go b/parser/parser_test.go index 122396aa6a94c..ac141dea8a6f8 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1062,7 +1062,6 @@ AAAAAAAAAAAA5gm5Mg== {"select * from t where t. and = 1;", false, ""}, {"select * from t where t.status = 1;", true, "SELECT * FROM `t` WHERE `t`.`status`=1"}, {"select * from t where t. status = 1;", true, "SELECT * FROM `t` WHERE `t`.`status`=1"}, - {"select * from t where a between 1. and 2;", true, "SELECT * FROM `t` WHERE `a` BETWEEN 1. AND 2"}, // for show placement {"SHOW PLACEMENT", true, "SHOW PLACEMENT"}, @@ -7468,6 +7467,16 @@ func TestMultiStmt(t *testing.T) { require.Equal(t, "1", stmt4.Fields.Fields[0].Text()) } +func TestIssue46789(t *testing.T) { + p := parser.New() + // parse `1.` get a datum with type decimal, + // parse and restore `1.` get a datumn with type int, + // so add a new test case here. + sql := "select * from t where a between 1. and 2" + _, _, err := p.Parse(sql, "", "") + require.NoError(t, err) +} + // https://dev.mysql.com/doc/refman/8.1/en/other-vendor-data-types.html func TestCompatTypes(t *testing.T) { table := []testCase{ diff --git a/parser/test_driver/test_driver_mydecimal.go b/parser/test_driver/test_driver_mydecimal.go index 21189182d35ed..91bd04486689e 100644 --- a/parser/test_driver/test_driver_mydecimal.go +++ b/parser/test_driver/test_driver_mydecimal.go @@ -124,7 +124,7 @@ func (d *MyDecimal) ToString() (str []byte) { if d.negative { length++ } - if digitsFrac >= 0 { + if digitsFrac > 0 { length++ } str = str[:length] @@ -134,7 +134,7 @@ func (d *MyDecimal) ToString() (str []byte) { strIdx++ } var fill int - if digitsFrac >= 0 { + if digitsFrac > 0 { fracIdx := strIdx + digitsIntLen fill = digitsFracLen - digitsFrac wordIdx := wordStartIdx + digitsToWords(digitsInt)