diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index fbd60b9d..d94f97f0 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1861,6 +1861,29 @@ def test_select_query_spooled_segments(trino_connection): assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}" +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_segments_cursor(trino_connection): + if trino_connection._client_session.encoding is None: + with pytest.raises(ValueError, match=".*encoding.*"): + trino_connection.cursor("segment") + return + cur = trino_connection.cursor("segment") + cur.execute("""SELECT l.* + FROM tpch.tiny.lineitem l, TABLE(sequence( + start => 1, + stop => 5, + step => 1)) n""") + rows = cur.fetchall() + assert len(rows) > 0 + for spooled_data, spooled_segment in rows: + assert spooled_data.encoding == trino_connection._client_session.encoding + assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}" + assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}" + + def get_cursor(legacy_prepared_statements, run_trino): host, port = run_trino diff --git a/trino/client.py b/trino/client.py index b2b9f14c..29692056 100644 --- a/trino/client.py +++ b/trino/client.py @@ -799,6 +799,7 @@ def __init__( request: TrinoRequest, query: str, legacy_primitive_types: bool = False, + fetch_mode: Literal["mapped", "segments"] = "mapped" ) -> None: self._query_id: Optional[str] = None self._stats: Dict[Any, Any] = {} @@ -815,6 +816,7 @@ def __init__( self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types self._row_mapper: Optional[RowMapper] = None + self._fetch_mode = fetch_mode @property def query_id(self) -> Optional[str]: @@ -919,6 +921,8 @@ def fetch(self) -> List[Union[List[Any]], Any]: # spooling protocol rows = cast(_SpooledProtocolResponseTO, rows) segments = self._to_segments(rows) + if self._fetch_mode == "segments": + return segments return list(SegmentIterator(segments, self._row_mapper)) elif isinstance(status.rows, list): return self._row_mapper.map(rows) diff --git a/trino/dbapi.py b/trino/dbapi.py index 426532f9..dee7cdb7 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -268,7 +268,7 @@ def _create_request(self): self.request_timeout, ) - def cursor(self, legacy_primitive_types: bool = None): + def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None): """Return a new :py:class:`Cursor` object using the connection.""" if self.isolation_level != IsolationLevel.AUTOCOMMIT: if self.transaction is None: @@ -277,11 +277,21 @@ def cursor(self, legacy_primitive_types: bool = None): request = self.transaction.request else: request = self._create_request() - return Cursor( + + cursor_class = { + # Add any custom Cursor classes here + "segment": SegmentCursor, + "row": Cursor + }.get(cursor_style.lower(), Cursor) + + return cursor_class( self, request, - # if legacy params are not explicitly set in Cursor, take them from Connection - legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types + legacy_primitive_types=( + legacy_primitive_types + if legacy_primitive_types is not None + else self.legacy_primitive_types + ) ) def _use_legacy_prepared_statements(self): @@ -714,6 +724,28 @@ def close(self): # but also any other outstanding queries executed through this cursor. +class SegmentCursor(Cursor): + def __init__( + self, + connection, + request, + legacy_primitive_types: bool = False): + super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types) + if self.connection._client_session.encoding is None: + raise ValueError("SegmentCursor can only be used if encoding is set on the connection") + + def execute(self, operation, params=None): + if params: + # TODO: refactor code to allow for params to be supported + raise ValueError("params not supported") + + self._query = trino.client.TrinoQuery(self._request, query=operation, + legacy_primitive_types=self._legacy_primitive_types, + fetch_mode="segments") + self._iterator = iter(self._query.execute()) + return self + + Date = datetime.date Time = datetime.time Timestamp = datetime.datetime