Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQLFluff /format endpoint to the server #4

Merged
merged 7 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions src/dbt_core_interface/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@

try:
import dbt_core_interface.state as dci_state
from dbt_core_interface.sqlfluff_util import lint_command
from dbt_core_interface.sqlfluff_util import format_command, lint_command
except ImportError:
dci_state = None
format_command = None
lint_command = None

# dbt-core-interface is designed for non-standard use. There is no
Expand Down Expand Up @@ -6231,6 +6232,80 @@ def lint_sql(
lint_result = {"result": [error for error in result]}
return lint_result

if format_command:

@route("/format", method="POST")
def format_sql(
runners: DbtProjectContainer,
):
LOGGER.info(f"format_sql()")
# Project Support
project_runner = (
runners.get_project(request.get_header("X-dbt-Project"))
or runners.get_default_project()
)
LOGGER.info(f"got project: {project_runner}")
if not project_runner:
response.status = 400
return asdict(
ServerErrorContainer(
error=ServerError(
code=ServerErrorCode.ProjectNotRegistered,
message=(
"Project is not registered. Make a POST request to the /register"
" endpoint first to register a runner"
),
data={"registered_projects": runners.registered_projects()},
)
)
)

sql_path = request.query.get("sql_path")
LOGGER.info(f"sql_path: {sql_path}")
if sql_path:
# Format a file
# NOTE: Formatting a string is not supported.
LOGGER.info(f"formatting file: {sql_path}")
sql = Path(sql_path)
else:
# Format a string
LOGGER.info(f"formatting string")
sql = request.body.getvalue().decode("utf-8")
if not sql:
response.status = 400
return {
"error": {
"data": {},
"message": (
"No SQL provided. Either provide a SQL file path or a SQL string to lint."
),
}
}
try:
LOGGER.info(f"Calling format_command()")
temp_result, formatted_sql = format_command(
Path(project_runner.config.project_root),
sql=sql,
extra_config_path=(
Path(request.query.get("extra_config_path"))
if request.query.get("extra_config_path")
else None
),
)
except Exception as format_err:
logging.exception("Formatting failed")
response.status = 500
return {
"error": {
"data": {},
"message": str(format_err),
}
}
else:
LOGGER.info(f"Formatting succeeded")
format_result = {"result": temp_result, "sql": formatted_sql}
return format_result


def run_server(runner: Optional[DbtProject] = None, host="localhost", port=8581):
"""Run the dbt core interface server.
Expand All @@ -6256,7 +6331,13 @@ def run_server(runner: Optional[DbtProject] = None, host="localhost", port=8581)
if __name__ == "__main__":
import argparse

# logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.ERROR)

# Configure logging for 'dbt_core_interface' and 'dbt_core_interface.sqlfluff_util'
for logger_name in ['dbt_core_interface', 'dbt_core_interface.sqlfluff_util']:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser(
description="Run the dbt interface server. Defaults to the WSGIRefServer"
)
Expand Down
154 changes: 150 additions & 4 deletions src/dbt_core_interface/sqlfluff_util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import atexit
import logging
import os
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Tuple, Union

from sqlfluff.cli.outputstream import FileOutput
from sqlfluff.core import SQLLintError, SQLTemplaterError
from sqlfluff.core.config import ConfigLoader, FluffConfig


LOGGER = logging.getLogger(__name__)


# Cache linters (up to 50 though its arbitrary)
@lru_cache(maxsize=50)
def get_linter(
Expand All @@ -17,7 +22,7 @@ def get_linter(
):
"""Get linter."""
from sqlfluff.cli.commands import get_linter_and_formatter
return get_linter_and_formatter(config, stream)[0]
return get_linter_and_formatter(config, stream)

# Cache config to prevent wasted frames
@lru_cache(maxsize=50)
Expand Down Expand Up @@ -81,7 +86,7 @@ def lint_command(
but for now this should provide maximum compatibility with the command-line
tool. We can also propose changes to SQLFluff to make this easier.
"""
lnt = get_linter(
lnt, formatter = get_linter(
*get_config(
project_root,
extra_config_path,
Expand All @@ -104,6 +109,109 @@ def lint_command(
return records[0] if records else None


def format_command(
project_root: Path,
sql: Union[Path, str],
extra_config_path: Optional[Path] = None,
ignore_local_config: bool = False,
) -> Tuple[bool, Optional[str]]:
"""Format specified file or SQL string.

This is essentially a streamlined version of the SQLFluff command-line
format function, sqlfluff.cli.commands.cli_format().

This function uses a few SQLFluff internals, but it should be relatively
stable. The initial plan was to use the public API, but that was not
behaving well initially. Small details about how SQLFluff handles .sqlfluff
and dbt_project.yaml file locations and overrides generate lots of support
questions, so it seems better to use this approach for now.

Eventually, we can look at using SQLFluff's public, high-level APIs,
but for now this should provide maximum compatibility with the command-line
tool. We can also propose changes to SQLFluff to make this easier.
"""
LOGGER.info(f"""format_command(
{project_root},
{str(sql)[:100]},
{extra_config_path},
{ignore_local_config})
""")
lnt, formatter = get_linter(
*get_config(
project_root,
extra_config_path,
ignore_local_config,
require_dialect=False,
nocolor=True,
rules=(
# All of the capitalisation rules
"capitalisation,"
# All of the layout rules
"layout,"
# Safe rules from other groups
"ambiguous.union,"
"convention.not_equal,"
"convention.coalesce,"
"convention.select_trailing_comma,"
"convention.is_null,"
"jinja.padding,"
"structure.distinct,"
)
)
)

if isinstance(sql, str):
# Lint SQL passed in as a string
LOGGER.info(f"Formatting SQL string: {sql[:100]}")
result = lnt.lint_string_wrapped(sql, fname="stdin", fix=True)
total_errors, num_filtered_errors = result.count_tmp_prs_errors()
result.discard_fixes_for_lint_errors_in_files_with_tmp_or_prs_errors()
success = not num_filtered_errors
num_fixable = result.num_violations(types=SQLLintError, fixable=True)
if num_fixable > 0:
LOGGER.info(f"Fixing {num_fixable} errors in SQL string")
result_sql = result.paths[0].files[0].fix_string()[0]
LOGGER.info(f"Result string has changes? {result_sql != sql}")
else:
LOGGER.info("No fixable errors in SQL string")
result_sql = sql
else:
# Format a SQL file
LOGGER.info(f"Formatting SQL file: {sql}")
before_modified = datetime.fromtimestamp(sql.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
LOGGER.info(f"Before fixing, modified: {before_modified}")
result_sql = None
lint_result = lnt.lint_paths(
paths=[str(sql)],
fix=True,
ignore_non_existent_files=False,
#processes=processes,
# If --force is set, then apply the changes as we go rather
# than waiting until the end.
apply_fixes=True,
#fixed_file_suffix=fixed_suffix,
fix_even_unparsable=False,
)
total_errors, num_filtered_errors = lint_result.count_tmp_prs_errors()
lint_result.discard_fixes_for_lint_errors_in_files_with_tmp_or_prs_errors()
success = not num_filtered_errors
if success:
num_fixable = lint_result.num_violations(types=SQLLintError, fixable=True)
if num_fixable > 0:
LOGGER.info(f"Fixing {num_fixable} errors in SQL file")
res = lint_result.persist_changes(
formatter=formatter, fixed_file_suffix=""
)
after_modified = datetime.fromtimestamp(sql.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
LOGGER.info(f"After fixing, modified: {after_modified}")
LOGGER.info(f"File modification time has changes? {before_modified != after_modified}")
success = all(res.values())
else:
LOGGER.info("No fixable errors in SQL file")
LOGGER.info(f"format_command returning success={success}, result_sql={result_sql[:100] if result_sql is not None else 'n/a'}")
return success, result_sql


def test_lint_command():
"""Quick and dirty functional test for lint_command().

Expand Down Expand Up @@ -138,5 +246,43 @@ def test_lint_command():
print(f"{'*'*40} Lint result {'*'*40}")


def test_format_command():
"""Quick and dirty functional test for format_command().

Handy for seeing SQLFluff logs if something goes wrong. The automated tests
make it difficult to see the logs.
"""
logging.basicConfig(level=logging.DEBUG)
from dbt_core_interface.project import DbtProjectContainer
dbt = DbtProjectContainer()
dbt.add_project(
name_override="dbt_project",
project_dir="tests/sqlfluff_templater/fixtures/dbt/dbt_project/",
profiles_dir="tests/sqlfluff_templater/fixtures/dbt/profiles_yml/",
target="dev",
)
sql_path = Path(
"tests/sqlfluff_templater/fixtures/dbt/dbt_project/models/my_new_project/issue_1608.sql"
)

# Test formatting a string
success, result_sql = format_command(
Path("tests/sqlfluff_templater/fixtures/dbt/dbt_project"),
sql=sql_path.read_text(),
)
print(f"{'*'*40} Formatting result {'*'*40}")
print(success, result_sql)

# Test formatting a file
result = format_command(
Path("tests/sqlfluff_templater/fixtures/dbt/dbt_project"),
sql=sql_path,
)
print(f"{'*'*40} Formatting result {'*'*40}")
print(result)
print(f"{'*'*40} Formatting result {'*'*40}")


if __name__ == "__main__":
test_lint_command()
#test_lint_command()
test_format_command()
75 changes: 73 additions & 2 deletions tests/sqlfluff_templater/test_server_v2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import difflib
import os
import shutil
import urllib.parse
from pathlib import Path

Expand All @@ -24,8 +27,8 @@
@pytest.mark.parametrize(
"param_name, param_value",
[
("sql_path", SQL_PATH),
(None, SQL_PATH.read_text()),
pytest.param("sql_path", SQL_PATH, id="sql_file"),
pytest.param(None, SQL_PATH.read_text(), id="sql_string"),
],
)
def test_lint(param_name, param_value, profiles_dir, project_dir, sqlfluff_config_path, caplog):
Expand Down Expand Up @@ -94,6 +97,74 @@ def test_lint_parse_failure(profiles_dir, project_dir, sqlfluff_config_path, cap
assert response_json == {"result": []}


@pytest.mark.parametrize(
"param_name, param_value",
[
pytest.param("sql_path", SQL_PATH, id="sql_file"),
pytest.param(None, SQL_PATH.read_text(), id="sql_string"),
],
)
def test_format(param_name, param_value, profiles_dir, project_dir, sqlfluff_config_path, caplog):
if param_name:
# Make a copy of the file and format the copy so we don't modify a file in
# git.
destination_path = param_value.parent / f"{param_value.stem + '_new'}{param_value.suffix}"
shutil.copy(str(param_value), str(destination_path))
param_value = destination_path

params = {}
kwargs = {}
data = ''
if param_name:
# Formatting a file
params[param_name] = param_value
original_lines = param_value.read_text().splitlines()
else:
data = param_value
original_lines = param_value.splitlines()
response = client.post(
f"/format?{urllib.parse.urlencode(params)}",
data,
headers={"X-dbt-Project": "dbt_project"},
**kwargs,
)
try:
assert response.status_code == 200

# Compare "before and after" SQL and verify the expected changes were made.
if param_name:
formatted_lines = destination_path.read_text().splitlines()
else:
formatted_lines = response.json["sql"].splitlines()
differ = difflib.Differ()
diff = list(differ.compare(original_lines, formatted_lines))
assert diff == [
" {{ config(materialized='view') }}",
" ",
" with cte_example as (",
"- select 1 as col_name",
"? -\n",
"+ select 1 as col_name",
" ),",
" ",
"- final as",
"+ final as (",
"? ++\n",
"- (",
" select",
" col_name,",
" {{- echo('col_name') -}} as col_name2",
" from",
" cte_example",
" )",
" ",
" select * from final",
]
finally:
if param_name:
os.unlink(destination_path)


@pytest.mark.parametrize(
"param_name, param_value, clients, sample",
[
Expand Down