Skip to content

Commit

Permalink
rewrite suite_tests parser (#12949)
Browse files Browse the repository at this point in the history
  • Loading branch information
zverevgeny authored Dec 25, 2024
1 parent 61f38fc commit cf03f2f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 99 deletions.
167 changes: 71 additions & 96 deletions ydb/tests/functional/suite_tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import itertools
import json
import abc
import collections
import os
import random
import string
Expand Down Expand Up @@ -41,21 +40,53 @@ def mute_sdk_loggers():
mute_sdk_loggers()


@enum.unique
class StatementTypes(enum.Enum):
Skipped = 'statement skipped'
Ok = 'statement ok'
Error = 'statement error'
Query = 'statement query'
StreamQuery = 'statement stream query'
ImportTableData = 'statement import table data'
class StatementDefinition:
@enum.unique
class Type(enum.Enum):
Skipped = 'statement skipped'
Ok = 'statement ok'
Error = 'statement error'
Query = 'statement query'
StreamQuery = 'statement stream query'
ImportTableData = 'statement import table data'

def __init__(self, suite: str, at_line: int, type: Type, text: [str]):
self.suite_name = suite
self.at_line = at_line
self.s_type = type
self.text = text

def get_statement_type(line):
for s_type in list(StatementTypes):
if s_type.value in line.lower():
return s_type
raise RuntimeError("Can't find statement type for line %s" % line)
def __str__(self):
return f'''StatementDefinition:
suite: {self.suite_name}
line: {self.at_line}
type: {self.s_type}
text:
''' + '\n'.join([f' {row}' for row in self.text.split('\n')])

@staticmethod
def _parse_statement_type(statement_line: str) -> Type:
for t in list(StatementDefinition.Type):
if t.value in statement_line.lower():
return t
return None

@staticmethod
def parse(suite: str, at_line: int, lines: list[str]):
if not lines or not lines[0]:
raise RuntimeError(f'Invalid statement in {suite}, at line: {at_line}')
type = StatementDefinition._parse_statement_type(lines[0])
if type is None:
raise RuntimeError(f'Unknown statement type in {suite}, at line: {at_line}')
lines.pop(0)
at_line += 1
statement_lines = []
for line in lines:
if line.startswith('side effect: '): # side effects are not supported yet
pass
else:
statement_lines.append(line)
return StatementDefinition(suite, at_line, type, "\n".join(statement_lines))


def get_token(length=10):
Expand All @@ -67,12 +98,6 @@ def get_source_path(*args):
return os.path.join(arcadia_root, test_source_path(os.path.join(*args)))


def is_empty_line(line):
if line.split():
return False
return True


def get_lines(suite_path):
with open(suite_path) as reader:
for line_idx, line in enumerate(reader.readlines()):
Expand All @@ -97,79 +122,31 @@ def get_test_suites(directory):
return suites


def get_single_statement(lines):
def split_by_statement(lines):
statement_lines = []
statement_start_line_idx = 0
for line_idx, line in lines:
if is_empty_line(line):
statement = "\n".join(statement_lines)
return statement
statement_lines.append(line)
return "\n".join(statement_lines)


class ParsedStatement(collections.namedtuple('ParsedStatement', ["at_line", "s_type", "suite_name", "text"])):
def get_fields(self):
return self._fields

def __str__(self):
result = ["", "Parsed Statement"]
for field in self.get_fields():
value = str(getattr(self, field))
if field != 'text':
result.append(' ' * 4 + '%s: %s,' % (field, value))
else:
result.append(' ' * 4 + '%s:' % field)
result.extend([' ' * 8 + row for row in value.split('\n')])
return "\n".join(result)
if line:
if line.startswith("statement "):
statement_start_line_idx = line_idx
statement_lines = [line]
elif statement_lines:
statement_lines.append(line)
else:
if statement_lines:
yield (statement_start_line_idx, statement_lines)
statement_lines = []
if statement_lines:
yield (statement_start_line_idx, statement_lines)


def get_statements(suite_path, suite_name):
lines = get_lines(suite_path)
for line_idx, line in lines:
if is_empty_line(line) or not is_statement_definition(line):
# empty line or junk lines
continue
text = get_single_statement(lines)
yield ParsedStatement(
line_idx,
get_statement_type(line),
for statement_start_line_idx, statement_lines in split_by_statement(get_lines(suite_path)):
yield StatementDefinition.parse(
suite_name,
text)


def is_side_effect(statement_line):
return statement_line.startswith('side effect: ')


def parse_side_effect(se_line):
pieces = se_line.split(':')
if len(pieces) < 3:
raise RuntimeError("Invalid side effect description: %s" % se_line)
se_type = pieces[1].strip()
se_description = ':'.join(pieces[2:])
se_description = se_description.strip()

return se_type, se_description


def get_statement_and_side_effects(statement_text):
statement_lines = statement_text.split('\n')
side_effects = {}
filtered = []
for statement_line in statement_lines:
if not is_side_effect(statement_line):
filtered.append(statement_line)
continue

se_type, se_description = parse_side_effect(statement_line)

side_effects[se_type] = se_description

return '\n'.join(filtered), side_effects


def is_statement_definition(line):
return line.startswith("statement")
statement_start_line_idx,
statement_lines,
)


def patch_yql_statement(lines_or_statement, table_path_prefix):
Expand Down Expand Up @@ -307,12 +284,12 @@ def assert_statement_import_table_data(self, statement):
def assert_statement(self, parsed_statement):
start_time = time.time()
from_type = {
StatementTypes.Ok: self.assert_statement_ok,
StatementTypes.Query: self.assert_statement_query,
StatementTypes.StreamQuery: self.assert_statement_stream_query,
StatementTypes.Error: (lambda x: x),
StatementTypes.ImportTableData: self.assert_statement_import_table_data,
StatementTypes.Skipped: lambda x: x
StatementDefinition.Type.Ok: self.assert_statement_ok,
StatementDefinition.Type.Query: self.assert_statement_query,
StatementDefinition.Type.StreamQuery: self.assert_statement_stream_query,
StatementDefinition.Type.Error: (lambda x: x),
StatementDefinition.Type.ImportTableData: self.assert_statement_import_table_data,
StatementDefinition.Type.Skipped: lambda x: x
}
assert_method = from_type.get(parsed_statement.s_type)
assert_method(parsed_statement)
Expand All @@ -329,10 +306,8 @@ def assert_statement_ok(self, statement):
)

def assert_statement_error(self, statement):
# not supported yet
statement_text, side_effects = get_statement_and_side_effects(statement.text)
assert_that(
lambda: self.execute_query(statement_text),
lambda: self.execute_query(statement.text),
raises(
ydb.Error
)
Expand Down
5 changes: 2 additions & 3 deletions ydb/tests/functional/suite_tests/test_sql_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from hamcrest import assert_that, raises

from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute, get_statement_and_side_effects
from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute

"""
This module is a specific runner of sqllogic tests. Test suites for this
Expand Down Expand Up @@ -38,8 +38,7 @@ def assert_statement_ok(self, statement):
safe_execute(lambda: self.__execute_sqlitedb(statement.text), statement)

def assert_statement_error(self, statement):
statement_text, side_effects = get_statement_and_side_effects(statement.text)
assert_that(lambda: self.__execute_sqlitedb(statement_text), raises(sqlite3.Error), str(statement))
assert_that(lambda: self.__execute_sqlitedb(statement.text), raises(sqlite3.Error), str(statement))
super(TestSQLLogic, self).assert_statement_error(statement)

def get_query_and_output(self, statement_text):
Expand Down

0 comments on commit cf03f2f

Please sign in to comment.