Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sqllab): Force trino client async execution #24859

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix(sqllab): Force trino client async execution
We are currently unable to stop trino queries, because the underlying
trino client blocks until the query completes, and doesn't make the
query ID or any other info available in the meantime. Unfortunately it
doesn't look like they plan to change that any time soon, either.

Make the following changes:
  - Add a new method execute_with_cursor to db_engine_spec which
    combines execute with handle_cursor, factoring it out of the one
    place it's used, deep in the execute query logic.
  - Make handle_cursor poll the cursor for query ID, as it's going to
    be populated asynchronously. Add warnings that the trino impl will
    require using execute_with_cursor. Currently nothing is directly
    calling handle_cursor, with the one original call eliminated.
  - Override execute_with_cursor for the trino engine and execute the
    two tasks in parallel to allow us to poll for the query ID
    while the query is still blocking.
  • Loading branch information
giftig committed Aug 2, 2023
commit 4919b58685887a8f2f3fd8a3ff2246bbd2af05f0
18 changes: 18 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
@@ -1026,6 +1026,24 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
query object"""
# TODO: Fix circular import error caused by importing sql_lab.Query

@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
) -> None:
"""
Trigger execution of a query and handle the resulting cursor.

For most implementations this just makes calls to `execute` and
`handle_cursor` consecutively, but in some engines (e.g. Trino) we may
need to handle client limitations such as lack of async support and
perform a more complicated operation to get information from the cursor
in a timely manner and facilitate operations such as query stop
"""
logger.debug("Query %d: Running query: %s", query.id, sql)
cls.execute(cursor, sql, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
return f"{cls.engine} error: {cls._extract_error_message(ex)}"
66 changes: 60 additions & 6 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@
from __future__ import annotations

import logging
import threading
import time
from typing import Any, TYPE_CHECKING

import simplejson as json
@@ -149,14 +151,21 @@ def get_tracking_url(cls, cursor: Cursor) -> str | None:

@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
"""
Handle a trino client cursor.

WARNING: if you execute a query, it will block until complete and you
will not be able to handle the cursor until complete. Use
`execute_with_cursor` instead, to handle this asynchronously.
"""

# Adds the executed query id to the extra payload so the query can be cancelled
query.set_extra_json_key(
key=QUERY_CANCEL_KEY,
value=(cancel_query_id := cursor.stats["queryId"]),
)
cancel_query_id = cursor.query_id
logger.debug("Query %d: queryId %s found in cursor", query.id, cancel_query_id)
query.set_extra_json_key(key=QUERY_CANCEL_KEY, value=cancel_query_id)

if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url

session.commit()

@@ -171,6 +180,51 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:

super().handle_cursor(cursor=cursor, query=query, session=session)

@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
) -> None:
"""
Trigger execution of a query and handle the resulting cursor.

Trino's client blocks until the query is complete, so we need to run it
in another thread and invoke `handle_cursor` to poll for the query ID
to appear on the cursor in parallel.
"""
execute_result: dict[str, Any] = {}

def _execute(results: dict[str, Any]) -> None:
logger.debug("Query %d: Running query: %s", query.id, sql)

# Pass result / exception information back to the parent thread
try:
cls.execute(cursor, sql)
results["complete"] = True
except Exception as ex: # pylint: disable=broad-except
results["complete"] = True
results["error"] = ex

execute_thread = threading.Thread(target=_execute, args=(execute_result,))
execute_thread.start()

# Wait for a query ID to be available before handling the cursor, as
# it's required by that method; it may never become available on error.
while not cursor.query_id and not execute_result.get("complete"):
time.sleep(0.1)

logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)

# Block until the query completes; same behaviour as the client itself
logger.debug("Query %d: Waiting for query to complete", query.id)
while not execute_result.get("complete"):
time.sleep(0.5)

# Unfortunately we'll mangle the stack trace due to the thread, but
# throwing the original exception allows mapping database errors as normal
if err := execute_result.get("error"):
raise err

@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
if QUERY_CANCEL_KEY not in query.extra:
7 changes: 2 additions & 5 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
@@ -191,7 +191,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
return handle_query_error(ex, query, session)


def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statements
def execute_sql_statement( # pylint: disable=too-many-arguments
sql_statement: str,
query: Query,
session: Session,
@@ -270,10 +270,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem
)
session.commit()
with stats_timing("sqllab.query.time_executing_query", stats_logger):
logger.debug("Query %d: Running query: %s", query.id, sql)
db_engine_spec.execute(cursor, sql, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
db_engine_spec.handle_cursor(cursor, query, session)
db_engine_spec.execute_with_cursor(cursor, sql, query, session)

with stats_timing("sqllab.query.time_fetching_results", stats_logger):
logger.debug(
31 changes: 30 additions & 1 deletion tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
@@ -352,7 +352,7 @@ def test_handle_cursor_early_cancel(
query_id = "myQueryId"

cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.stats = {"queryId": query_id}
cursor_mock.query_id = query_id
session_mock = mocker.MagicMock()

query = Query()
@@ -366,3 +366,32 @@ def test_handle_cursor_early_cancel(
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
else:
assert cancel_query_mock.call_args is None


def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec

query_id = "myQueryId"

mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None

mock_query = mocker.MagicMock()
mock_session = mocker.MagicMock()

def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id

mock_cursor.execute.side_effect = _mock_execute

TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
session=mock_session,
)

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
10 changes: 4 additions & 6 deletions tests/unit_tests/sql_lab_test.py
Original file line number Diff line number Diff line change
@@ -55,8 +55,8 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
)

database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
db_engine_spec.execute.assert_called_with(
cursor, "SELECT 42 AS answer LIMIT 2", async_=True
db_engine_spec.execute_with_cursor.assert_called_with(
cursor, "SELECT 42 AS answer LIMIT 2", query, session
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)

@@ -106,10 +106,8 @@ def test_execute_sql_statement_with_rls(
101,
force=True,
)
db_engine_spec.execute.assert_called_with(
cursor,
"SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
async_=True,
db_engine_spec.execute_with_cursor.assert_called_with(
cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query, session
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)