diff --git a/airtabledb/adapter.py b/airtabledb/adapter.py index 431969e..4ad597f 100644 --- a/airtabledb/adapter.py +++ b/airtabledb/adapter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple from pyairtable import Table from shillelagh.adapters.base import Adapter @@ -19,7 +19,8 @@ def __init__( table: str, base_id: str, api_key: str, - base_metadata: BaseMetadata, + base_metadata: Optional[BaseMetadata], + peek_rows: Optional[int], ): super().__init__() @@ -28,7 +29,7 @@ def __init__( self._table_api = Table(api_key, base_id, table) - fields: List[str] + fields: Iterable[str] if self.base_metadata is not None: # TODO(cancan101): Better error handling here # We search by name here. @@ -41,12 +42,26 @@ def __init__( columns_metadata = table_metadata["columns"] fields = [col["name"] for col in columns_metadata] self.strict_col = True + # Attempts introspection by looking at data. + # This is super not reliable + # as Airtable removes the key if the value is empty. else: - # This introspects the first row in the table. - # This is super not reliable - # as Airtable removes the key if the value is empty. - # We should probably look at more than one entry. - fields = self._table_api.first()["fields"] + # This introspects the just first row in the table. + if peek_rows is None or peek_rows == 1: + fields = self._table_api.first()["fields"].keys() + # Or peek at specified number of rows + else: + # We have an explicit type check here as the Airtable API + # just ignores the value if it isn't valid. + if not isinstance(peek_rows, int): + raise TypeError( + f"peek_rows should be an int. Got: {type(peek_rows)}" + ) + + fields = set() + for row in self._table_api.all(max_records=peek_rows): + fields |= row["fields"].keys() + self.strict_col = False # TODO(cancan101): parse out types diff --git a/airtabledb/dialect.py b/airtabledb/dialect.py index 0ff5b5c..2fea094 100644 --- a/airtabledb/dialect.py +++ b/airtabledb/dialect.py @@ -80,7 +80,13 @@ def create_connect_args( if url.password and self.airtable_api_key: raise ValueError("Both password and airtable_api_key were provided") - _, url_host = extract_query_host(url) + url_query, url_host = extract_query_host(url) + peek_rows = None + if "peek_rows" in url_query: + peek_rows_raw = url_query["peek_rows"] + if not isinstance(peek_rows_raw, str): + peek_rows_raw = peek_rows_raw[-1] + peek_rows = int(peek_rows_raw) # At some point we might have args adapter_kwargs = { @@ -88,6 +94,7 @@ def create_connect_args( "api_key": self.airtable_api_key or url.password, "base_id": url_host, "base_metadata": self.base_metadata, + "peek_rows": peek_rows, } } diff --git a/tests/test_dialect.py b/tests/test_dialect.py index c5d8133..a39e800 100644 --- a/tests/test_dialect.py +++ b/tests/test_dialect.py @@ -84,3 +84,21 @@ def test_extract_query_host_no_query(): query, host = extract_query_host(URL.create(drivername="drive", host="myhost")) assert host == "myhost" assert query == {} + + +def test_peek_rows_default(): + url_http = make_url("airtable://foo") + _, kwargs = APSWAirtableDialect().create_connect_args(url_http) + assert _get_adapter_kwargs(kwargs)["peek_rows"] is None + + +def test_peek_rows_single(): + url_http = make_url("airtable://foo?peek_rows=12") + _, kwargs = APSWAirtableDialect().create_connect_args(url_http) + assert _get_adapter_kwargs(kwargs)["peek_rows"] == 12 + + +def test_peek_rows_dupe(): + url_http = make_url("airtable://foo?peek_rows=12&peek_rows=13") + _, kwargs = APSWAirtableDialect().create_connect_args(url_http) + assert _get_adapter_kwargs(kwargs)["peek_rows"] == 13