Skip to content

Commit

Permalink
Implement ExceptionCollectorListener and make it default behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
surister committed Jun 3, 2024
1 parent 54e6351 commit c1b749d
Showing 1 changed file with 128 additions and 12 deletions.
140 changes: 128 additions & 12 deletions cratedb_sqlparse_py/cratedb_sqlparse/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import List

from antlr4 import CommonTokenStream, InputStream, Token
from antlr4 import CommonTokenStream, InputStream, Token, RecognitionException
from antlr4.error.ErrorListener import ErrorListener

from cratedb_sqlparse.generated_parser.SqlBaseLexer import SqlBaseLexer
Expand Down Expand Up @@ -30,7 +31,49 @@ def END_DOLLAR_QUOTED_STRING_sempred(self, localctx, predIndex) -> bool:


class ParsingException(Exception):
pass
def __init__(self, *, query: str, msg: str, offending_token: Token, e: RecognitionException):
self.message = msg
self.offending_token = offending_token
self.e = e
self.query = query

@property
def error_message(self):
return f"{self!r}[line {self.line}:{self.column} {self.message}]"

@property
def original_query_with_error_marked(self):
query = self.offending_token.source[1].strdata
offending_token_text: str = query[self.offending_token.start: self.offending_token.stop + 1]
query_lines: list = query.split('\n')

offending_line: str = query_lines[self.line - 1]

# White spaces from the beginning of the offending line to the offending text, so the '^'
# chars are correctly placed below the offending token.
newline_offset = offending_line.index(offending_token_text)
newline = offending_line + '\n' + (
' ' * newline_offset + '^' * (
self.offending_token.stop - self.offending_token.start + 1))

query_lines[self.line - 1] = newline

msg = "\n".join(query_lines)
return msg

@property
def column(self):
return self.offending_token.column

@property
def line(self):
return self.offending_token.line

def __repr__(self):
return f'{type(self.e).__qualname__}'

def __str__(self):
return repr(self)


class CaseInsensitiveStream(InputStream):
Expand All @@ -47,16 +90,44 @@ class ExceptionErrorListener(ErrorListener):
"""

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise ParsingException(f"line{line}:{column} {msg}")
error = ParsingException(
msg=msg,
offending_token=offendingSymbol,
e=e,
query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex)
)
raise error


class ExceptionCollectorListener(ErrorListener):
"""
Error listener that collects all errors into errors for further processing.
Based partially on https://github.com/antlr/antlr4/issues/396
"""

def __init__(self):
self.errors = []

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
error = ParsingException(
msg=msg,
offending_token=offendingSymbol,
e=e,
query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex),
)

self.errors.append(error)


class Statement:
"""
Represents a CrateDB SQL statement.
"""

def __init__(self, ctx: SqlBaseParser.StatementContext):
def __init__(self, ctx: SqlBaseParser.StatementContext, exception: ParsingException = None):
self.ctx: SqlBaseParser.StatementContext = ctx
self.exception = exception

@property
def tree(self):
Expand All @@ -77,7 +148,8 @@ def query(self) -> str:
"""
Returns the query, comments and ';' are not included.
"""
return self.ctx.parser.getTokenStream().getText(start=self.ctx.start.tokenIndex, stop=self.ctx.stop.tokenIndex)
return self.ctx.parser.getTokenStream().getText(start=self.ctx.start,
stop=self.ctx.stop)

@property
def type(self):
Expand All @@ -90,7 +162,20 @@ def __repr__(self):
return f'{self.__class__.__qualname__}<{self.query if len(self.query) < 15 else self.query[:15] + "..."}>'


def sqlparse(query: str) -> List[Statement]:
def find_suitable_error(statement, errors):
for error in errors[:]:
# We clean the error_query of ';' and spaces because ironically,
# we can get the full query in the error handler but not in the context.
error_query = error.query
if error_query.endswith(';'):
error_query = error_query[:len(error_query) - 1]

if error_query.lstrip().rstrip() == statement.query:
statement.exception = error
errors.pop(errors.index(error))


def sqlparse(query: str, raise_exception: bool = False) -> List[Statement]:
"""
Parses a string into SQL `Statement`.
"""
Expand All @@ -101,12 +186,43 @@ def sqlparse(query: str) -> List[Statement]:

parser = SqlBaseParser(stream)
parser.removeErrorListeners()
parser.addErrorListener(ExceptionErrorListener())
error_listener = ExceptionErrorListener() if raise_exception else ExceptionCollectorListener()
parser.addErrorListener(error_listener)

tree = parser.statements()

# At this point, all errors are already raised; it's seasonably safe to assume
# that the statements are valid.
statements = list(filter(lambda children: isinstance(children, SqlBaseParser.StatementContext), tree.children))

return [Statement(statement) for statement in statements]
statements_context: list[SqlBaseParser.StatementContext] = list(
filter(lambda children: isinstance(children, SqlBaseParser.StatementContext),
tree.children)
)

statements = []
for statement_context in statements_context:
_stmt = Statement(statement_context)
find_suitable_error(_stmt, error_listener.errors)
statements.append(_stmt)

else:
# We might still have error(s) that we couldn't match with their origin statement,
# this happens when the query is composed of only one keyword, e.g. 'SELCT 1'
# the error.query will be 'SELCT' instead of 'SELCT 1'.
if len(error_listener.errors) == 1:
# This case has an edge case where we hypothetically assign the
# wrong error to a statement, for example:
# SELECT A FROM tbl1;
# SELEC 1;
# This would match both conditionals, this however is protected by
# by https://github.com/crate/cratedb-sqlparse/issues/28, but might
# change in the future.
error = error_listener.errors[0]
for _stmt in statements:
if _stmt.exception is None and error.query in _stmt.query:
_stmt.exception = error
break

if len(error_listener.errors) > 1:
logging.error(
'Could not match errors to queries, too much ambiguity, open an issue with this error and the query.'
)

return statements

0 comments on commit c1b749d

Please sign in to comment.