diff --git a/altair/__init__.py b/altair/__init__.py index bde21db06..c9b0da123 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -1,13 +1,6 @@ # ruff: noqa __version__ = "5.4.0dev" -from typing import Any - -# Necessary as mypy would see expr as the module alt.expr although due to how -# the imports are set up it is expr in the alt.expr module -expr: Any - - # The content of __all__ is automatically written by # tools/update_init_file.py. Do not modify directly. __all__ = [ @@ -54,6 +47,7 @@ "BrushConfig", "CalculateTransform", "Categorical", + "ChainedWhen", "Chart", "ChartDataType", "ChartType", @@ -488,6 +482,7 @@ "TextDef", "TextDirection", "TextValue", + "Then", "Theta", "Theta2", "Theta2Datum", @@ -565,6 +560,7 @@ "VegaLiteSchema", "ViewBackground", "ViewConfig", + "When", "WindowEventType", "WindowFieldDef", "WindowOnlyOp", @@ -622,7 +618,6 @@ "load_ipython_extension", "load_schema", "mixins", - "overload", "param", "parse_shorthand", "renderers", @@ -645,6 +640,7 @@ "vconcat", "vegalite", "vegalite_compilers", + "when", "with_property_setters", ] @@ -654,7 +650,9 @@ def __dir__(): from altair.vegalite import * +from altair.vegalite.v5.schema.core import Dict from altair.jupyter import JupyterChart +from altair.expr import expr from altair.utils import AltairDeprecationWarning diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index b9b7269a6..2b3cd6e85 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -8,6 +8,7 @@ update_nested, display_traceback, SchemaBase, + SHORTHAND_KEYS, ) from .html import spec_to_html from .plugin_registry import PluginRegistry @@ -16,6 +17,7 @@ __all__ = ( + "SHORTHAND_KEYS", "AltairDeprecationWarning", "Optional", "PluginRegistry", diff --git a/altair/utils/core.py b/altair/utils/core.py index 91a94391a..739b839cb 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -12,17 +12,7 @@ import sys import traceback import warnings -from typing import ( - Callable, - TypeVar, - Any, - Iterator, - cast, - Literal, - Protocol, - TYPE_CHECKING, - runtime_checkable, -) +from typing import Callable, TypeVar, Any, Iterator, cast, Literal, TYPE_CHECKING from itertools import groupby from operator import itemgetter @@ -33,6 +23,10 @@ from altair.utils.schemapi import SchemaBase, Undefined +if sys.version_info >= (3, 12): + from typing import runtime_checkable, Protocol +else: + from typing_extensions import runtime_checkable, Protocol if sys.version_info >= (3, 10): from typing import ParamSpec else: @@ -199,6 +193,22 @@ def __dataframe__( "utcsecondsmilliseconds", ] +VALID_TYPECODES = list(itertools.chain(iter(TYPECODE_MAP), iter(INV_TYPECODE_MAP))) + +SHORTHAND_UNITS = { + "field": "(?P.*)", + "type": "(?P{})".format("|".join(VALID_TYPECODES)), + "agg_count": "(?Pcount)", + "op_count": "(?Pcount)", + "aggregate": "(?P{})".format("|".join(AGGREGATES)), + "window_op": "(?P{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), + "timeUnit": "(?P{})".format("|".join(TIMEUNITS)), +} + +SHORTHAND_KEYS: frozenset[Literal["field", "aggregate", "type", "timeUnit"]] = ( + frozenset(("field", "aggregate", "type", "timeUnit")) +) + def infer_vegalite_type_for_pandas( data: object, @@ -577,18 +587,6 @@ def parse_shorthand( if not shorthand: return {} - valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP) - - units = { - "field": "(?P.*)", - "type": "(?P{})".format("|".join(valid_typecodes)), - "agg_count": "(?Pcount)", - "op_count": "(?Pcount)", - "aggregate": "(?P{})".format("|".join(AGGREGATES)), - "window_op": "(?P{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), - "timeUnit": "(?P{})".format("|".join(TIMEUNITS)), - } - patterns = [] if parse_aggregates: @@ -606,7 +604,8 @@ def parse_shorthand( patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns))) regexps = ( - re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns + re.compile(r"\A" + p.format(**SHORTHAND_UNITS) + r"\Z", re.DOTALL) + for p in patterns ) # find matches depending on valid fields passed diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index bd5241fb3..a0bd88506 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1,14 +1,28 @@ from __future__ import annotations +import sys import warnings import hashlib import io import json import jsonschema import itertools -from typing import Union, cast, Any, Iterable, Literal, IO, TYPE_CHECKING +from typing import ( + Any, + cast, + overload, + Literal, + Union, + TYPE_CHECKING, + TypeVar, + Sequence, + Protocol, +) from typing_extensions import TypeAlias -import typing +import typing as t +import functools +import operator +from copy import deepcopy as _deepcopy from .schema import core, channels, mixins, Undefined, SCHEMA_URL @@ -27,20 +41,31 @@ from altair.utils.core import ( to_eager_narwhals_dataframe as _to_eager_narwhals_dataframe, ) +from .schema._typing import Map + +if sys.version_info >= (3, 13): + from typing import TypedDict +else: + from typing_extensions import TypedDict +if sys.version_info >= (3, 12): + from typing import TypeAliasType +else: + from typing_extensions import TypeAliasType if TYPE_CHECKING: from ...utils.core import DataFrameLike - import sys from pathlib import Path + from typing import Iterable, IO, Iterator if sys.version_info >= (3, 13): - from typing import TypeIs + from typing import TypeIs, Required else: - from typing_extensions import TypeIs + from typing_extensions import TypeIs, Required if sys.version_info >= (3, 11): - from typing import Self + from typing import Self, Never else: - from typing_extensions import Self + from typing_extensions import Self, Never + from .schema.channels import Facet, Row, Column from .schema.core import ( SchemaBase, @@ -85,8 +110,14 @@ TopLevelSelectionParameter, SelectionParameter, InlineDataset, + UndefinedType, + ) + from altair.expr.core import ( + BinaryExpression, + Expression, + GetAttrExpression, + GetItemExpression, ) - from altair.expr.core import Expression, GetAttrExpression from .schema._typing import ( ImputeMethod_T, SelectionType_T, @@ -96,7 +127,24 @@ ResolveMode_T, ) + ChartDataType: TypeAlias = Optional[Union[DataType, core.Data, str, core.Generator]] +_TSchemaBase = TypeVar("_TSchemaBase", bound=core.SchemaBase) +_T = TypeVar("_T") +_OneOrSeq = TypeAliasType("_OneOrSeq", Union[_T, Sequence[_T]], type_params=(_T,)) +"""One of ``_T`` specified type(s), or a `Sequence` of such. + +Examples +-------- +The parameters ``short``, ``long`` accept the same range of types:: + + # ruff: noqa: UP006, UP007 + + def func( + short: _OneOrSeq[str | bool | float], + long: Union[str, bool, float, Sequence[Union[str, bool, float]], + ): ... +""" # ------------------------------------------------------------------------ @@ -324,7 +372,7 @@ def __getattr__(self, field_name: str) -> GetAttrExpression | SelectionExpressio # TODO: Are there any special cases to consider for __getitem__? # This was copied from v4. - def __getitem__(self, field_name: str) -> _expr_core.GetItemExpression: + def __getitem__(self, field_name: str) -> GetItemExpression: return _expr_core.GetItemExpression(self.name, field_name) @@ -379,26 +427,636 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: return False -_TestPredicateType = Union[str, _expr_core.Expression, core.PredicateComposition] -_PredicateType = Union[ +# ------------------------------------------------------------------------- +# Tools for working with conditions +_TestPredicateType: TypeAlias = Union[ + str, _expr_core.Expression, core.PredicateComposition +] +"""https://vega.github.io/vega-lite/docs/predicate.html""" + +_PredicateType: TypeAlias = Union[ Parameter, core.Expr, - typing.Dict[str, Any], + Map, _TestPredicateType, _expr_core.OperatorMixin, ] -_ConditionType = typing.Dict[str, Union[_TestPredicateType, Any]] -_DictOrStr = Union[typing.Dict[str, Any], str] -_DictOrSchema = Union[core.SchemaBase, typing.Dict[str, Any]] -_StatementType = Union[core.SchemaBase, _DictOrStr] +"""Permitted types for `predicate`.""" + +_ComposablePredicateType: TypeAlias = Union[ + _expr_core.OperatorMixin, SelectionPredicateComposition +] +"""Permitted types for `&` reduced predicates.""" + +_StatementType: TypeAlias = Union[core.SchemaBase, Map, str] +"""Permitted types for `if_true`/`if_false`. + +In python terms: +```py +if _PredicateType: + return _StatementType +elif _PredicateType: + return _StatementType +else: + return _StatementType +``` +""" + + +_ConditionType: TypeAlias = t.Dict[str, Union[_TestPredicateType, Any]] +"""Intermediate type representing a converted `_PredicateType`. + +Prior to parsing any `_StatementType`. +""" + +_LiteralValue: TypeAlias = Union[str, bool, float, int] +"""Primitive python value types.""" + +_FieldEqualType: TypeAlias = Union[_LiteralValue, Map, Parameter, core.SchemaBase] +"""Permitted types for equality checks on field values: + +- `datum.field == ...` +- `FieldEqualPredicate(equal=...)` +- `when(**constraints=...)` +""" + + +def _is_test_predicate(obj: Any) -> TypeIs[_TestPredicateType]: + return isinstance(obj, (str, _expr_core.Expression, core.PredicateComposition)) + + +def _is_undefined(obj: Any) -> TypeIs[UndefinedType]: + """Type-safe singleton check for `UndefinedType`. + + Notes + ----- + - Using `obj is Undefined` does not narrow from `UndefinedType` in a union. + - Due to the assumption that other `UndefinedType`'s could exist. + - Current [typing spec advises](https://typing.readthedocs.io/en/latest/spec/concepts.html#support-for-singleton-types-in-unions) using an `Enum`. + - Otherwise, requires an explicit guard to inform the type checker. + """ + return obj is Undefined + + +def _get_predicate_expr(p: Parameter) -> Optional[str | SchemaBase]: + # https://vega.github.io/vega-lite/docs/predicate.html + return getattr(p.param, "expr", Undefined) + + +def _predicate_to_condition( + predicate: _PredicateType, *, empty: Optional[bool] = Undefined +) -> _ConditionType: + condition: _ConditionType + if isinstance(predicate, Parameter): + predicate_expr = _get_predicate_expr(predicate) + if predicate.param_type == "selection" or _is_undefined(predicate_expr): + condition = {"param": predicate.name} + if isinstance(empty, bool): + condition["empty"] = empty + elif isinstance(predicate.empty, bool): + condition["empty"] = predicate.empty + else: + condition = {"test": predicate_expr} + elif _is_test_predicate(predicate): + condition = {"test": predicate} + elif isinstance(predicate, dict): + condition = predicate + elif isinstance(predicate, _expr_core.OperatorMixin): + condition = {"test": predicate._to_expr()} + else: + msg = ( + f"Expected a predicate, but got: {type(predicate).__name__!r}\n\n" + f"From `predicate={predicate!r}`." + ) + raise TypeError(msg) + return condition + + +def _condition_to_selection( + condition: _ConditionType, + if_true: _StatementType, + if_false: _StatementType, + **kwargs: Any, +) -> SchemaBase | dict[str, _ConditionType | Any]: + selection: SchemaBase | dict[str, _ConditionType | Any] + if isinstance(if_true, core.SchemaBase): + if_true = if_true.to_dict() + elif isinstance(if_true, str): + if isinstance(if_false, str): + msg = ( + "A field cannot be used for both the `if_true` and `if_false` " + "values of a condition. " + "One of them has to specify a `value` or `datum` definition." + ) + raise ValueError(msg) + else: + if_true = utils.parse_shorthand(if_true) + if_true.update(kwargs) + condition.update(if_true) + if isinstance(if_false, core.SchemaBase): + # For the selection, the channel definitions all allow selections + # already. So use this SchemaBase wrapper if possible. + selection = if_false.copy() + selection.condition = condition + elif isinstance(if_false, (str, dict)): + if isinstance(if_false, str): + if_false = utils.parse_shorthand(if_false) + if_false.update(kwargs) + selection = dict(condition=condition, **if_false) + else: + raise TypeError(if_false) + return selection + + +class _ConditionClosed(TypedDict, closed=True, total=False): # type: ignore[call-arg] + # https://peps.python.org/pep-0728/ + # Parameter {"param", "value", "empty"} + # Predicate {"test", "value"} + empty: Optional[bool] + param: Parameter | str + test: _TestPredicateType + value: Any + + +class _ConditionExtra(TypedDict, closed=True, total=False): # type: ignore[call-arg] + # https://peps.python.org/pep-0728/ + # Likely a Field predicate + empty: Optional[bool] + param: Parameter | str + test: _TestPredicateType + value: Any + __extra_items__: _StatementType | _OneOrSeq[_LiteralValue] + + +_Condition: TypeAlias = _ConditionExtra +"""A singular, non-chainable condition produced by ``.when()``.""" + +_Conditions: TypeAlias = t.List[_ConditionClosed] +"""Chainable conditions produced by ``.when()`` and ``Then.when()``.""" + +_C = TypeVar("_C", _Conditions, _Condition) + + +class _Conditional(TypedDict, t.Generic[_C], total=False): + condition: Required[_C] + value: Any + + +class _Value(TypedDict, closed=True, total=False): # type: ignore[call-arg] + # https://peps.python.org/pep-0728/ + value: Required[Any] + __extra_items__: Any + + +def _reveal_parsed_shorthand(obj: Map, /) -> dict[str, Any]: + # Helper for producing error message on multiple field collision. + return {k: v for k, v in obj.items() if k in utils.SHORTHAND_KEYS} + + +def _is_extra(*objs: Any, kwds: Map) -> Iterator[bool]: + for el in objs: + if isinstance(el, (core.SchemaBase, t.Mapping)): + item = el.to_dict(validate=False) if isinstance(el, core.SchemaBase) else el + yield not (item.keys() - kwds.keys()).isdisjoint(utils.SHORTHAND_KEYS) + else: + continue + + +def _is_condition_extra(obj: Any, *objs: Any, kwds: Map) -> TypeIs[_Condition]: + # NOTE: Short circuits on the first conflict. + # 1 - Originated from parse_shorthand + # 2 - Used a wrapper or `dict` directly, including `extra_keys` + return isinstance(obj, str) or any(_is_extra(obj, *objs, kwds=kwds)) + + +def _parse_when_constraints( + constraints: dict[str, _FieldEqualType], / +) -> Iterator[BinaryExpression]: + """Wrap kwargs with `alt.datum`. + + ```py + # before + alt.when(alt.datum.Origin == "Europe") + + # after + alt.when(Origin = "Europe") + ``` + """ + for name, value in constraints.items(): + yield _expr_core.GetAttrExpression("datum", name) == value + + +def _validate_composables( + predicates: Iterable[Any], / +) -> Iterator[_ComposablePredicateType]: + for p in predicates: + if isinstance(p, (_expr_core.OperatorMixin, SelectionPredicateComposition)): + yield p + else: + msg = ( + f"Predicate composition is not permitted for " + f"{type(p).__name__!r}.\n" + f"Try wrapping {p!r} in a `Parameter` first." + ) + raise TypeError(msg) + + +def _parse_when_compose( + predicates: tuple[Any, ...], + constraints: dict[str, _FieldEqualType], + /, +) -> BinaryExpression: + """Compose an `&` reduction predicate. + + Parameters + ---------- + predicates + Collected positional arguments. + constraints + Collected keyword arguments. + + Raises + ------ + TypeError + On the first non ``_ComposablePredicateType`` of `predicates` + """ + iters = [] + if predicates: + iters.append(_validate_composables(predicates)) + if constraints: + iters.append(_parse_when_constraints(constraints)) + r = functools.reduce(operator.and_, itertools.chain.from_iterable(iters)) + return cast(_expr_core.BinaryExpression, r) + + +def _parse_when( + predicate: Optional[_PredicateType], + *more_predicates: _ComposablePredicateType, + empty: Optional[bool], + **constraints: _FieldEqualType, +) -> _ConditionType: + composed: _PredicateType + if _is_undefined(predicate): + if more_predicates or constraints: + composed = _parse_when_compose(more_predicates, constraints) + else: + msg = ( + f"At least one predicate or constraint must be provided, " + f"but got: {predicate=}" + ) + raise TypeError(msg) + elif more_predicates or constraints: + predicates = predicate, *more_predicates + composed = _parse_when_compose(predicates, constraints) + else: + composed = predicate + return _predicate_to_condition(composed, empty=empty) + + +def _parse_literal(val: Any, /) -> dict[str, Any]: + if isinstance(val, str): + return utils.parse_shorthand(val) + else: + msg = ( + f"Expected a shorthand `str`, but got: {type(val).__name__!r}\n\n" + f"From `statement={val!r}`." + ) + raise TypeError(msg) + + +def _parse_then(statement: _StatementType, kwds: dict[str, Any], /) -> dict[str, Any]: + if isinstance(statement, core.SchemaBase): + statement = statement.to_dict() + elif not isinstance(statement, dict): + statement = _parse_literal(statement) + statement.update(kwds) + return statement + + +def _parse_otherwise( + statement: _StatementType, conditions: _Conditional[Any], kwds: dict[str, Any], / +) -> SchemaBase | _Conditional[Any]: + selection: SchemaBase | _Conditional[Any] + if isinstance(statement, core.SchemaBase): + selection = statement.copy() + conditions.update(**kwds) # type: ignore[call-arg] + selection.condition = conditions["condition"] + else: + if not isinstance(statement, t.Mapping): + statement = _parse_literal(statement) + selection = conditions + selection.update(**statement, **kwds) # type: ignore[call-arg] + return selection + + +class _BaseWhen(Protocol): + # NOTE: Temporary solution to non-SchemaBase copy + _condition: _ConditionType + + def _when_then( + self, statement: _StatementType, kwds: dict[str, Any], / + ) -> _ConditionClosed | _Condition: + condition: Any = _deepcopy(self._condition) + then = _parse_then(statement, kwds) + condition.update(then) + return condition + + +class When(_BaseWhen): + """Utility class for ``when-then-otherwise`` conditions. + + Represents the state after calling :func:`.when()`. + + This partial state requires calling :meth:`When.then()` to finish the condition. + + References + ---------- + `polars.expr.whenthen `__ + """ + + def __init__(self, condition: _ConditionType, /) -> None: + self._condition = condition + + @overload + def then(self, statement: str, /, **kwds: Any) -> Then[_Condition]: ... + @overload + def then(self, statement: _Value, /, **kwds: Any) -> Then[_Conditions]: ... + @overload + def then( + self, statement: dict[str, Any] | SchemaBase, /, **kwds: Any + ) -> Then[Any]: ... + def then(self, statement: _StatementType, /, **kwds: Any) -> Then[Any]: + """Attach a statement to this predicate. + + Parameters + ---------- + statement + A spec or value to use when the preceding :func:`.when()` clause is true. + + .. note:: + ``str`` will be encoded as `shorthand`__. + **kwds + Additional keyword args are added to the resulting ``dict``. + + Returns + ------- + :class:`Then` + """ + condition = self._when_then(statement, kwds) + if _is_condition_extra(condition, statement, kwds=kwds): + return Then(_Conditional(condition=condition)) + else: + return Then(_Conditional(condition=[condition])) + + +class Then(core.SchemaBase, t.Generic[_C]): + """Utility class for ``when-then-otherwise`` conditions. + + Represents the state after calling :func:`.when().then()`. + + This state is a valid condition on its own. + + It can be further specified, via multiple chained `when-then` calls, + or finalized with :meth:`Then.otherwise()`. + + References + ---------- + `polars.expr.whenthen `__ + """ + + _schema = {"type": "object"} + + def __init__(self, conditions: _Conditional[_C], /) -> None: + super().__init__(**conditions) + self.condition: _C + + @overload + def otherwise(self, statement: _TSchemaBase, /, **kwds: Any) -> _TSchemaBase: ... + @overload + def otherwise(self, statement: str, /, **kwds: Any) -> _Conditional[_Condition]: ... + @overload + def otherwise( + self, statement: _Value, /, **kwds: Any + ) -> _Conditional[_Conditions]: ... + @overload + def otherwise( + self, statement: dict[str, Any], /, **kwds: Any + ) -> _Conditional[Any]: ... + def otherwise( + self, statement: _StatementType, /, **kwds: Any + ) -> SchemaBase | _Conditional[Any]: + """Finalize the condition with a default value. + + Parameters + ---------- + statement + A spec or value to use when no predicates were met. + + .. note:: + Roughly equivalent to an ``else`` clause. + + .. note:: + ``str`` will be encoded as `shorthand`__. + **kwds + Additional keyword args are added to the resulting ``dict``. + """ + conditions: _Conditional[Any] + is_extra = functools.partial(_is_condition_extra, kwds=kwds) + if is_extra(self.condition, statement): + current = self.condition + if isinstance(current, list) and len(current) == 1: + # This case is guaranteed to have come from `When` and not `ChainedWhen` + # The `list` isn't needed if we complete the condition here + conditions = _Conditional(condition=current[0]) + elif isinstance(current, dict): + if not is_extra(statement): + conditions = self.to_dict() + else: + cond = _reveal_parsed_shorthand(current) + msg = ( + f"Only one field may be used within a condition.\n" + f"Shorthand {statement!r} would conflict with {cond!r}\n\n" + f"Use `alt.value({statement!r})` if this is not a shorthand string." + ) + raise TypeError(msg) + else: + # Generic message to cover less trivial cases + msg = ( + f"Chained conditions cannot be mixed with field conditions.\n" + f"{self!r}\n\n{statement!r}" + ) + raise TypeError(msg) + else: + conditions = self.to_dict() + return _parse_otherwise(statement, conditions, kwds) + + def when( + self, + predicate: Optional[_PredicateType] = Undefined, + *more_predicates: _ComposablePredicateType, + empty: Optional[bool] = Undefined, + **constraints: _FieldEqualType, + ) -> ChainedWhen: + """Attach another predicate to the condition. + + The resulting predicate is an ``&`` reduction over ``predicate`` and optional ``*``, ``**``, arguments. + + Parameters + ---------- + predicate + A selection or test predicate. ``str`` input will be treated as a test operand. + + .. note:: + accepts the same range of inputs as in :func:`.condition()`. + *more_predicates + Additional predicates, restricted to types supporting ``&``. + empty + For selection parameters, the predicate of empty selections returns ``True`` by default. + Override this behavior, with ``empty=False``. + **constraints + Specify `Field Equal Predicate `__'s. + Shortcut for ``alt.datum.field_name == value``, see examples for usage. + + Returns + ------- + :class:`ChainedWhen` + A partial state which requires calling :meth:`ChainedWhen.then()` to finish the condition. + """ + condition = _parse_when(predicate, *more_predicates, empty=empty, **constraints) + conditions = self.to_dict() + current = conditions["condition"] + if isinstance(current, list): + conditions = t.cast(_Conditional[_Conditions], conditions) + return ChainedWhen(condition, conditions) + elif isinstance(current, dict): + cond = _reveal_parsed_shorthand(current) + msg = ( + f"Chained conditions cannot be mixed with field conditions.\n" + f"Additional conditions would conflict with {cond!r}\n\n" + f"Must finalize by calling `.otherwise()`." + ) + raise TypeError(msg) + else: + msg = ( + f"The internal structure has been modified.\n" + f"{type(current).__name__!r} found, but only `dict | list` are valid." + ) + raise NotImplementedError(msg) + + def to_dict(self, *args, **kwds) -> _Conditional[_C]: # type: ignore[override] + m = super().to_dict(*args, **kwds) + return _Conditional(condition=m["condition"]) + + +class ChainedWhen(_BaseWhen): + """ + Utility class for ``when-then-otherwise`` conditions. + + Represents the state after calling :func:`.when().then().when()`. + + This partial state requires calling :meth:`ChainedWhen.then()` to finish the condition. + + References + ---------- + `polars.expr.whenthen `__ + """ + + def __init__( + self, + condition: _ConditionType, + conditions: _Conditional[_Conditions], + /, + ) -> None: + self._condition = condition + self._conditions = conditions + + def then(self, statement: _StatementType, /, **kwds: Any) -> Then[_Conditions]: + """Attach a statement to this predicate. + + Parameters + ---------- + statement + A spec or value to use when the preceding :meth:`Then.when()` clause is true. + + .. note:: + ``str`` will be encoded as `shorthand`__. + **kwds + Additional keyword args are added to the resulting ``dict``. + + Returns + ------- + :class:`Then` + """ + condition = self._when_then(statement, kwds) + conditions = self._conditions.copy() + conditions["condition"].append(condition) + return Then(conditions) + + +def when( + predicate: Optional[_PredicateType] = Undefined, + *more_predicates: _ComposablePredicateType, + empty: Optional[bool] = Undefined, + **constraints: _FieldEqualType, +) -> When: + """Start a ``when-then-otherwise`` condition. + + The resulting predicate is an ``&`` reduction over ``predicate`` and optional ``*``, ``**``, arguments. + + Parameters + ---------- + predicate + A selection or test predicate. ``str`` input will be treated as a test operand. + + .. note:: + Accepts the same range of inputs as in :func:`.condition()`. + *more_predicates + Additional predicates, restricted to types supporting ``&``. + empty + For selection parameters, the predicate of empty selections returns ``True`` by default. + Override this behavior, with ``empty=False``. + **constraints + Specify `Field Equal Predicate `__'s. + Shortcut for ``alt.datum.field_name == value``, see examples for usage. + + Returns + ------- + :class:`When` + A partial state which requires calling :meth:`When.then()` to finish the condition. + + Notes + ----- + - Directly inspired by the ``when-then-otherwise`` syntax used in ``polars.when``. + + References + ---------- + `polars.when `__ + + Examples + -------- + Using keyword-argument ``constraints`` can simplify compositions like:: + + import altair as alt + verbose_composition = ( + (alt.datum.Name == "Name_1") + & (alt.datum.Color == "Green") + & (alt.datum.Age == 25) + & (alt.datum.StartDate == "2000-10-01") + ) + when_verbose = alt.when(verbose_composition) + when_concise = alt.when(Name="Name_1", Color="Green", Age=25, StartDate="2000-10-01") + """ + condition = _parse_when(predicate, *more_predicates, empty=empty, **constraints) + return When(condition) + # ------------------------------------------------------------------------ # Top-Level Functions -def value(value, **kwargs) -> dict: +def value(value, **kwargs) -> _Value: """Specify a value for use in an encoding""" - return dict(value=value, **kwargs) + return _Value(value=value, **kwargs) # type: ignore[typeddict-item] def param( @@ -810,26 +1468,23 @@ def binding_range(**kwargs): return core.BindRange(input="range", **kwargs) -_TSchemaBase = typing.TypeVar("_TSchemaBase", bound=core.SchemaBase) - - -@typing.overload +@overload def condition( predicate: _PredicateType, if_true: _StatementType, if_false: _TSchemaBase, **kwargs ) -> _TSchemaBase: ... -@typing.overload +@overload def condition( predicate: _PredicateType, if_true: str, if_false: str, **kwargs -) -> typing.NoReturn: ... -@typing.overload +) -> Never: ... +@overload def condition( - predicate: _PredicateType, if_true: _DictOrSchema, if_false: _DictOrStr, **kwargs + predicate: _PredicateType, if_true: Map | SchemaBase, if_false: Map | str, **kwargs ) -> dict[str, _ConditionType | Any]: ... -@typing.overload +@overload def condition( predicate: _PredicateType, - if_true: _DictOrStr, - if_false: dict[str, Any], + if_true: Map | str, + if_false: Map, **kwargs, ) -> dict[str, _ConditionType | Any]: ... # TODO: update the docstring @@ -838,7 +1493,7 @@ def condition( if_true: _StatementType, if_false: _StatementType, **kwargs, -) -> dict[str, Any] | SchemaBase: +) -> SchemaBase | dict[str, _ConditionType | Any]: """A conditional attribute or encoding Parameters @@ -858,56 +1513,9 @@ def condition( spec: dict or VegaLiteSchema the spec that describes the condition """ - - test_predicates = (str, _expr_core.Expression, core.PredicateComposition) - - condition: dict[str, Optional[bool | str | Expression | PredicateComposition]] - if isinstance(predicate, Parameter): - if ( - predicate.param_type == "selection" - or getattr(predicate.param, "expr", Undefined) is Undefined - ): - condition = {"param": predicate.name} - if "empty" in kwargs: - condition["empty"] = kwargs.pop("empty") - elif isinstance(predicate.empty, bool): - condition["empty"] = predicate.empty - else: - condition = {"test": getattr(predicate.param, "expr", Undefined)} - elif isinstance(predicate, test_predicates): - condition = {"test": predicate} - elif isinstance(predicate, dict): - condition = predicate - else: - msg = f"condition predicate of type {type(predicate)}" "" - raise NotImplementedError(msg) - - if isinstance(if_true, core.SchemaBase): - # convert to dict for now; the from_dict call below will wrap this - # dict in the appropriate schema - if_true = if_true.to_dict() - elif isinstance(if_true, str): - if isinstance(if_false, str): - msg = "A field cannot be used for both the `if_true` and `if_false` values of a condition. One of them has to specify a `value` or `datum` definition." - raise ValueError(msg) - else: - if_true = utils.parse_shorthand(if_true) - if_true.update(kwargs) - condition.update(if_true) - - selection: dict | SchemaBase - if isinstance(if_false, core.SchemaBase): - # For the selection, the channel definitions all allow selections - # already. So use this SchemaBase wrapper if possible. - selection = if_false.copy() - selection.condition = condition - elif isinstance(if_false, str): - selection = {"condition": condition, "shorthand": if_false} - selection.update(kwargs) - else: - selection = dict(condition=condition, **if_false) - - return selection + empty = kwargs.pop("empty", Undefined) + condition = _predicate_to_condition(predicate, empty=empty) + return _condition_to_selection(condition, if_true, if_false, **kwargs) # -------------------------------------------------------------------- diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index 5d2fe8c50..89f35d0dd 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -135,7 +135,6 @@ "Day", "DensityTransform", "DerivedStream", - "Dict", "DictInlineDataset", "DictSelectionInit", "DictSelectionInitInterval", diff --git a/doc/user_guide/api.rst b/doc/user_guide/api.rst index c926c0d56..ab964736f 100644 --- a/doc/user_guide/api.rst +++ b/doc/user_guide/api.rst @@ -166,6 +166,7 @@ API Functions topo_feature value vconcat + when Low-Level Schema Wrappers ------------------------- @@ -627,3 +628,15 @@ Low-Level Schema Wrappers WindowFieldDef WindowOnlyOp WindowTransform + +API Utility Classes +------------------- +.. currentmodule:: altair + +.. autosummary:: + :toctree: generated/api-cls/ + :nosignatures: + + When + Then + ChainedWhen diff --git a/doc/user_guide/marks/area.rst b/doc/user_guide/marks/area.rst index 78c77866b..8991d25c8 100644 --- a/doc/user_guide/marks/area.rst +++ b/doc/user_guide/marks/area.rst @@ -78,7 +78,6 @@ to ``true`` or an object defining a property of the overlaying point marks, we c .. altair-plot:: import altair as alt from vega_datasets import data - from altair.expr import datum source = data.stocks.url diff --git a/doc/user_guide/marks/geoshape.rst b/doc/user_guide/marks/geoshape.rst index 37aee0567..4db48f98c 100644 --- a/doc/user_guide/marks/geoshape.rst +++ b/doc/user_guide/marks/geoshape.rst @@ -202,14 +202,12 @@ Altair also contains expressions related to geographical features. We can for ex .. altair-plot:: - from altair.expr import datum, geoCentroid - basemap = alt.Chart(gdf_sel).mark_geoshape( fill='lightgray', stroke='white', strokeWidth=0.5 ) bubbles = alt.Chart(gdf_sel).transform_calculate( - centroid=geoCentroid(None, datum) + centroid=alt.expr.geoCentroid(None, alt.datum) ).mark_circle( stroke='black' ).encode( diff --git a/tests/utils/test_core.py b/tests/utils/test_core.py index a2344a218..d71e4d822 100644 --- a/tests/utils/test_core.py +++ b/tests/utils/test_core.py @@ -300,9 +300,11 @@ def test_infer_encoding_types_with_condition(): ), ), "color": alt.Color( - "cfield:N", + field=alt.FieldName("cfield"), + type=alt.StandardType("nominal"), condition=alt.ConditionalPredicateValueDefGradientstringnullExprRef( - value="red", test=alt.Predicate("pred2") + value="red", + test=alt.Predicate("pred2"), ), ), "opacity": alt.OpacityValue( diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 0d7e248b6..990e0965c 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1,5 +1,8 @@ """Unit tests for altair API""" +from __future__ import annotations + + from datetime import date import io import ibis @@ -8,6 +11,7 @@ import operator import os import pathlib +import re import tempfile from importlib.metadata import version as importlib_version from packaging.version import Version @@ -18,7 +22,7 @@ import pandas as pd import polars as pl -import altair.vegalite.v5 as alt +import altair as alt try: import vl_convert as vlc @@ -87,6 +91,100 @@ def basic_chart(): return alt.Chart(data).mark_bar().encode(x="a", y="b") +@pytest.fixture +def cars(): + return pd.DataFrame( + { + "Name": [ + "chevrolet chevelle malibu", + "buick skylark 320", + "plymouth satellite", + "amc rebel sst", + "ford torino", + "ford galaxie 500", + "chevrolet impala", + "plymouth fury iii", + "pontiac catalina", + "amc ambassador dpl", + ], + "Miles_per_Gallon": [ + 18.0, + 15.0, + 18.0, + 16.0, + 17.0, + 15.0, + 14.0, + 14.0, + 14.0, + 15.0, + ], + "Cylinders": [8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + "Displacement": [ + 307.0, + 350.0, + 318.0, + 304.0, + 302.0, + 429.0, + 454.0, + 440.0, + 455.0, + 390.0, + ], + "Horsepower": [ + 130.0, + 165.0, + 150.0, + 150.0, + 140.0, + 198.0, + 220.0, + 215.0, + 225.0, + 190.0, + ], + "Weight_in_lbs": [ + 3504, + 3693, + 3436, + 3433, + 3449, + 4341, + 4354, + 4312, + 4425, + 3850, + ], + "Acceleration": [12.0, 11.5, 11.0, 12.0, 10.5, 10.0, 9.0, 8.5, 10.0, 8.5], + "Year": [ + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + pd.Timestamp("1970-01-01 00:00:00"), + ], + "Origin": [ + "USA", + "USA", + "USA", + "USA", + "USA", + "USA", + "USA", + "USA", + "USA", + "USA", + ], + } + ) + + def test_chart_data_types(): def Chart(data): return alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q") @@ -263,6 +361,341 @@ def test_chart_operations(): assert len(chart.vconcat) == 4 +def test_when() -> None: + select = alt.selection_point(name="select", on="click") + condition = alt.condition(select, alt.value(1), "two", empty=False)["condition"] + condition.pop("value") + when = alt.when(select, empty=False) + when_constraint = alt.when(Origin="Europe") + when_constraints = alt.when( + Name="Name_1", Color="Green", Age=25, StartDate="2000-10-01" + ) + expected_constraint = alt.datum.Origin == "Europe" + expected_constraints = ( + (alt.datum.Name == "Name_1") + & (alt.datum.Color == "Green") + & (alt.datum.Age == 25) + & (alt.datum.StartDate == "2000-10-01") + ) + + assert isinstance(when, alt.When) + assert condition == when._condition + assert isinstance(when_constraint, alt.When) + assert when_constraint._condition["test"] == expected_constraint + assert when_constraints._condition["test"] == expected_constraints + with pytest.raises((NotImplementedError, TypeError), match="list"): + alt.when([1, 2, 3]) # type: ignore + with pytest.raises(TypeError, match="Undefined"): + alt.when() + with pytest.raises(TypeError, match="int"): + alt.when(select, alt.datum.Name == "Name_1", 99, TestCon=5.901) # type: ignore + + +def test_when_then() -> None: + select = alt.selection_point(name="select", on="click") + when = alt.when(select) + when_then = when.then(alt.value(5)) + + assert isinstance(when_then, alt.Then) + condition = when_then.condition + assert isinstance(condition, list) + assert condition[-1].get("value") == 5 + + with pytest.raises(TypeError, match=r"Path"): + when.then(pathlib.Path("some")) # type: ignore + + with pytest.raises(TypeError, match="float"): + when_then.when(select, alt.datum.Name != "Name_2", 86.123, empty=True) # type: ignore + + +def test_when_then_only(basic_chart) -> None: + """`Then` is an acceptable encode argument.""" + + select = alt.selection_point(name="select", on="click") + + basic_chart.encode(fillOpacity=alt.when(select).then(alt.value(5))).to_dict() + + +def test_when_then_otherwise() -> None: + select = alt.selection_point(name="select", on="click") + when_then = alt.when(select).then(alt.value(2, empty=False)) + when_then_otherwise = when_then.otherwise(alt.value(0)) + + expected = alt.condition(select, alt.value(2, empty=False), alt.value(0)) + with pytest.raises(TypeError, match="list"): + when_then.otherwise([1, 2, 3]) # type: ignore + + # Needed to modify to a list of conditions, + # which isn't possible in `condition` + single_condition = expected.pop("condition") + expected["condition"] = [single_condition] + + assert expected == when_then_otherwise + + +def test_when_then_when_then_otherwise() -> None: + """Test for [#3301](https://github.com/vega/altair/issues/3301).""" + + data = { + "values": [ + {"a": "A", "b": 28}, + {"a": "B", "b": 55}, + {"a": "C", "b": 43}, + {"a": "D", "b": 91}, + {"a": "E", "b": 81}, + {"a": "F", "b": 53}, + {"a": "G", "b": 19}, + {"a": "H", "b": 87}, + {"a": "I", "b": 52}, + ] + } + + select = alt.selection_point(name="select", on="click") + highlight = alt.selection_point(name="highlight", on="pointerover") + when_then_when_then = ( + alt.when(select) + .then(alt.value(2, empty=False)) + .when(highlight) + .then(alt.value(1, empty=False)) + ) + with pytest.raises(TypeError, match="set"): + when_then_when_then.otherwise({"five", "six"}) # type: ignore + + expected_stroke = { + "condition": [ + {"param": "select", "empty": False, "value": 2}, + {"param": "highlight", "empty": False, "value": 1}, + ], + "value": 0, + } + actual_stroke = when_then_when_then.otherwise(alt.value(0)) + fill_opacity = alt.when(select).then(alt.value(1)).otherwise(alt.value(0.3)) + + assert expected_stroke == actual_stroke + chart = ( + alt.Chart(data) + .mark_bar(fill="#4C78A8", stroke="black", cursor="pointer") + .encode(x="a:O", y="b:Q", fillOpacity=fill_opacity, strokeWidth=actual_stroke) + .configure_scale(bandPaddingInner=0.2) + .add_params(select, highlight) + ) + chart.to_dict() + + +def test_when_multi_channel_param(cars): + """Adapted from [2236376458](https://github.com/vega/altair/pull/3427#issuecomment-2236376458)""" + brush = alt.selection_interval() + hover = alt.selection_point(on="pointerover", nearest=True, empty=False) + + chart_1 = ( + alt.Chart(cars) + .mark_rect() + .encode( + x="Cylinders:N", + y="Origin:N", + color=alt.when(brush).then("count()").otherwise(alt.value("grey")), + opacity=alt.when(brush).then(alt.value(1)).otherwise(alt.value(0.6)), + ) + .add_params(brush) + ) + chart_1.to_dict() + + color = alt.when(hover).then(alt.value("coral")).otherwise(alt.value("lightgray")) + + chart_2 = ( + alt.Chart(cars, title="Selection obscured by other points") + .mark_circle(opacity=1) + .encode( + x="Horsepower:Q", + y="Miles_per_Gallon:Q", + color=color, + size=alt.when(hover).then(alt.value(300)).otherwise(alt.value(30)), + ) + .add_params(hover) + ) + + chart_3 = chart_2 | chart_2.encode( + order=alt.when(hover).then(alt.value(1)).otherwise(alt.value(0)) + ).properties(title="Selection brought to front") + + chart_3.to_dict() + + +def test_when_labels_position_based_on_condition() -> None: + """Test for [2144026368-1](https://github.com/vega/altair/pull/3427#issuecomment-2144026368) + + Original [labels-position-based-on-condition](https://altair-viz.github.io/user_guide/marks/text.html#labels-position-based-on-condition) + """ + import numpy as np + import pandas as pd + from altair.utils.schemapi import SchemaValidationError + + rand = np.random.RandomState(42) + df = pd.DataFrame({"xval": range(100), "yval": rand.randn(100).cumsum()}) + + bind_range = alt.binding_range(min=100, max=300, name="Slider value: ") + param_width = alt.param(bind=bind_range) + param_width_lt_200 = param_width < 200 + + # Examples of how to write both js and python expressions + param_color_js_expr = alt.param(expr=f"{param_width.name} < 200 ? 'red' : 'black'") + param_color_py_expr = alt.param( + expr=alt.expr.if_(param_width_lt_200, "red", "black") + ) + when = ( + alt.when(param_width_lt_200) + .then(alt.value("red")) + .otherwise(alt.value("black")) + ) + + # NOTE: If the `@overload` signatures change, + # `mypy` will flag structural errors here + cond = when["condition"][0] + otherwise = when["value"] + param_color_py_when = alt.param( + expr=alt.expr.if_(cond["test"], cond["value"], otherwise) + ) + assert param_color_py_expr.expr == param_color_py_when.expr + + chart = ( + alt.Chart(df) + .mark_point() + .encode( + alt.X("xval").axis(titleColor=param_color_js_expr), + alt.Y("yval").axis(titleColor=param_color_py_when), + ) + .add_params(param_width, param_color_js_expr, param_color_py_when) + ) + chart.to_dict() + fail_condition = alt.condition( + param_width < 200, alt.value("red"), alt.value("black") + ) + with pytest.raises(SchemaValidationError, match="invalid value for `expr`"): + alt.param(expr=fail_condition) # type: ignore + + +def test_when_expressions_inside_parameters() -> None: + """Test for [2144026368-2](https://github.com/vega/altair/pull/3427#issuecomment-2144026368) + + Original [expressions-inside-parameters](https://altair-viz.github.io/user_guide/interactions.html#expressions-inside-parameters) + """ + import polars as pl + + source = pl.DataFrame({"a": ["A", "B", "C"], "b": [28, -5, 10]}) + + bar = ( + alt.Chart(source) + .mark_bar() + .encode(y="a:N", x=alt.X("b:Q").scale(domain=[-10, 35])) + ) + when_then_otherwise = ( + alt.when(alt.datum.b >= 0).then(alt.value(10)).otherwise(alt.value(-20)) + ) + cond = when_then_otherwise["condition"][0] + otherwise = when_then_otherwise["value"] + expected = alt.expr(alt.expr.if_(alt.datum.b >= 0, 10, -20)) + actual = alt.expr(alt.expr.if_(cond["test"], cond["value"], otherwise)) + assert expected == actual + + text_conditioned = bar.mark_text(align="left", baseline="middle", dx=actual).encode( + text="b" + ) + + chart = bar + text_conditioned + chart.to_dict() + + +def test_when_multiple_fields(): + # Triggering structural errors + # https://vega.github.io/vega-lite/docs/condition.html#field + brush = alt.selection_interval() + select_x = alt.selection_interval(encodings=["x"]) + when = alt.when(brush) + reveal_msg = re.compile(r"Only one field.+Shorthand 'max\(\)'", flags=re.DOTALL) + with pytest.raises(TypeError, match=reveal_msg): + when.then("count()").otherwise("max()") + + chain_mixed_msg = re.compile( + r"Chained.+mixed.+conflict.+\{'field': 'field_1', 'type': 'quantitative'\}.+otherwise", + flags=re.DOTALL, + ) + with pytest.raises(TypeError, match=chain_mixed_msg): + when.then({"field": "field_1", "type": "quantitative"}).when( + select_x, field_2=99 + ) + + with pytest.raises(TypeError, match=chain_mixed_msg): + when.then("field_1:Q").when(Genre="pop") + + chain_otherwise_msg = re.compile( + r"Chained.+mixed.+field.+AggregatedFieldDef.+'this_field_here'", + flags=re.DOTALL, + ) + with pytest.raises(TypeError, match=chain_otherwise_msg): + when.then(alt.value(5)).when( + alt.selection_point(fields=["b"]) | brush, empty=False, b=63812 + ).then("min(foo):Q").otherwise( + alt.AggregatedFieldDef( + "argmax", field="field_9", **{"as": "this_field_here"} + ) + ) + + +@pytest.mark.parametrize( + ("channel", "then", "otherwise"), + [ + ("color", alt.ColorValue("red"), alt.ColorValue("blue")), + ("opacity", alt.value(0.5), alt.value(1.0)), + ("text", alt.TextValue("foo"), alt.value("bar")), + ("color", alt.Color("col1:N"), alt.value("blue")), + ("opacity", "col1:N", alt.value(0.5)), + ("text", alt.value("abc"), alt.Text("Name:N")), + ("size", alt.value(20), "Name:N"), + ("size", "count()", alt.value(0)), + ], +) +@pytest.mark.parametrize( + "when", + [ + alt.selection_interval(), + alt.selection_point(), + alt.datum.Displacement > alt.value(350), + alt.selection_point(name="select", on="click"), + alt.selection_point(fields=["Horsepower"]), + ], +) +@pytest.mark.parametrize("empty", [alt.Undefined, True, False]) +def test_when_condition_parity( + cars, channel: str, when, empty: alt.Optional[bool], then, otherwise +): + params = [when] if isinstance(when, alt.Parameter) else () + kwds = {"x": "Cylinders:N", "y": "Origin:N"} + + input_condition = alt.condition(when, then, otherwise, empty=empty) + chart_condition = ( + alt.Chart(cars) + .mark_rect() + .encode(**kwds, **{channel: input_condition}) + .add_params(*params) + .to_dict() + ) + + input_when = alt.when(when, empty=empty).then(then).otherwise(otherwise) + chart_when = ( + alt.Chart(cars) + .mark_rect() + .encode(**kwds, **{channel: input_when}) + .add_params(*params) + .to_dict() + ) + + if isinstance(input_when["condition"], list): + input_when["condition"] = input_when["condition"][0] + assert input_condition == input_when + else: + assert chart_condition == chart_when + + def test_selection_to_dict(): brush = alt.selection_interval() diff --git a/tools/generate_api_docs.py b/tools/generate_api_docs.py index 2f923d8c2..09c4d6fd0 100644 --- a/tools/generate_api_docs.py +++ b/tools/generate_api_docs.py @@ -64,6 +64,16 @@ :nosignatures: {lowlevel_wrappers} + +API Utility Classes +------------------- +.. currentmodule:: altair + +.. autosummary:: + :toctree: generated/api-cls/ + :nosignatures: + + {api_classes} """ @@ -95,15 +105,21 @@ def encoding_wrappers() -> list[str]: def api_functions() -> list[str]: - # Exclude typing.cast + # Exclude `typing` functions/SpecialForm(s) altair_api_functions = [ obj_name for obj_name in iter_objects(alt.api, restrict_to_type=types.FunctionType) # type: ignore[attr-defined] - if obj_name != "cast" + if obj_name not in {"cast", "overload", "NamedTuple", "TypedDict"} ] return sorted(altair_api_functions) +def api_classes() -> list[str]: + # classes defined in `api` and returned by `API Functions`, + # but not covered in other groups + return ["When", "Then", "ChainedWhen"] + + def lowlevel_wrappers() -> list[str]: objects = sorted(iter_objects(alt.schema.core, restrict_to_subclass=alt.SchemaBase)) # type: ignore[attr-defined] # The names of these two classes are also used for classes in alt.channels. Due to @@ -124,6 +140,7 @@ def write_api_file() -> None: api_functions=sep.join(api_functions()), encoding_wrappers=sep.join(encoding_wrappers()), lowlevel_wrappers=sep.join(lowlevel_wrappers()), + api_classes=sep.join(api_classes()), ), encoding="utf-8", ) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index a8e951d7a..9244c3261 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -501,11 +501,8 @@ def generate_vegalite_schema_wrapper(schema_file: Path) -> str: # of exported classes which are also defined in the channels or api modules which takes # precedent in the generated __init__.py files one and two levels up. # Importing these classes from multiple modules confuses type checkers. - it = ( - c - for c in definitions.keys() - {"Color", "Text", "LookupData"} - if not c.startswith("_") - ) + EXCLUDE = {"Color", "Text", "LookupData", "Dict"} + it = (c for c in definitions.keys() - EXCLUDE if not c.startswith("_")) all_ = [*sorted(it), "Root", "VegaLiteSchema", "SchemaBase", "load_schema"] contents = [ diff --git a/tools/update_init_file.py b/tools/update_init_file.py index 28ba165f0..e58cb14dc 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -7,18 +7,7 @@ from inspect import ismodule, getattr_static from pathlib import Path -from typing import ( - IO, - Any, - Iterable, - List, - Sequence, - TypeVar, - Union, - cast, - TYPE_CHECKING, - Literal, -) +from typing import TYPE_CHECKING import typing as t import typing_extensions as te @@ -26,19 +15,28 @@ _TYPING_CONSTRUCTS = { te.TypeAlias, - TypeVar, - cast, - List, - Any, - Literal, - Union, - Iterable, + t.TypeVar, + t.cast, + t.overload, + te.runtime_checkable, + t.List, + t.Dict, + t.Tuple, + t.Any, + t.Literal, + t.Union, + t.Iterable, t.Protocol, te.Protocol, - Sequence, - IO, + t.Sequence, + t.IO, annotations, + te.Required, + te.TypedDict, + t.TypedDict, + te.Self, te.deprecated, + te.TypeAliasType, } @@ -80,7 +78,7 @@ def update__all__variable() -> None: ruff_write_lint_format_str(init_path, new_lines) -def relevant_attributes(namespace: dict[str, Any], /) -> list[str]: +def relevant_attributes(namespace: dict[str, t.Any], /) -> list[str]: """Figure out which attributes in `__all__` are relevant. Returns an alphabetically sorted list, to insert into `__all__`. @@ -90,6 +88,17 @@ def relevant_attributes(namespace: dict[str, Any], /) -> list[str]: namespace A module dict, like `altair.__dict__` """ + from altair.vegalite.v5.schema import _typing + + # NOTE: Exclude any `TypeAlias` that were reused in a runtime definition. + # Required for imports from `_typing`, outside of a `TYPE_CHECKING` block. + _TYPING_CONSTRUCTS.update( + ( + v + for k, v in _typing.__dict__.items() + if (not k.startswith("__")) and _is_hashable(v) + ) + ) it = ( name for name, attr in namespace.items() @@ -98,7 +107,7 @@ def relevant_attributes(namespace: dict[str, Any], /) -> list[str]: return sorted(it) -def _is_hashable(obj: Any) -> bool: +def _is_hashable(obj: t.Any) -> bool: """Guard to prevent an `in` check occuring on mutable objects.""" try: return bool(hash(obj)) @@ -106,7 +115,7 @@ def _is_hashable(obj: Any) -> bool: return False -def _is_relevant(attr: Any, name: str, /) -> bool: +def _is_relevant(attr: t.Any, name: str, /) -> bool: """Predicate logic for filtering attributes.""" if ( getattr_static(attr, "_deprecated", False)