From 86e249d904166ee90b6d38378e328342e108e143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Tue, 7 May 2024 17:06:14 -0600 Subject: [PATCH] chore: Standardized `record` and `context` types (#2415) chore: Standardize `record` and `context` types` --- singer_sdk/streams/core.py | 74 +++++++++++++++++++++-------------- singer_sdk/streams/graphql.py | 5 ++- singer_sdk/streams/rest.py | 21 +++++----- singer_sdk/streams/sql.py | 3 +- 4 files changed, 61 insertions(+), 42 deletions(-) diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index f1b82a2c8..886466c5d 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -6,6 +6,7 @@ import copy import datetime import json +import sys import typing as t from os import PathLike from pathlib import Path @@ -50,6 +51,11 @@ from singer_sdk.helpers._util import utc_now from singer_sdk.mapper import RemoveRecordTransform, SameRecordTransform, StreamMap +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias # noqa: ICN003 + if t.TYPE_CHECKING: import logging @@ -62,6 +68,8 @@ REPLICATION_LOG_BASED = "LOG_BASED" FactoryType = t.TypeVar("FactoryType", bound="Stream") +Record: TypeAlias = t.Dict[str, t.Any] +Context: TypeAlias = t.Dict class Stream(metaclass=abc.ABCMeta): # noqa: PLR0904 @@ -227,7 +235,7 @@ def is_timestamp_replication_key(self) -> bool: def get_starting_replication_key_value( self, - context: dict | None, + context: Context | None, ) -> t.Any | None: # noqa: ANN401 """Get starting replication key. @@ -252,7 +260,9 @@ def get_starting_replication_key_value( else None ) - def get_starting_timestamp(self, context: dict | None) -> datetime.datetime | None: + def get_starting_timestamp( + self, context: Context | None + ) -> datetime.datetime | None: """Get starting replication timestamp. Will return the value of the stream's replication key when `--state` is passed. @@ -330,7 +340,7 @@ def descendent_streams(self) -> list[Stream]: def _write_replication_key_signpost( self, - context: dict | None, + context: Context | None, value: datetime.datetime | str | int | float, ) -> None: """Write the signpost value, if available. @@ -371,7 +381,7 @@ def compare_start_date(self, value: str, start_date_value: str) -> str: return value - def _write_starting_replication_value(self, context: dict | None) -> None: + def _write_starting_replication_value(self, context: Context | None) -> None: """Write the starting replication value, if available. Args: @@ -399,7 +409,7 @@ def _write_starting_replication_value(self, context: dict | None) -> None: def get_replication_key_signpost( self, - context: dict | None, # noqa: ARG002 + context: Context | None, # noqa: ARG002 ) -> datetime.datetime | t.Any | None: # noqa: ANN401 """Get the replication signpost. @@ -646,7 +656,7 @@ def tap_state(self) -> dict: """ return self._tap_state - def get_context_state(self, context: dict | None) -> dict: + def get_context_state(self, context: Context | None) -> dict: """Return a writable state dict for the given context. Gives a partitioned context state if applicable; else returns stream state. @@ -701,7 +711,7 @@ def stream_state(self) -> dict: # Partitions @property - def partitions(self) -> list[dict] | None: + def partitions(self) -> list[Context] | None: """Get stream partitions. Developers may override this property to provide a default partitions list. @@ -724,9 +734,9 @@ def partitions(self) -> list[dict] | None: def _increment_stream_state( self, - latest_record: dict[str, t.Any], + latest_record: Record, *, - context: dict | None = None, + context: Context | None = None, ) -> None: """Update state of stream or partition with data from the provided record. @@ -817,7 +827,7 @@ def mask(self) -> singer.SelectionMask: def _generate_record_messages( self, - record: dict, + record: Record, ) -> t.Generator[singer.RecordMessage, None, None]: """Write out a RECORD message. @@ -846,7 +856,7 @@ def _generate_record_messages( time_extracted=utc_now(), ) - def _write_record_message(self, record: dict) -> None: + def _write_record_message(self, record: Record) -> None: """Write out a RECORD message. Args: @@ -963,7 +973,7 @@ def reset_state_progress_markers(self, state: dict | None = None) -> None: state: State object to promote progress markers with. """ if state is None or state == {}: - context: dict | None + context: Context | None for context in self.partitions or [{}]: state = self.get_context_state(context or None) reset_state_progress_markers(state) @@ -992,7 +1002,7 @@ def finalize_state_progress_markers(self, state: dict | None = None) -> None: for child_stream in self.child_streams or []: child_stream.finalize_state_progress_markers() - context: dict | None + context: Context | None for context in self.partitions or [{}]: state = self.get_context_state(context or None) self._finalize_state(state) @@ -1005,9 +1015,9 @@ def finalize_state_progress_markers(self, state: dict | None = None) -> None: def _process_record( self, - record: dict, - child_context: dict | None = None, - partition_context: dict | None = None, + record: Record, + child_context: Context | None = None, + partition_context: Context | None = None, ) -> None: """Process a record. @@ -1032,7 +1042,7 @@ def _process_record( def _sync_records( # noqa: C901 self, - context: dict | None = None, + context: Context | None = None, *, write_messages: bool = True, ) -> t.Generator[dict, t.Any, t.Any]: @@ -1054,7 +1064,7 @@ def _sync_records( # noqa: C901 timer = metrics.sync_timer(self.name) record_index = 0 - context_element: dict | None + context_element: Context | None context_list: list[dict] | None context_list = [context] if context is not None else self.partitions selected = self.selected @@ -1070,7 +1080,7 @@ def _sync_records( # noqa: C901 current_context, ) self._write_starting_replication_value(current_context) - child_context: dict | None = ( + child_context: Context | None = ( None if current_context is None else copy.copy(current_context) ) @@ -1131,7 +1141,7 @@ def _sync_records( # noqa: C901 def _sync_batches( self, batch_config: BatchConfig, - context: dict | None = None, + context: Context | None = None, ) -> None: """Sync batches, emitting BATCH messages. @@ -1148,7 +1158,7 @@ def _sync_batches( # Public methods ("final", not recommended to be overridden) @t.final - def sync(self, context: dict | None = None) -> None: + def sync(self, context: Context | None = None) -> None: """Sync this stream. This method is internal to the SDK and should not need to be overridden. @@ -1188,7 +1198,7 @@ def sync(self, context: dict | None = None) -> None: ) raise - def _sync_children(self, child_context: dict | None) -> None: + def _sync_children(self, child_context: Context | None) -> None: if child_context is None: self.logger.warning( "Context for child streams of '%s' is null, " @@ -1223,7 +1233,7 @@ def apply_catalog(self, catalog: singer.Catalog) -> None: if catalog_entry.replication_method: self.forced_replication_method = catalog_entry.replication_method - def _get_state_partition_context(self, context: dict | None) -> dict | None: + def _get_state_partition_context(self, context: Context | None) -> dict | None: """Override state handling if Stream.state_partitioning_keys is specified. Args: @@ -1240,7 +1250,11 @@ def _get_state_partition_context(self, context: dict | None) -> dict | None: return {k: v for k, v in context.items() if k in self.state_partitioning_keys} - def get_child_context(self, record: dict, context: dict | None) -> dict | None: + def get_child_context( + self, + record: Record, + context: Context | None, + ) -> dict | None: """Return a child context object from the record and optional provided context. By default, will return context if provided and otherwise the record dict. @@ -1281,8 +1295,8 @@ def get_child_context(self, record: dict, context: dict | None) -> dict | None: def generate_child_contexts( self, - record: dict, - context: dict | None, + record: Record, + context: Context | None, ) -> t.Iterable[dict | None]: """Generate child contexts. @@ -1300,7 +1314,7 @@ def generate_child_contexts( @abc.abstractmethod def get_records( self, - context: dict | None, + context: Context | None, ) -> t.Iterable[dict | tuple[dict, dict | None]]: """Abstract record generator function. Must be overridden by the child class. @@ -1346,7 +1360,7 @@ def get_batch_config(self, config: t.Mapping) -> BatchConfig | None: # noqa: PL def get_batches( self, batch_config: BatchConfig, - context: dict | None = None, + context: Context | None = None, ) -> t.Iterable[tuple[BaseBatchFileEncoding, list[str]]]: """Batch generator function. @@ -1371,8 +1385,8 @@ def get_batches( def post_process( # noqa: PLR6301 self, - row: dict, - context: dict | None = None, # noqa: ARG002 + row: Record, + context: Context | None = None, # noqa: ARG002 ) -> dict | None: """As needed, append or transform raw data to match expected structure. diff --git a/singer_sdk/streams/graphql.py b/singer_sdk/streams/graphql.py index fde4f99b9..04a2e80d6 100644 --- a/singer_sdk/streams/graphql.py +++ b/singer_sdk/streams/graphql.py @@ -8,6 +8,9 @@ from singer_sdk.helpers._classproperty import classproperty from singer_sdk.streams.rest import RESTStream +if t.TYPE_CHECKING: + from singer_sdk.streams.core import Context + _TToken = t.TypeVar("_TToken") @@ -44,7 +47,7 @@ def query(self) -> str: def prepare_request_payload( self, - context: dict | None, + context: Context | None, next_page_token: _TToken | None, ) -> dict | None: """Prepare the data payload for the GraphQL API request. diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index cbc64086c..e96537246 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -32,6 +32,7 @@ from backoff.types import Details from singer_sdk._singerlib import Schema + from singer_sdk.streams.core import Context from singer_sdk.tap_base import Tap if sys.version_info >= (3, 10): @@ -110,7 +111,7 @@ def _url_encode(val: str | datetime | bool | int | list[str]) -> str: # noqa: F """ return val.replace("/", "%2F") if isinstance(val, str) else str(val) - def get_url(self, context: dict | None) -> str: + def get_url(self, context: Context | None) -> str: """Get stream entity URL. Developers override this method to perform dynamic URL generation. @@ -245,7 +246,7 @@ def request_decorator(self, func: t.Callable) -> t.Callable: def _request( self, prepared_request: requests.PreparedRequest, - context: dict | None, + context: Context | None, ) -> requests.Response: """TODO. @@ -271,7 +272,7 @@ def _request( def get_url_params( # noqa: PLR6301 self, - context: dict | None, # noqa: ARG002 + context: Context | None, # noqa: ARG002 next_page_token: _TToken | None, # noqa: ARG002 ) -> dict[str, t.Any] | str: """Return a dictionary or string of URL query parameters. @@ -325,7 +326,7 @@ def build_prepared_request( def prepare_request( self, - context: dict | None, + context: Context | None, next_page_token: _TToken | None, ) -> requests.PreparedRequest: """Prepare a request object for this stream. @@ -357,7 +358,7 @@ def prepare_request( json=request_data, ) - def request_records(self, context: dict | None) -> t.Iterable[dict]: + def request_records(self, context: Context | None) -> t.Iterable[dict]: """Request records from REST endpoint(s), returning response records. If pagination is detected, pages will be recursed automatically. @@ -403,7 +404,7 @@ def _write_request_duration_log( self, endpoint: str, response: requests.Response, - context: dict | None, + context: Context | None, extra_tags: dict | None, ) -> None: """TODO. @@ -440,7 +441,7 @@ def update_sync_costs( self, request: requests.PreparedRequest, response: requests.Response, - context: dict | None, + context: Context | None, ) -> dict[str, int]: """Update internal calculation of Sync costs. @@ -465,7 +466,7 @@ def calculate_sync_cost( # noqa: PLR6301 self, request: requests.PreparedRequest, # noqa: ARG002 response: requests.Response, # noqa: ARG002 - context: dict | None, # noqa: ARG002 + context: Context | None, # noqa: ARG002 ) -> dict[str, int]: """Calculate the cost of the last API call made. @@ -494,7 +495,7 @@ def calculate_sync_cost( # noqa: PLR6301 def prepare_request_payload( self, - context: dict | None, + context: Context | None, next_page_token: _TToken | None, ) -> dict | None: """Prepare the data payload for the REST API request. @@ -560,7 +561,7 @@ def timeout(self) -> int: # Records iterator - def get_records(self, context: dict | None) -> t.Iterable[dict[str, t.Any]]: + def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]: """Return a generator of record-type dictionary objects. Each record emitted should be a dictionary of property names to their values. diff --git a/singer_sdk/streams/sql.py b/singer_sdk/streams/sql.py index 28fd0dc33..2b610a2a5 100644 --- a/singer_sdk/streams/sql.py +++ b/singer_sdk/streams/sql.py @@ -14,6 +14,7 @@ from singer_sdk.streams.core import Stream if t.TYPE_CHECKING: + from singer_sdk.streams.core import Context from singer_sdk.tap_base import Tap @@ -157,7 +158,7 @@ def get_selected_schema(self) -> dict: ) # Get records from stream - def get_records(self, context: dict | None) -> t.Iterable[dict[str, t.Any]]: + def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]: """Return a generator of record-type dictionary objects. If the stream has a replication_key value defined, records will be sorted by the