Skip to content

Commit

Permalink
Support "segment" cursor style
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and hashhar committed Jan 13, 2025
1 parent 43ed692 commit 3a9de8b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
23 changes: 23 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,29 @@ def test_select_query_spooled_segments(trino_connection):
assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}"


@pytest.mark.skipif(
trino_version() <= 466,
reason="spooling protocol was introduced in version 466"
)
def test_segments_cursor(trino_connection):
if trino_connection._client_session.encoding is None:
with pytest.raises(ValueError, match=".*encoding.*"):
trino_connection.cursor("segment")
return
cur = trino_connection.cursor("segment")
cur.execute("""SELECT l.*
FROM tpch.tiny.lineitem l, TABLE(sequence(
start => 1,
stop => 5,
step => 1)) n""")
rows = cur.fetchall()
assert len(rows) > 0
for spooled_data, spooled_segment in rows:
assert spooled_data.encoding == trino_connection._client_session.encoding
assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}"
assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}"


def get_cursor(legacy_prepared_statements, run_trino):
host, port = run_trino

Expand Down
4 changes: 4 additions & 0 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def __init__(
request: TrinoRequest,
query: str,
legacy_primitive_types: bool = False,
fetch_mode: Literal["mapped", "segments"] = "mapped"
) -> None:
self._query_id: Optional[str] = None
self._stats: Dict[Any, Any] = {}
Expand All @@ -815,6 +816,7 @@ def __init__(
self._result: Optional[TrinoResult] = None
self._legacy_primitive_types = legacy_primitive_types
self._row_mapper: Optional[RowMapper] = None
self._fetch_mode = fetch_mode

@property
def query_id(self) -> Optional[str]:
Expand Down Expand Up @@ -919,6 +921,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
# spooling protocol
rows = cast(_SpooledProtocolResponseTO, rows)
segments = self._to_segments(rows)
if self._fetch_mode == "segments":
return segments
return list(SegmentIterator(segments, self._row_mapper))
elif isinstance(status.rows, list):
return self._row_mapper.map(rows)
Expand Down
40 changes: 36 additions & 4 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _create_request(self):
self.request_timeout,
)

def cursor(self, legacy_primitive_types: bool = None):
def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None):
"""Return a new :py:class:`Cursor` object using the connection."""
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
if self.transaction is None:
Expand All @@ -277,11 +277,21 @@ def cursor(self, legacy_primitive_types: bool = None):
request = self.transaction.request
else:
request = self._create_request()
return Cursor(

cursor_class = {
# Add any custom Cursor classes here
"segment": SegmentCursor,
"row": Cursor
}.get(cursor_style.lower(), Cursor)

return cursor_class(
self,
request,
# if legacy params are not explicitly set in Cursor, take them from Connection
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
legacy_primitive_types=(
legacy_primitive_types
if legacy_primitive_types is not None
else self.legacy_primitive_types
)
)

def _use_legacy_prepared_statements(self):
Expand Down Expand Up @@ -714,6 +724,28 @@ def close(self):
# but also any other outstanding queries executed through this cursor.


class SegmentCursor(Cursor):
def __init__(
self,
connection,
request,
legacy_primitive_types: bool = False):
super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types)
if self.connection._client_session.encoding is None:
raise ValueError("SegmentCursor can only be used if encoding is set on the connection")

def execute(self, operation, params=None):
if params:
# TODO: refactor code to allow for params to be supported
raise ValueError("params not supported")

self._query = trino.client.TrinoQuery(self._request, query=operation,
legacy_primitive_types=self._legacy_primitive_types,
fetch_mode="segments")
self._iterator = iter(self._query.execute())
return self


Date = datetime.date
Time = datetime.time
Timestamp = datetime.datetime
Expand Down

0 comments on commit 3a9de8b

Please sign in to comment.