Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hints: Finish type hints and mark package as typed #3272

Merged
merged 8 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions altair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"CalculateTransform",
"Categorical",
"Chart",
"ChartDataType",
"Color",
"ColorDatum",
"ColorDef",
Expand Down Expand Up @@ -125,7 +126,9 @@
"Cyclical",
"Data",
"DataFormat",
"DataFrameLike",
"DataSource",
"DataType",
"Datasets",
"DateTime",
"DatumChannelMixin",
Expand Down
Empty file added altair/py.typed
Empty file.
6 changes: 3 additions & 3 deletions altair/utils/_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
data_transformers,
)
from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion
from altair.utils.core import _DataFrameLike
from altair.utils.core import DataFrameLike
from altair.utils.schemapi import Undefined

Scope = Tuple[int, ...]
Expand Down Expand Up @@ -56,7 +56,7 @@ def transformed_data(
chart: Union[Chart, FacetChart],
row_limit: Optional[int] = None,
exclude: Optional[Iterable[str]] = None,
) -> Optional[_DataFrameLike]:
) -> Optional[DataFrameLike]:
...


Expand All @@ -65,7 +65,7 @@ def transformed_data(
chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart],
row_limit: Optional[int] = None,
exclude: Optional[Iterable[str]] = None,
) -> List[_DataFrameLike]:
) -> List[DataFrameLike]:
...


Expand Down
12 changes: 6 additions & 6 deletions altair/utils/_vegafusion_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from typing import TypedDict, Final

from altair.utils._importers import import_vegafusion
from altair.utils.core import _DataFrameLike
from altair.utils.data import _DataType, _ToValuesReturnType, MaxRowsError
from altair.utils.core import DataFrameLike
from altair.utils.data import DataType, ToValuesReturnType, MaxRowsError
from altair.vegalite.data import default_data_transformer

# Temporary storage for dataframes that have been extracted
# from charts by the vegafusion data transformer. Use a WeakValueDictionary
# rather than a dict so that the Python interpreter is free to garbage
# collect the stored DataFrames.
extracted_inline_tables: MutableMapping[str, _DataFrameLike] = WeakValueDictionary()
extracted_inline_tables: MutableMapping[str, DataFrameLike] = WeakValueDictionary()

# Special URL prefix that VegaFusion uses to denote that a
# dataset in a Vega spec corresponds to an entry in the `inline_datasets`
Expand All @@ -29,8 +29,8 @@ class _ToVegaFusionReturnUrlDict(TypedDict):

@curried.curry
def vegafusion_data_transformer(
data: _DataType, max_rows: int = 100000
) -> Union[_ToVegaFusionReturnUrlDict, _ToValuesReturnType]:
data: DataType, max_rows: int = 100000
) -> Union[_ToVegaFusionReturnUrlDict, ToValuesReturnType]:
"""VegaFusion Data Transformer"""
if hasattr(data, "__geo_interface__"):
# Use default transformer for geo interface objects
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_inline_table_names(vega_spec: dict) -> Set[str]:
return table_names


def get_inline_tables(vega_spec: dict) -> Dict[str, _DataFrameLike]:
def get_inline_tables(vega_spec: dict) -> Dict[str, DataFrameLike]:
"""Get the inline tables referenced by a Vega specification

Note: This function should only be called on a Vega spec that corresponds
Expand Down
18 changes: 9 additions & 9 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
if TYPE_CHECKING:
from pandas.core.interchange.dataframe_protocol import Column as PandasColumn

_V = TypeVar("_V")
_P = ParamSpec("_P")
V = TypeVar("V")
P = ParamSpec("P")


class _DataFrameLike(Protocol):
class DataFrameLike(Protocol):
def __dataframe__(self, *args, **kwargs) -> DfiDataFrame:
...

Expand Down Expand Up @@ -188,12 +188,12 @@ def __dataframe__(self, *args, **kwargs) -> DfiDataFrame:
]


_InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"]
InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"]


def infer_vegalite_type(
data: object,
) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]:
) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]:
"""
From an array-like input, infer the correct vega typecode
('ordinal', 'nominal', 'quantitative', or 'temporal')
Expand Down Expand Up @@ -442,7 +442,7 @@ def sanitize_arrow_table(pa_table):

def parse_shorthand(
shorthand: Union[Dict[str, Any], str],
data: Optional[Union[pd.DataFrame, _DataFrameLike]] = None,
data: Optional[Union[pd.DataFrame, DataFrameLike]] = None,
parse_aggregates: bool = True,
parse_window_ops: bool = False,
parse_timeunits: bool = True,
Expand Down Expand Up @@ -637,7 +637,7 @@ def parse_shorthand(

def infer_vegalite_type_for_dfi_column(
column: Union[Column, "PandasColumn"],
) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]:
) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]:
from pyarrow.interchange.from_dataframe import column_to_array

try:
Expand Down Expand Up @@ -672,10 +672,10 @@ def infer_vegalite_type_for_dfi_column(
raise ValueError(f"Unexpected DtypeKind: {kind}")


def use_signature(Obj: Callable[_P, Any]):
def use_signature(Obj: Callable[P, Any]):
"""Apply call signature and documentation of Obj to the decorated method"""

def decorate(f: Callable[..., _V]) -> Callable[_P, _V]:
def decorate(f: Callable[..., V]) -> Callable[P, V]:
# call-signature of f is exposed via __wrapped__.
# we want it to mimic Obj.__init__
f.__wrapped__ = Obj.__init__ # type: ignore
Expand Down
30 changes: 15 additions & 15 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TypeVar

from ._importers import import_pyarrow_interchange
from .core import sanitize_dataframe, sanitize_arrow_table, _DataFrameLike
from .core import sanitize_dataframe, sanitize_arrow_table, DataFrameLike
from .core import sanitize_geo_interface
from .deprecation import AltairDeprecationWarning
from .plugin_registry import PluginRegistry
Expand All @@ -23,15 +23,15 @@
import pyarrow.lib


class _SupportsGeoInterface(Protocol):
class SupportsGeoInterface(Protocol):
__geo_interface__: MutableMapping


_DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _DataFrameLike]
_TDataType = TypeVar("_TDataType", bound=_DataType)
DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, DataFrameLike]
TDataType = TypeVar("TDataType", bound=DataType)

_VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]]
_ToValuesReturnType = Dict[str, Union[dict, List[dict]]]
VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]]
ToValuesReturnType = Dict[str, Union[dict, List[dict]]]


# ==============================================================================
Expand All @@ -46,7 +46,7 @@ class _SupportsGeoInterface(Protocol):
# form.
# ==============================================================================
class DataTransformerType(Protocol):
def __call__(self, data: _DataType, **kwargs) -> _VegaLiteDataDict:
def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict:
pass


Expand All @@ -70,7 +70,7 @@ class MaxRowsError(Exception):


@curried.curry
def limit_rows(data: _TDataType, max_rows: Optional[int] = 5000) -> _TDataType:
def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType:
"""Raise MaxRowsError if the data model has more than max_rows.

If max_rows is None, then do not perform any check.
Expand Down Expand Up @@ -122,7 +122,7 @@ def raise_max_rows_error():

@curried.curry
def sample(
data: _DataType, n: Optional[int] = None, frac: Optional[float] = None
data: DataType, n: Optional[int] = None, frac: Optional[float] = None
) -> Optional[Union[pd.DataFrame, Dict[str, Sequence], "pyarrow.lib.Table"]]:
"""Reduce the size of the data model by sampling without replacement."""
check_data_type(data)
Expand Down Expand Up @@ -180,7 +180,7 @@ class _ToCsvReturnUrlDict(TypedDict):

@curried.curry
def to_json(
data: _DataType,
data: DataType,
prefix: str = "altair-data",
extension: str = "json",
filename: str = "{prefix}-{hash}.{extension}",
Expand All @@ -199,7 +199,7 @@ def to_json(

@curried.curry
def to_csv(
data: Union[dict, pd.DataFrame, _DataFrameLike],
data: Union[dict, pd.DataFrame, DataFrameLike],
prefix: str = "altair-data",
extension: str = "csv",
filename: str = "{prefix}-{hash}.{extension}",
Expand All @@ -215,7 +215,7 @@ def to_csv(


@curried.curry
def to_values(data: _DataType) -> _ToValuesReturnType:
def to_values(data: DataType) -> ToValuesReturnType:
"""Replace a DataFrame by a data model with values."""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
Expand All @@ -242,7 +242,7 @@ def to_values(data: _DataType) -> _ToValuesReturnType:
raise ValueError("Unrecognized data type: {}".format(type(data)))


def check_data_type(data: _DataType) -> None:
def check_data_type(data: DataType) -> None:
if not isinstance(data, (dict, pd.DataFrame)) and not any(
hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"]
):
Expand All @@ -260,7 +260,7 @@ def _compute_data_hash(data_str: str) -> str:
return hashlib.md5(data_str.encode()).hexdigest()


def _data_to_json_string(data: _DataType) -> str:
def _data_to_json_string(data: DataType) -> str:
"""Return a JSON string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
Expand Down Expand Up @@ -288,7 +288,7 @@ def _data_to_json_string(data: _DataType) -> str:
)


def _data_to_csv_string(data: Union[dict, pd.DataFrame, _DataFrameLike]) -> str:
def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str:
"""return a CSV string representation of the input data"""
check_data_type(data)
if hasattr(data, "__geo_interface__"):
Expand Down
10 changes: 3 additions & 7 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
else:
from typing_extensions import Self

_TSchemaBase = TypeVar("_TSchemaBase", bound=Type["SchemaBase"])
TSchemaBase = TypeVar("TSchemaBase", bound=Type["SchemaBase"])

ValidationErrorList = List[jsonschema.exceptions.ValidationError]
GroupedValidationErrors = Dict[str, ValidationErrorList]
Expand Down Expand Up @@ -733,11 +733,7 @@ def __repr__(self):
return "Undefined"


# In the future Altair may implement a more complete set of type hints.
# But for now, we'll add an annotation to indicate that the type checker
# should permit any value passed to a function argument whose default
# value is Undefined.
Undefined: Any = UndefinedType()
Undefined = UndefinedType()


class SchemaBase:
Expand Down Expand Up @@ -1329,7 +1325,7 @@ def __call__(self, *args, **kwargs):
return obj


def with_property_setters(cls: _TSchemaBase) -> _TSchemaBase:
def with_property_setters(cls: TSchemaBase) -> TSchemaBase:
"""
Decorator to add property setters to a Schema class.
"""
Expand Down
6 changes: 3 additions & 3 deletions altair/vegalite/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
check_data_type,
)
from ..utils.data import DataTransformerRegistry as _DataTransformerRegistry
from ..utils.data import _DataType, _ToValuesReturnType
from ..utils.data import DataType, ToValuesReturnType
from ..utils.plugin_registry import PluginEnabler


@curried.curry
def default_data_transformer(
data: _DataType, max_rows: int = 5000
) -> _ToValuesReturnType:
data: DataType, max_rows: int = 5000
) -> ToValuesReturnType:
return curried.pipe(data, limit_rows(max_rows=max_rows), to_values)


Expand Down
Loading