From 9e51a4b1dc9c62e9b2512afbca1643fc5b414d6a Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 27 Sep 2024 20:45:04 +0200 Subject: [PATCH] Discover schema from `__table__` data classes --- labs.yml | 3 ++ pyproject.toml | 4 +++ src/databricks/labs/lsql/discovery.py | 45 +++++++++++++++++++++++++++ tests/unit/test_discovery.py | 9 ++++++ tests/unit/test_structs.py | 2 ++ 5 files changed, 63 insertions(+) create mode 100644 src/databricks/labs/lsql/discovery.py create mode 100644 tests/unit/test_discovery.py diff --git a/labs.yml b/labs.yml index b82ccc5d..e61495a6 100644 --- a/labs.yml +++ b/labs.yml @@ -34,3 +34,6 @@ commands: description: Publish the dashboard after creating by setting to `yes` or `y`. - name: open-browser description: Open the dashboard in the browser after creating by setting to `yes` or `y`. + + - name: deploy-schema + description: Create schema from dataclasses in the given folder \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index eb7512b1..f3857f8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,9 @@ dependencies = [ "sqlglot>=22.3.1" ] +[project.optional-dependencies] +df = ["sqlframe~=3.3.1"] + [project.urls] Documentation = "https://github.com/databrickslabs/lsql#readme" Issues = "https://github.com/databrickslabs/lsql/issues" @@ -43,6 +46,7 @@ path = "src/databricks/labs/lsql/__about__.py" [tool.hatch.envs.default] dependencies = [ + "databricks-labs-lsql[df]", "coverage[toml]>=6.5", "pytest", "pylint", diff --git a/src/databricks/labs/lsql/discovery.py b/src/databricks/labs/lsql/discovery.py new file mode 100644 index 00000000..564fd0fb --- /dev/null +++ b/src/databricks/labs/lsql/discovery.py @@ -0,0 +1,45 @@ +import ast +from pathlib import Path + + +class DataclassTableFinder(ast.NodeVisitor): + def __init__(self): + self.tables = [] + + def visit_ClassDef(self, node): + # Check if the class is a dataclass + is_dataclass = any(isinstance(decorator, ast.Name) and decorator.id == 'dataclass' + for decorator in node.decorator_list) + + # Look for __table__ assignment in class body + has_table_field = any(isinstance(n, ast.Assign) and + any(isinstance(t, ast.Name) and t.id == '__table__' for t in n.targets) + for n in node.body) + + # If both conditions are met, store the class name + if is_dataclass and has_table_field: + self.tables.append(node.name) + + # Continue visiting the rest of the AST + self.generic_visit(node) + + +class Scanner: + def __init__(self, start: Path): + self._start = start + + def find_all(self): + for f in self._start.glob('**/*.py'): # TODO: skip virtual environments + yield from self._find_dataclasses_with_table(f) + + def _find_dataclasses_with_table(self, path: Path): + # Parse the source code into an AST + tree = ast.parse(path.read_text()) + + # Create a finder instance and visit the parsed tree + finder = DataclassTableFinder() + finder.visit(tree) + + # Return the list of dataclasses with __table__ field + return finder.tables + diff --git a/tests/unit/test_discovery.py b/tests/unit/test_discovery.py new file mode 100644 index 00000000..067729c2 --- /dev/null +++ b/tests/unit/test_discovery.py @@ -0,0 +1,9 @@ +from pathlib import Path + +from databricks.labs.lsql.discovery import Scanner + + +def test_finds(): + s = Scanner(Path('/Users/serge.smertin/git/labs/lsql/tests')) + x = list(s.find_all()) + assert x == ['Nested'] \ No newline at end of file diff --git a/tests/unit/test_structs.py b/tests/unit/test_structs.py index 10c90570..a809d725 100644 --- a/tests/unit/test_structs.py +++ b/tests/unit/test_structs.py @@ -15,6 +15,8 @@ class Foo: @dataclass class Nested: + __table__ = 'x' + foo: Foo mapping: dict[str, int] array: list[int]