diff --git a/doc/api-operator.rst b/doc/api-operator.rst index bc453259..4460ec7d 100644 --- a/doc/api-operator.rst +++ b/doc/api-operator.rst @@ -36,22 +36,24 @@ Operators .. autosummary:: apply_units - as_quantity assign_units concat convert_units drop_vars + expand_dims relabel rename rename_dims select unique_units_from_dim - Input and output: + Input, output, and generating new quantities: .. autosummary:: + as_quantity load_file - add_load_file + random_qty + wildcard_qty write_report Helper functions for adding tasks to Computers diff --git a/doc/api-testing.rst b/doc/api-testing.rst index 5133d550..3ccfe0d5 100644 --- a/doc/api-testing.rst +++ b/doc/api-testing.rst @@ -1,6 +1,8 @@ Test and documentation utilities ******************************** +.. autodata:: genno.core.computer.DEFAULT_WARN_ON_RESULT_TUPLE + .. automodule:: genno.testing :members: diff --git a/doc/api.rst b/doc/api.rst index 76cb3cc1..f6f476a8 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -7,7 +7,7 @@ Top-level classes and functions Computer Key - KeySeq + Keys Quantity assert_quantity configure @@ -48,7 +48,9 @@ Also: apply cache describe + duplicate eval + insert visualize Executing computations: @@ -205,42 +207,43 @@ Also: 2. zero or more ordered :attr:`dims`, and 3. an optional :attr:`tag`. - For example, for a :math:`\text{foo}` with with three dimensions :math:`a, b, c`: + For example, for a quantity :math:`\text{foo}` with with three dimensions :math:`a, b, c`: .. math:: \text{foo}^{abc} - Key allows a specific, explicit reference to various forms of “foo”: + …Key allows a specific, explicit reference to various forms of “foo”: - - in its full resolution, i.e. indexed by a, b, and c: + - in its *full resolution*; that is, indexed by a, b, and c: >>> k1 = Key("foo", ["a", "b", "c"]) >>> k1 - - in a partial sum over one dimension, e.g. summed across dimension c, with remaining dimensions a and b: + - in a partial sum over one dimension, for instance summed across dimension c, with remaining dimensions a and b: - >>> k2 = k1.drop('c') - >>> k2 == 'foo:a-b' + >>> k2 = k1 / "c" + >>> k2 == "foo:a-b" True - in a partial sum over multiple dimensions, etc.: - >>> k1.drop('a', 'c') == k2.drop('a') == 'foo:b' + >>> k1.drop("a", "c") == k1 / ("a", "c") == k2 / "a" == "foo:b" True - after it has been manipulated by other computations, e.g. - >>> k3 = k1.add_tag('normalized') + >>> k3 = k1.add_tag(""normalized") >>> k3 - >>> k4 = k3.add_tag('rescaled') + >>> k4 = k3 + "rescaled" >>> k4 - **Notes:** + **Key comparison.** + + - Keys with the same name, dimensions, and tag compare and :func:`hash` equal—even if the dimensions are in a different order. + - A key compares (but does *not* :func:`hash`) equal to a :class:`str` with the same name, dimensions (in any order) and tag. - A Key has the same hash, and compares equal to its :class:`str` representation. - A Key also compares equal to another key or :class:`str` with the same dimensions in any other order. :py:`repr(key)` prints the Key in angle brackets ('<>') to signify that it is a Key object. >>> str(k1) @@ -263,7 +266,8 @@ Also: .. _key-arithmethic: - Keys can also be manipulated using some of the Python arithmetic operators: + **Key arithmetic.** + Keys can be manipulated using some of the Python arithmetic operators: - :py:`+`: and :py:`-`: manipulate :attr:`.tag`, same as :meth:`.add_tag` and :meth:`.remove_tag` respectively: @@ -293,46 +297,63 @@ Also: >>> k1 / Key("baz", "cde") -.. autoclass:: genno.KeySeq - :members: - + **Key generation and derivation.** When preparing chains or complicated graphs of computations, it can be useful to use a sequence or set of similar keys to refer to the intermediate steps. - The :class:`.KeySeq` class is provided for this purpose. - It supports several ways to create related keys starting from a *base key*: - - >>> ks = KeySeq("foo:x-y-z:bar") - - One may: - - - Use item access syntax: - - >>> ks["a"] - - >>> ks["b"] - - - - Use the Python built-in :func:`.next`. - This always returns the next key in a sequence of integers, starting with :py:`0` and continuing from the *highest previously created Key*: - - >>> next(ks) - + Python item-access syntax (:py:`[...]`) and the built-in function :func:`next` can be used to generate or derive keys from an original one, in any order: + + >>> k1 = Key("foo:a-b-c") + >>> k[0] + + >>> k[1] + + >>> k["bar"] + + >>> k[99] + + + :func:`.next` always returns the next key in a sequence of integers, starting with :py:`0` and continuing from the *highest previously created tag/Key*: + + >>> next(k) + + # Same + + A Key is callable, with any value that has a :class:`str` representation: + + >>> k() + + # Same as item-access syntax + >>> k("baz") + + + The attributes :attr:`.last` and :attr:`.generated` allow to inspect one or all of the keys that have been derived from an original: + + >>> k.last + + >>> k.generated + (, + , + , + , + , + , + ) + +.. autoclass:: genno.Keys + :members: - # Skip some values - >>> ks[5] - + >>> k = Keys(foo="X:a-b-c-d-e-f", bar="Y:a-b-c:long+sequence+of+tags") + >>> k.baz = "Z:a-b-c-e-f" - # next() continues from the highest - >>> next(ks) - +.. autoclass:: genno.KeySeq + :members: - - Treat the KeySeq as callable, optionally with any value that has a :class:`.str` representation: + .. note:: As of genno 1.28.0, :class:`.Key` provides most of the conveniences and shorthand that were previously provided by KeySeq. + User could *should* prefer use of Key and Keys. + KeySeq *may* eventually be deprecated and removed. - >>> ks("c") - + KeySeq supports several ways to create related keys starting from a *base key*: - # Same as next() - >>> ks() - + >>> ks = KeySeq("foo:x-y-z:bar") - Access the most recently generated item: diff --git a/doc/compat-sdmx.rst b/doc/compat-sdmx.rst index 94c0a743..becd8162 100644 --- a/doc/compat-sdmx.rst +++ b/doc/compat-sdmx.rst @@ -28,6 +28,7 @@ To ensure the operators are available: .. autosummary:: codelist_to_groups + coords_to_codelists dataset_to_quantity quantity_to_dataset quantity_to_message diff --git a/doc/index.rst b/doc/index.rst index b004c5e0..0f60ed1b 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -4,7 +4,7 @@ **genno** is a Python package for describing and executing complex calculations on labelled, multi-dimensional data. It aims to make these calculations efficient, transparent, modular, and easily validated as part of scientific research. -genno is built on high-quality Python data packages including :py:`dask`, :mod:`xarray`, :mod:`pandas`, and :py:`pint`; and provides (current or planned) compatibility with packages including :mod:`plotnine <.compat.plotnine>`, :mod:`sdmx1 <.compat.sdmx>`, :mod:`matplotlib`, :mod:`ixmp`, and :mod:`pyam <.compat.pyam>`. +genno is built on high-quality Python data packages including :mod:`dask `, :mod:`xarray`, :mod:`pandas`, :mod:`pint`, and :mod:`sparse`; and provides (current or planned) compatibility with packages including :mod:`plotnine <.compat.plotnine>`, :mod:`sdmx1 <.compat.sdmx>`, :mod:`matplotlib`, :mod:`ixmp`, and :mod:`pyam <.compat.pyam>`. .. toctree:: :maxdepth: 2 diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 4f782d9a..7c4b0b5c 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -4,8 +4,37 @@ What's new Next release ============ +Future deprecations +------------------- + +The following usage will be deprecated and removed in some future version(s) of genno. + +- Returning multiple Keys (instead of a single Key) from :meth:`.Computer.add`. + + - This change is planned to simplify type hinting and checking of code that uses :mod:`genno`. + - In order to preview this change, warnings (possibly many) can be generated by enabling :data:`.DEFAULT_WARN_ON_RESULT_TUPLE`. +- Import :func:`.random_qty` from :mod:`genno.testing`. + Instead, import from :mod:`genno.operator`. + +All changes +----------- + - :class:`.SparseDataArray` is tested on Python 3.13 (:pull:`158`). -- :meth:`.AttrSeries.sum` supports the same use of :any:`Ellipsis` as :meth:`xarray.DataArray.sum`, for instance :py:`qty.sum(...)` (:pull:`158`) +- Improvements to :class:`.Computer` (:pull:`157`) + + - Item-setter syntax is supported as a shorthand for :meth:`~.Computer.add`, for example :py:`c["X:a-b"] = "mul", "Y:a-b", "Z:b"` (:issue:`160`). + - New methods :meth:`~.Computer.duplicate` and :meth:`~.Computer.insert` (:issue:`129`). + - :meth:`.Computer.describe` avoids :class:`RecursionError` when called on malformed (cyclic) graphs. +- :class:`.Key` directly provides many of the key-generation features previously provided by :class:`.KeySeq` (:pull:`157`). +- :class:`.Key` hashes the same, regardless of dimension order (:pull:`157`, :issue:`159`). +- New class :class:`.Keys`, a typed namespace of :class:`.Key` (:pull:`157`). +- New operators :func:`.expand_dims`, :func:`.random_qty` (previously in :mod:`genno.testing`), :func:`.wildcard_qty`, and :func:`.compat.sdmx.operator.coords_to_codelists` (:pull:`157`). +- Operator :func:`.write_report` gains :py:`header_datetime=...` and :py:`header_units=...` keywords for writing to CSV (:pull:`157`). +- :meth:`.AttrSeries.squeeze` supports the :py:`dim=...` argument (:pull:`157`, :issue:`144`). +- :meth:`.AttrSeries.sum` supports the same use of :any:`Ellipsis` as :meth:`xarray.DataArray.sum`, for instance :py:`qty.sum(...)` (:pull:`158`). +- New type variables :class:`.TKeyLike` and :class:`.TQuantity` (:pull:`157`). + These should be used in downstream code when the return value of a function is the *same* type as its inputs. + For example, a function that returns Key when passed Key; or str when passed str, should use :class:`.TKeyLike`; a function that returns :class:`.AttrSeries` when passed AttrSeries should use :class:`.TQuantity`. v1.27.1 (2024-11-12) ==================== diff --git a/genno/__init__.py b/genno/__init__.py index c53b41f3..06cf7dc9 100644 --- a/genno/__init__.py +++ b/genno/__init__.py @@ -3,7 +3,7 @@ from .config import configure from .core.computer import Computer from .core.exceptions import ComputationError, KeyExistsError, MissingKeyError -from .core.key import Key, KeySeq +from .core.key import Key, Keys, KeySeq from .core.operator import Operator from .core.quantity import Quantity, assert_quantity, get_class, set_class @@ -11,6 +11,7 @@ "ComputationError", "Computer", "Key", + "Keys", "KeySeq", "KeyExistsError", "MissingKeyError", diff --git a/genno/compat/pyam/__init__.py b/genno/compat/pyam/__init__.py index e981fad2..c191071a 100644 --- a/genno/compat/pyam/__init__.py +++ b/genno/compat/pyam/__init__.py @@ -53,7 +53,7 @@ def iamc(c: Computer, info): collapse_info = info.pop("collapse", {}) collapse_func = collapse_info.pop("callback", util.collapse) - # Use the Computer method to add the coversion step + # Use the Computer method to add the conversion step # NB convert_pyam() returns a single key when applied to a single key keys.append( single_key( diff --git a/genno/compat/pyam/operator.py b/genno/compat/pyam/operator.py index 1ccaf3df..33ef8c1b 100644 --- a/genno/compat/pyam/operator.py +++ b/genno/compat/pyam/operator.py @@ -19,7 +19,7 @@ import pandas from genno.core.computer import Computer - from genno.core.quantity import AnyQuantity + from genno.types import AnyQuantity, TQuantity log = logging.getLogger(__name__) @@ -226,11 +226,11 @@ def _(*args: pyam.IamDataFrame, **kwargs) -> "pyam.IamDataFrame": def quantity_from_iamc( - qty: Union["AnyQuantity", "pyam.IamDataFrame", "pandas.DataFrame"], + qty: Union["TQuantity", "pyam.IamDataFrame", "pandas.DataFrame"], variable: str, *, fail: Union[int, str] = "warning", -) -> "AnyQuantity": +) -> "TQuantity": """Extract data for a single measure from `qty` with IAMC-like structure. Parameters diff --git a/genno/compat/sdmx/operator.py b/genno/compat/sdmx/operator.py index 593f2e3a..0c12dcab 100644 --- a/genno/compat/sdmx/operator.py +++ b/genno/compat/sdmx/operator.py @@ -1,8 +1,7 @@ -from collections.abc import Hashable, Iterable, Mapping -from typing import Optional, Union +from collections.abc import Callable, Hashable, Iterable, Mapping +from typing import TYPE_CHECKING, Any, Optional, Union -import genno -from genno import Quantity +from genno.operator import write_report try: import sdmx @@ -13,8 +12,12 @@ from . import util +if TYPE_CHECKING: + from genno.types import AnyQuantity + __all__ = [ "codelist_to_groups", + "coords_to_codelists", "dataset_to_quantity", "quantity_to_dataset", "quantity_to_message", @@ -56,7 +59,29 @@ def codelist_to_groups( return {dim: groups} -def dataset_to_quantity(ds: "sdmx.model.common.BaseDataSet") -> Quantity: +def coords_to_codelists( + qty: "AnyQuantity", *, id_transform: Optional[Callable] = str.upper, **kwargs +) -> list["sdmx.model.common.Codelist"]: + """Convert the coordinates of `qty` to a collection of :class:`.Codelist`.""" + from sdmx.model.common import Codelist + + result = [] + + def _transform(value: Any) -> str: + if id_transform is None: + return str(value) + else: + return id_transform(value) + + for dim_id, labels in qty.coords.items(): + cl = Codelist(id=_transform(dim_id), **kwargs) + [cl.setdefault(id=str(label)) for label in labels.data] + result.append(cl) + + return result + + +def dataset_to_quantity(ds: "sdmx.model.common.BaseDataSet") -> "AnyQuantity": """Convert :class:`DataSet ` to :class:`.Quantity`. Returns @@ -74,6 +99,8 @@ def dataset_to_quantity(ds: "sdmx.model.common.BaseDataSet") -> Quantity: :attr:`structured_by ` attribute of `ds`, if any. """ + from genno import Quantity + # Assemble attributes attrs: dict[str, str] = {} if ds.described_by: # pragma: no cover @@ -85,7 +112,7 @@ def dataset_to_quantity(ds: "sdmx.model.common.BaseDataSet") -> Quantity: def quantity_to_dataset( - qty: Quantity, + qty: "AnyQuantity", structure: "sdmx.model.common.BaseDataStructureDefinition", *, observation_dimension: Optional[str] = None, @@ -170,7 +197,7 @@ def as_obs(key, value): def quantity_to_message( - qty: Quantity, structure: "sdmx.model.v21.DataStructureDefinition", **kwargs + qty: "AnyQuantity", structure: "sdmx.model.v21.DataStructureDefinition", **kwargs ) -> "sdmx.message.DataMessage": """Convert :class:`.Quantity` to :class:`DataMessage `. @@ -197,7 +224,7 @@ def quantity_to_message( return sdmx.message.DataMessage(data=[ds], **kwargs) -@genno.operator.write_report.register +@write_report.register def _(obj: "sdmx.message.DataMessage", path, kwargs=None) -> None: """Write `obj` to the file at `path`. @@ -205,8 +232,6 @@ def _(obj: "sdmx.message.DataMessage", path, kwargs=None) -> None: use :mod:`sdmx` methods to write the file to SDMX-ML. Otherwise, equivalent to :func:`genno.operator.write_report`. """ - import genno.compat.sdmx.operator # noqa: F401 - assert path.suffix.lower() == ".xml" kwargs = kwargs or {} diff --git a/genno/config.py b/genno/config.py index e3407083..81f18f9e 100644 --- a/genno/config.py +++ b/genno/config.py @@ -212,6 +212,7 @@ def aggregate(c: Computer, info): kw = dict( fail=info.pop("_fail", None), groups={info.pop("_dim"): info}, + keep=True, strict=True, sums=True, ) diff --git a/genno/core/attrseries.py b/genno/core/attrseries.py index b8f9ff70..02c02445 100644 --- a/genno/core/attrseries.py +++ b/genno/core/attrseries.py @@ -1,7 +1,7 @@ import logging from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from functools import partial -from itertools import tee +from itertools import product, tee from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np @@ -43,6 +43,11 @@ def _ensure_multiindex(obj): if len(obj.index) > 1 and obj.index.name is None: kw["names"] = ["dim_0"] obj.index = pd.MultiIndex.from_product([obj.index], **kw) + else: + # From a ≥2-dim index, drop a dimension with name `None` and only 1 level + if len(obj.index.names) > 1 and None in obj.index.names: + obj.index = obj.index.droplevel(obj.index.names.index(None)) + return obj @@ -178,6 +183,9 @@ def _perform_binary_op( # Invoke a pd.Series method like .mul() fv = dict(fill_value=0.0) if rank(op) == 1 else {} + # FIXME In downstream code this occasionally warns RuntimeWarning: The values + # in the array are unorderable. Pass `sort=False` to suppress this + # warning. Address. return getattr(left, op.__name__)(right, **fv).dropna().reorder_levels(order) def assign_coords(self, coords=None, **coord_kwargs): @@ -282,34 +290,49 @@ def drop_vars( return self.droplevel(names) - def expand_dims(self, dim=None, axis=None, **dim_kwargs: Any) -> "AttrSeries": - """Like :meth:`xarray.DataArray.expand_dims`. - - .. todo:: Support passing a mapping of length > 1 to `dim`. - """ - if isinstance(dim, list): - dim = dict.fromkeys(dim, []) - dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") + def expand_dims( + self, + dim: Union[Hashable, Sequence[Hashable], Mapping[Any, Any], None] = None, + axis: Union[int, Sequence[int], None] = None, + create_index_for_new_dim: bool = True, + **dim_kwargs: Any, + ) -> "AttrSeries": + """Like :meth:`xarray.DataArray.expand_dims`.""" if axis is not None: - raise NotImplementedError # pragma: no cover + raise NotImplementedError( # pragma: no cover + "AttrSeries.expand_dims(…, axis=…) keyword argument" + ) - result = self - for name, values in reversed(list(dim.items())): - N = len(values) - if N == 0: # Dimension without labels - N, values = 1, [None] - result = pd.concat([result] * N, keys=values, names=[name], sort=False) - - # Ensure `result` is multiindexed - try: - i = result.index.names.index(None) - except ValueError: - pass - else: - assert 2 == len(result.index.names) - result.index = pd.MultiIndex.from_product([result.index.droplevel(i)]) + # Handle inputs. This block identical to part of xr.DataArray.expand_dims. + if isinstance(dim, int): + raise TypeError("dim should be Hashable or sequence/mapping of Hashables") + elif isinstance(dim, Sequence) and not isinstance(dim, str): + if len(dim) != len(set(dim)): + raise ValueError("dims should not contain duplicate values.") + dim = dict.fromkeys(dim, 1) + elif dim is not None and not isinstance(dim, Mapping): + dim = {dim: 1} + + if dim is None or 0 == len(dim): + # Nothing to do → return early + return self.copy() + + # Assemble names → keys mapping for added dimensions + n_k = {} + for dim, value in either_dict_or_kwargs(dim, dim_kwargs, "expand_dims").items(): + if isinstance(value, int): + n_k[dim] = range(value) + elif isinstance(value, (list, pd.Index)) and 0 == len(value): + log.warning(f'Insert length-1 dimension for {{"{dim}": []}}') + n_k[dim] = range(1) + else: + n_k[dim] = value + keys = list(product(*n_k.values())) + names = list(n_k.keys()) - return result + return _ensure_multiindex( + pd.concat([self] * len(keys), keys=keys, names=names, sort=False) + ) def ffill(self, dim: Hashable, limit: Optional[int] = None): """Like :meth:`xarray.DataArray.ffill`.""" @@ -543,9 +566,14 @@ def squeeze(self, dim=None, drop=False, axis=None): idx = self.index.remove_unused_levels() + if isinstance(dim, Iterable) and not isinstance(dim, str): + dim = list(dim) + elif dim is not None: + dim = [dim] + to_drop = [] for i, name in enumerate(filter(None, idx.names)): - if dim and name != dim: + if dim and name not in dim: continue elif len(idx.levels[i]) > 1: if dim is None: diff --git a/genno/core/base.py b/genno/core/base.py index 70534b4c..0b5e7791 100644 --- a/genno/core/base.py +++ b/genno/core/base.py @@ -9,7 +9,7 @@ import pint if TYPE_CHECKING: - from genno.types import Unit + from genno.types import TQuantity, Unit from .quantity import AnyQuantity @@ -176,11 +176,11 @@ def __init__( def _keep( self, - target: "AnyQuantity", + target: "TQuantity", attrs: Optional[Any] = False, name: Optional[Any] = False, units: Optional[Any] = False, - ) -> "AnyQuantity": + ) -> "TQuantity": """Preserve `name`, `units`, and/or other `attrs` from `self` to `target`. The action for each argument is: diff --git a/genno/core/computer.py b/genno/core/computer.py index ec864a94..984cf7a2 100644 --- a/genno/core/computer.py +++ b/genno/core/computer.py @@ -30,15 +30,22 @@ from .describe import describe_recursive from .exceptions import ComputationError, KeyExistsError, MissingKeyError from .graph import Graph -from .key import Key, KeyLike +from .key import Key if TYPE_CHECKING: import genno.core.graph import genno.core.key + from genno.core.key import KeyLike + from genno.types import TKeyLike log = logging.getLogger(__name__) +#: Emit :class:`.FutureWarning` from :meth:`.Computer.add` when :class:`.tuple` is +#: returned. This default value can be overridden with +#: :py:`c.configure(config={"warn on result tuple": False})`. +DEFAULT_WARN_ON_RESULT_TUPLE = False + class Computer: """Class for describing and executing computations. @@ -53,7 +60,7 @@ class Computer: graph: "genno.core.graph.Graph" = Graph(config=dict()) #: The default key to :meth:`.get` with no argument. - default_key: Optional["genno.core.key.KeyLike"] = None + default_key: Optional["KeyLike"] = None #: List of modules containing operators. #: @@ -75,9 +82,17 @@ def __init__(self, **kwargs): # Python data model - def __contains__(self, item): + def __contains__(self, item) -> bool: return self.graph.__contains__(item) + def __setitem__(self, data: "KeyLike", *args) -> None: + _args, kwargs = args[0], {} + + if isinstance(_args[-1], dict): + *_args, kwargs = _args + + self.add(data, *_args, **kwargs) + # Dask data model def __dask_keys__(self): @@ -241,7 +256,7 @@ def require_compat(self, pkg: Union[str, types.ModuleType]): # Add computations to the Computer - def add(self, data, *args, **kwargs) -> Union[KeyLike, tuple[KeyLike, ...]]: + def add(self, data, *args, **kwargs) -> Union["KeyLike", tuple["KeyLike", ...]]: """General-purpose method to add computations. :meth:`add` can be called in several ways; its behaviour depends on `data`; see @@ -261,13 +276,14 @@ def add(self, data, *args, **kwargs) -> Union[KeyLike, tuple[KeyLike, ...]]: .iter_keys .single_key """ + # Other methods if isinstance(data, Sequence) and not isinstance(data, str): # Sequence of (args, kwargs) or args; use add_queue() - return self.add_queue(data, *args, **kwargs) + return _warn_on_result(self, self.add_queue(data, *args, **kwargs)) elif isinstance(data, str) and data in dir(self) and data != "add": # Name of another method such as "apply" or "eval" - return getattr(self, data)(*args, **kwargs) + return _warn_on_result(self, getattr(self, data)(*args, **kwargs)) # Possibly identify a named or direct callable in `data` or `args[0]` func: Optional[Callable] = None @@ -293,10 +309,13 @@ def add(self, data, *args, **kwargs) -> Union[KeyLike, tuple[KeyLike, ...]]: if func: try: - # Use an implementation of Computation.add_task() - return func.add_tasks(self, *args, **kwargs) # type: ignore [attr-defined] + # Use an implementation of Operator.add_task() + return _warn_on_result( + self, + func.add_tasks(self, *args, **kwargs), # type: ignore [attr-defined] + ) except (AttributeError, NotImplementedError): - # Computation obj that doesn't implement .add_tasks(), or plain callable + # Operator obj that doesn't implement .add_tasks(), or plain callable _partialed_func, kw = partial_split(func, kwargs) key = args[0] computation = (_partialed_func,) + args[1:] @@ -317,10 +336,12 @@ def add(self, data, *args, **kwargs) -> Union[KeyLike, tuple[KeyLike, ...]]: # Optionally add sums if isinstance(result, Key) and sums: # Add one entry for each of the partial sums of `result` - return (result,) + self.add_queue(result.iter_sums(), fail=fail) + return _warn_on_result( + self, (result,) + self.add_queue(result.iter_sums(), fail=fail) + ) else: # NB This might be deprecated to simplify expectations of calling code - return result + return _warn_on_result(self, result) def cache(self, func): """Decorate `func` so that its return value is cached. @@ -336,7 +357,7 @@ def add_queue( # noqa: C901 FIXME reduce complexity from 11 → ≤10 queue: Iterable[tuple], max_tries: int = 1, fail: Optional[Union[str, int]] = None, - ) -> tuple[KeyLike, ...]: + ) -> tuple["KeyLike", ...]: """Add tasks from a list or `queue`. Parameters @@ -360,7 +381,7 @@ def add_queue( # noqa: C901 FIXME reduce complexity from 11 → ≤10 fail = self._queue_fail[-1] # Use the same value as an outer call. # Accumulate added keys - added: list[KeyLike] = [] + added: list["KeyLike"] = [] class Item: """Container for queue items.""" @@ -424,8 +445,8 @@ def _log(msg: str, i: Item, e: Optional[Exception] = None, level=logging.DEBUG): # Generic graph manipulations def add_single( - self, key: KeyLike, *computation, strict=False, index=False - ) -> KeyLike: + self, key: "KeyLike", *computation, strict=False, index=False + ) -> "KeyLike": """Add a single `computation` at `key`. Parameters @@ -451,8 +472,11 @@ def add_single( If `strict` is :obj:`True` and any key referred to by `computation` does not exist. """ - if len(computation) == 1 and not callable(computation[0]): - # Unpack a length-1 tuple + # Unpack a length-1 tuple, except for a tuple starting with a callable (task + # with no arguments) + if len(computation) == 1 and ( + isinstance(computation[0], Key) or not callable(computation[0]) + ): computation = computation[0] if index: @@ -498,7 +522,7 @@ def _rewrite_comp(self, computation): def apply( self, generator: Callable, *keys, **kwargs - ) -> Union[KeyLike, tuple[KeyLike, ...]]: + ) -> Union["KeyLike", tuple["KeyLike", ...]]: """Add computations by applying `generator` to `keys`. Parameters @@ -556,6 +580,31 @@ def apply( return tuple(result) if len(result) > 1 else result[0] + def duplicate(self, key: "TKeyLike", tag: str) -> "TKeyLike": + """Duplicate the task at `key` and all of its inputs. + + Re + + Parameters + ---------- + key + Starting key to duplicate. + tag + :attr:`~.Key.tag` to add to duplicated keys. + """ + + comp = self.graph[key] # Retrieve the existing computation at `key` + new_key = type(key)(Key(key) + tag) # Identify the new key; same type as `key` + + if isinstance(comp, (list, tuple)): + # Rewrite the computation + new_comp = [self.duplicate(x, tag) if x in self.graph else x for x in comp] + self.graph[new_key] = type(comp)(new_comp) + else: + self.graph[new_key] = comp + + return new_key + def eval(self, expr: str) -> tuple[Key, ...]: r"""Evaluate `expr` to add tasks and keys. @@ -635,7 +684,8 @@ def get(self, key=None): log.debug(f"Cull {len(self.graph)} -> {len(dsk)} keys") try: - result = dask.get(dsk, key) + # Dask doesn't know about genno.Key; pass a str with original dim order + result = dask.get(dsk, str(key)) except Exception as exc: raise ComputationError(exc) from None else: @@ -644,13 +694,63 @@ def get(self, key=None): # Unwrap config from protection applied above self.graph["config"] = self.graph["config"][0].data + def insert(self, key: "KeyLike", *args, tag: str = "pre", **kwargs) -> None: + """Insert a task before `key`, using `args`, `kwargs`. + + The existing task at `key` is moved to :py:`key + tag`. The `args` and `kwargs` + are passed to :meth:`add` to insert a new task at `key`. The `args` must include + at least 2 items: + + 1. the new :class:`callable` or :class:`Operator`, and + 2. the :any:`.Ellipsis` (:py:`...`), which is replaced by the shifted + :py:`key + tag`. + + If there are more than 2 items, each instance of the :class:`.Ellipsis` is + replaced per (2); all other items (and `kwargs`) are passed on as-is. + + The effect is that all existing tasks to which `key` are input will receive, + instead, the output of the added task. + + One way to use :func:`insert` is with a ‘pass-through’ `operation` that, for + instance, performs logging, assertions, or other steps, then returns its input + unchanged. It is also possible to insert a new task that mutates its input in + certain ways. + """ + # Determine a key for the task to be shifted + k_pre = self.infer_keys(key) + tag + if k_pre in self: + # Cannot shift `key` because the target key already exists + raise KeyExistsError(k_pre) + + # Construct the arguments for the add() call + if len(args) < 2: + raise ValueError( + "Must supply at least 2 args (operator, ...) to Computer.insert(); " + f"got {args}" + ) + elif Ellipsis not in args: + raise ValueError(f"One arg must be '...'; got {args}") + + _args = [k_pre if a is Ellipsis else a for a in args] + + try: + # Preserve the existing task at `key` + existing = self.graph[key].copy() + # Add `operation` at `key`, operating on the output of the original task + self.add(key, *_args, **kwargs) + except Exception: + raise + else: + # Move the existing task at `key` to `k_pre` + self.graph[k_pre] = existing + # Convenience methods for the graph and its keys def keys(self): """Return the keys of :attr:`~genno.Computer.graph`.""" return self.graph.keys() - def full_key(self, name_or_key: KeyLike) -> KeyLike: + def full_key(self, name_or_key: "KeyLike") -> "KeyLike": """Return the full-dimensionality key for `name_or_key`. An quantity 'foo' with dimensions (a, c, n, q, x) is available in the Computer @@ -672,7 +772,7 @@ def full_key(self, name_or_key: KeyLike) -> KeyLike: def check_keys( self, *keys: Union[str, Key], predicate=None, action="raise" - ) -> list[KeyLike]: + ) -> list["KeyLike"]: """Check that `keys` are in the Computer. Parameters @@ -734,7 +834,9 @@ def _check(value): return result def infer_keys( - self, key_or_keys: Union[KeyLike, Iterable[KeyLike]], dims: Iterable[str] = [] + self, + key_or_keys: Union["KeyLike", Iterable["KeyLike"]], + dims: Iterable[str] = [], ): """Infer complete `key_or_keys`. @@ -777,7 +879,9 @@ def describe(self, key=None, quiet=True): Returns ------- str - Description of computations. + Description of computations. If a malformed :attr:`.graph` is detected (one + key is its own direct ancestor), the text “← CYCLE DETECTED” is shown, and + recursion stops. """ # TODO accept a list of keys, like get() if key is None: @@ -868,7 +972,7 @@ def add_product(self, *args, **kwargs): def aggregate( self, - qty: KeyLike, + qty: "KeyLike", tag: str, dims_or_groups: Union[Mapping, str, Sequence[str]], weights: Optional[xr.DataArray] = None, @@ -1012,3 +1116,16 @@ def disaggregate(self, qty, new_dim, method="shares", args=[]): warn(f"Computer.disaggregate(…, {msg}", DeprecationWarning, stacklevel=2) return self.add(key, method, qty, *args, sums=False, strict=True) + + +def _warn_on_result(computer: Computer, result): + if isinstance(result, tuple) and computer.graph.get("config", {}).get( + "warn on result tuple", DEFAULT_WARN_ON_RESULT_TUPLE + ): + warn( + f"Return {len(result)}-tuple from Computer.add(); in a future version of " + f"genno only the first added Key ({result[0]}) will be returned", + FutureWarning, + stacklevel=2, + ) + return result diff --git a/genno/core/describe.py b/genno/core/describe.py index 5e0d33bc..d2cc0f58 100644 --- a/genno/core/describe.py +++ b/genno/core/describe.py @@ -37,16 +37,27 @@ def describe_recursive(graph, comp, depth=0, seen=None): result = [] for arg in comp: + try: + # Record whether `arg` has been seen already + arg_seen = arg in seen + # Update `seen` so that `arg` is not handled in recursive calls below + seen.add(arg) + except TypeError: # `arg` is unhashable, e.g. dict, list + arg_seen = False + # Don't fully reprint keys and their ancestors that have been seen - if isinstance(arg, Hashable) and arg in seen: + if isinstance(arg, Hashable) and arg_seen: if depth > 0: # Don't print top-level items that have been seen result.append(f"{indent}'{arg}' (above)") continue elif isinstance(arg, (str, Key)) and arg in graph: # key that exists in the graph → recurse - item = "'{}':\n{}".format( - arg, describe_recursive(graph, graph[arg], depth + 1, seen) + item = f"'{arg}'" + sub_item = describe_recursive(graph, graph[arg], depth + 1, seen) + # A direct recurrence of `item` in `subtree` indicates a cycle + item += ":\n" + sub_item.replace( + f"{indent}{item} (above)", f"{indent}{item} ← CYCLE DETECTED" ) elif is_list_of_keys(arg, graph): # list → collection of items @@ -57,10 +68,6 @@ def describe_recursive(graph, comp, depth=0, seen=None): # Anything else: use a readable string representation item = label(arg) - try: - seen.add(arg) - except TypeError: - pass # `arg` is unhashable, e.g. a list result.append(indent + item) # Combine items diff --git a/genno/core/graph.py b/genno/core/graph.py index 2239a81b..17b3b779 100644 --- a/genno/core/graph.py +++ b/genno/core/graph.py @@ -1,12 +1,15 @@ from collections.abc import Generator, Iterable, Sequence from itertools import chain, tee from operator import itemgetter -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union -from .key import Key, KeyLike +from .key import Key +if TYPE_CHECKING: + from .key import KeyLike -def _key_arg(key: KeyLike) -> Union[str, Key]: + +def _key_arg(key: "KeyLike") -> Union[str, Key]: return Key.bare_name(key) or Key(key) @@ -29,10 +32,10 @@ class Graph(dict): infer """ - _unsorted: dict[KeyLike, KeyLike] = dict() + _unsorted: dict["KeyLike", "KeyLike"] = dict() _full: dict[Key, Key] = dict() - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: # Initialize members super().__init__(*args, **kwargs) @@ -44,7 +47,7 @@ def __init__(self, *args, **kwargs): for k in kwargs.keys(): self._index(k) - def _index(self, key: KeyLike): + def _index(self, key: "KeyLike") -> None: """Add `key` to the indices.""" k = _key_arg(key) if isinstance(k, Key): @@ -55,7 +58,7 @@ def _index(self, key: KeyLike): else: self._unsorted[k] = key - def _deindex(self, key: KeyLike): + def _deindex(self, key: "KeyLike") -> None: """Remove `key` from the indices.""" k = _key_arg(key) if isinstance(k, Key): @@ -64,11 +67,11 @@ def _deindex(self, key: KeyLike): else: self._unsorted.pop(k, None) - def __setitem__(self, key: KeyLike, value: Any): + def __setitem__(self, key: "KeyLike", value: Any) -> None: super().__setitem__(key, value) self._index(key) - def __delitem__(self, key: KeyLike): + def __delitem__(self, key: "KeyLike") -> None: super().__delitem__(key) self._deindex(key) @@ -87,7 +90,7 @@ def pop(self, *args): self._deindex(args[0]) def update(self, arg=None, **kwargs): - """Overload :meth:`dict.pop` to also call :meth:`_index`.""" + """Overload :meth:`dict.update` to also call :meth:`_index`.""" if isinstance(arg, (Sequence, Generator)): arg0, arg1 = tee(arg) arg_keys = map(itemgetter(0), arg0) @@ -100,18 +103,18 @@ def update(self, arg=None, **kwargs): super().update(arg1, **kwargs) - def unsorted_key(self, key: KeyLike) -> Optional[KeyLike]: + def unsorted_key(self, key: "KeyLike") -> Optional["KeyLike"]: """Return `key` with its original or unsorted dimensions.""" k = _key_arg(key) return self._unsorted.get(k.sorted if isinstance(k, Key) else k) - def full_key(self, name_or_key: KeyLike) -> Optional[KeyLike]: + def full_key(self, name_or_key: "KeyLike") -> Optional["KeyLike"]: """Return `name_or_key` with its full dimensions.""" return self._full.get(Key(name_or_key).drop_all()) def infer( self, key: Union[str, Key], dims: Iterable[str] = [] - ) -> Optional[KeyLike]: + ) -> Optional["KeyLike"]: """Infer a `key`. Parameters diff --git a/genno/core/key.py b/genno/core/key.py index 31fe7b6b..623364e2 100644 --- a/genno/core/key.py +++ b/genno/core/key.py @@ -53,11 +53,52 @@ def _(value: "AnyQuantity"): # register() only handles bare AnyQuantity in Pyth return str(value.name), tuple(map(str, value.dims)), None -class Key: +class KeyGeneratorMixIn: + """Mix-in class for classes that can derive :class:`.Key` from a base.""" + + __slots__ = ("_base", "_generated") + + _base: "Key" + _generated: list[Hashable] + + def __init__(self) -> None: + self._generated = [] + + def __call__(self, value: Optional[Hashable] = None) -> "Key": + return next(self) if value is None else self[value] + + def __getitem__(self, value: Hashable) -> "Key": + value = int(value) if isinstance(value, SupportsInt) else str(value) + if value not in self._generated: + self._generated.append(value) + return self._base.add_tag(str(value)) + + def __next__(self) -> "Key": + return self[self._next_int_tag()] + + def _next_int_tag(self) -> int: + return max([-1] + [t for t in self._generated if isinstance(t, int)]) + 1 + + @property + def generated(self) -> tuple["Key", ...]: + """Sequence of previously-created :class:`Keys <.Key>`.""" + return tuple(self._base.add_tag(str(k)) for k in self._generated) + + @property + def last(self) -> "Key": + """The most recently created :class:`.Key`.""" + return self._base.add_tag(str(self._generated[-1])) + + +class Key(KeyGeneratorMixIn): """A hashable key for a quantity that includes its dimensionality.""" - _name: str + __slots__ = ("_dims", "_hash", "_name", "_str", "_tag") + _dims: tuple[str, ...] + _hash: int + _name: str + _str: str _tag: Optional[str] def __init__( @@ -92,6 +133,9 @@ def __init__( self._dims = _dims or tuple(dims) self._tag = _tag or tag + super().__init__() + self._base = self + # Pre-compute string representation and hash self._str = ( self._name @@ -99,7 +143,13 @@ def __init__( + "-".join(self._dims) + (f":{self._tag}" if self._tag else "") ) - self._hash = hash(self._str) + # Hash is independent of dim order + self._hash = hash( + self._name + + ":" + + "-".join(sorted(self._dims)) + + (f":{self._tag}" if self._tag else "") + ) # Class methods @@ -185,9 +235,19 @@ def product(cls, new_name: str, *keys, tag: Optional[str] = None) -> "Key": ---------- new_name : str Name for the new Key. The names of *keys* are discarded. + keys + May include instances of :class:`.Key`, :class:`str` (converted to Key), or + :class:`Quantity` (the dimensions of the quantity are used directly). """ # Iterable of dimension names from all keys, in order, with repetitions - dims = chain(*map(lambda k: cls(k).dims, keys)) + dims = chain( + *map( + lambda k: cls(k).dims + if isinstance(k, (AttrSeries, SparseDataArray, Key, str)) + else (), + keys, + ) + ) # Return new key. Use dict to keep only unique *dims*, in same order return cls(new_name, dict.fromkeys(dims).keys()).add_tag(tag) @@ -294,7 +354,7 @@ def drop(self, *dims: Union[str, bool]) -> "Key": """Return a new Key with `dims` dropped.""" return Key( self._name, - [] if dims == (True,) else filter(lambda d: d not in dims, self._dims), + tuple() if dims == (True,) else filter(lambda d: d not in dims, self._dims), self._tag, _fast=True, ) @@ -345,35 +405,49 @@ def _(value: Key): return value._name, value._dims, value._tag -class KeySeq: - """Utility class for generating similar :class:`Keys <.Key>`.""" +class Keys: + """A collection of :class:`.Key`. - #: Base :class:`.Key` of the sequence. - base: Key + This is essentially the same as :class:`.types.SimpleNamespace`, except every + attribute is a :class:`.Key`. + """ - # Keys that have been created. - _keys: dict[Hashable, Key] + __slots__ = ("_keys",) - def __init__(self, *args, **kwargs): - self.base = Key(*args, **kwargs) - self._keys = {} + _keys: dict[str, Key] - def _next_int_tag(self) -> int: - return max([-1] + [t for t in self._keys if isinstance(t, int)]) + 1 + def __init__(self, **kwargs: "KeyLike") -> None: + object.__setattr__(self, "_keys", {}) + for name, value in kwargs.items(): + setattr(self, name, value) - def __next__(self) -> Key: - return self[self._next_int_tag()] + def __delattr__(self, name: str) -> None: + self._keys.pop(name) - def __call__(self, value: Optional[Hashable] = None) -> Key: - return next(self) if value is None else self[value] + def __getattr__(self, name: str) -> "Key": + try: + return self._keys[name] + except KeyError: + raise AttributeError(name) from None - def __getitem__(self, value: Hashable) -> Key: - tag = int(value) if isinstance(value, SupportsInt) else str(value) - result = self._keys[tag] = self.base + str(tag) - return result + def __repr__(self) -> str: + return f"<{len(self._keys)} keys: {' '.join(sorted(self._keys))}>" + + def __setattr__(self, name: str, value: "Key") -> None: + self._keys[name] = value if isinstance(value, Key) else Key(value) + + +class KeySeq(KeyGeneratorMixIn): + """Utility class for generating similar :class:`Keys <.Key>`.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self._base = Key(*args, **kwargs) def __repr__(self) -> str: - return f"" + return f"" + + # Particular to KeySeq @property def keys(self) -> MappingProxyType: @@ -382,40 +456,47 @@ def keys(self) -> MappingProxyType: In the form of a :class:`dict` mapping tags (:class:`int` or :class:`str`) to :class:`.Key` values. """ - return MappingProxyType(self._keys) + return MappingProxyType( + {k: self._base.add_tag(str(k)) for k in self._generated} + ) @property def prev(self) -> Key: - """The most recently created :class:`.Key`.""" - return next(reversed(self._keys.values())) + """Alias of :attr:`.KeyGeneratorMixin.last`.""" + return self.last # Access to Key properties + @property + def base(self) -> Key: + """The base Key.""" + return self._base + @property def name(self) -> str: """Name of the :attr:`.base` Key.""" - return self.base.name + return self._base.name @property def dims(self) -> tuple[str, ...]: """Dimensions of the :attr:`.base` Key.""" - return self.base.dims + return self._base.dims @property def tag(self) -> Optional[str]: """Tag of the :attr:`.base` Key.""" - return self.base.tag + return self._base.tag def __add__(self, other: str) -> "KeySeq": - return KeySeq(self.base + other) + return KeySeq(self._base.__add__(other)) def __mul__(self, other) -> "KeySeq": - return KeySeq(self.base * other) + return KeySeq(self._base.__mul__(other)) def __sub__(self, other: Union[str, Iterable[str]]) -> "KeySeq": - return KeySeq(self.base - other) + return KeySeq(self._base.__sub__(other)) def __truediv__(self, other) -> "KeySeq": - return KeySeq(self.base / other) + return KeySeq(self._base.__truediv__(other)) #: Type shorthand for :class:`Key` or any other value that can be used as a key. diff --git a/genno/operator.py b/genno/operator.py index 75494230..ef178a22 100644 --- a/genno/operator.py +++ b/genno/operator.py @@ -6,13 +6,15 @@ import operator import os import re -from collections.abc import Callable, Collection, Hashable, Iterable, Mapping +from collections.abc import Callable, Collection, Hashable, Iterable, Mapping, Sequence +from datetime import datetime from functools import partial, reduce, singledispatch from itertools import chain from os import PathLike from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union, cast +import numpy as np import pandas as pd import pint import xarray as xr @@ -23,12 +25,13 @@ from .core.attrseries import AttrSeries from .core.key import Key, KeyLike, iter_keys, single_key from .core.operator import Operator -from .core.quantity import AnyQuantity, assert_quantity +from .core.quantity import assert_quantity from .core.sparsedataarray import SparseDataArray from .util import UnitLike, collect_units, filter_concat_args, units_with_multiplier if TYPE_CHECKING: from genno import types + from genno.types import AnyQuantity, TQuantity __all__ = [ "add", @@ -44,6 +47,7 @@ "disaggregate_shares", "div", "drop_vars", + "expand_dims", "group_sum", "index_to", "interpolate", @@ -51,6 +55,7 @@ "mul", "pow", "product", + "random_qty", "ratio", "relabel", "rename", @@ -61,6 +66,7 @@ "sum", "unique_units_from_dim", "where", + "wildcard_qty", "write_report", ] @@ -101,7 +107,7 @@ def add_binop(func, c: "genno.Computer", key, *quantities, **kwargs) -> Key: """ # Fetch the full key for each quantity base_keys = c.check_keys( - *quantities, predicate=lambda v: isinstance(v, genno.Quantity) + *quantities, predicate=lambda v: isinstance(v, (genno.Quantity, int, float)) ) # Compute a key for the result @@ -122,7 +128,7 @@ def add_binop(func, c: "genno.Computer", key, *quantities, **kwargs) -> Key: @Operator.define(helper=add_binop) -def add(*quantities: "AnyQuantity", fill_value: float = 0.0) -> "AnyQuantity": +def add(*quantities: "TQuantity", fill_value: float = 0.0) -> "TQuantity": """Sum across multiple `quantities`. Raises @@ -146,8 +152,8 @@ def add(*quantities: "AnyQuantity", fill_value: float = 0.0) -> "AnyQuantity": def aggregate( - quantity: "AnyQuantity", groups: Mapping[str, Mapping], keep: bool -) -> "AnyQuantity": + quantity: "TQuantity", groups: Mapping[str, Mapping], keep: bool +) -> "TQuantity": """Aggregate `quantity` by `groups`. Parameters @@ -216,7 +222,7 @@ def _unit_args(qty, units): return *result, getattr(result[1], "dimensionality", {}), result[0].Unit(units) -def apply_units(qty: "AnyQuantity", units: UnitLike) -> "AnyQuantity": +def apply_units(qty: "TQuantity", units: UnitLike) -> "TQuantity": """Apply `units` to `qty`. If `qty` has existing units… @@ -301,7 +307,7 @@ def as_quantity(info: Union[dict, float, str]) -> "AnyQuantity": raise TypeError(type(info)) -def assign_units(qty: "AnyQuantity", units: UnitLike) -> "AnyQuantity": +def assign_units(qty: "TQuantity", units: UnitLike) -> "TQuantity": """Set the `units` of `qty` without changing magnitudes. Logs on level ``INFO`` if `qty` has existing units. @@ -330,11 +336,11 @@ def assign_units(qty: "AnyQuantity", units: UnitLike) -> "AnyQuantity": def broadcast_map( - quantity: "AnyQuantity", - map: "AnyQuantity", + quantity: "TQuantity", + map: "TQuantity", rename: Mapping = {}, strict: bool = False, -) -> "AnyQuantity": +) -> "TQuantity": """Broadcast `quantity` using a `map`. The `map` must be a 2-dimensional Quantity with dimensions (``d1``, ``d2``), such as @@ -360,21 +366,21 @@ def broadcast_map( def clip( - qty: "AnyQuantity", + qty: "TQuantity", min: Optional["types.ScalarOrArray"] = None, max: Optional["types.ScalarOrArray"] = None, *, keep_attrs: Optional[bool] = None, -) -> "AnyQuantity": +) -> "TQuantity": """Call :meth:`.Quantity.clip`.""" return qty.clip(min, max, keep_attrs=keep_attrs) def combine( - *quantities: "AnyQuantity", + *quantities: "TQuantity", select: Optional[list[Mapping]] = None, weights: Optional[list[float]] = None, -) -> "AnyQuantity": # noqa: F811 +) -> "TQuantity": # noqa: F811 """Sum distinct `quantities` by `weights`. Parameters @@ -428,7 +434,7 @@ def combine( @singledispatch -def concat(*objs: "AnyQuantity", **kwargs) -> "AnyQuantity": +def concat(*objs: "TQuantity", **kwargs) -> "TQuantity": """Concatenate Quantity `objs`. Any strings included amongst `objs` are discarded, with a logged warning; these @@ -472,7 +478,7 @@ def concat(*objs: "AnyQuantity", **kwargs) -> "AnyQuantity": return objs[0]._keep(result, name=True, **to_keep) -def convert_units(qty: "AnyQuantity", units: UnitLike) -> "AnyQuantity": +def convert_units(qty: "TQuantity", units: UnitLike) -> "TQuantity": """Convert magnitude of `qty` from its current units to `units`. Parameters @@ -499,9 +505,7 @@ def convert_units(qty: "AnyQuantity", units: UnitLike) -> "AnyQuantity": return qty._keep(qty * factor, name=True, attrs=True, units=new_units) -def disaggregate_shares( - quantity: "AnyQuantity", shares: "AnyQuantity" -) -> "AnyQuantity": +def disaggregate_shares(quantity: "TQuantity", shares: "TQuantity") -> "TQuantity": """Deprecated: Disaggregate `quantity` by `shares`. This operator is identical to :func:`mul`; use :func:`mul` and its helper instead. @@ -510,9 +514,7 @@ def disaggregate_shares( @Operator.define(helper=add_binop) -def div( - numerator: Union["AnyQuantity", float], denominator: "AnyQuantity" -) -> "AnyQuantity": +def div(numerator: Union["TQuantity", float], denominator: "TQuantity") -> "TQuantity": """Compute the ratio `numerator` / `denominator`. Parameters @@ -534,15 +536,15 @@ def div( def drop_vars( - qty: "AnyQuantity", + qty: "TQuantity", names: Union[ str, Iterable[Hashable], - Callable[["AnyQuantity"], Union[str, Iterable[Hashable]]], + Callable[["TQuantity"], Union[str, Iterable[Hashable]]], ], *, errors="raise", -) -> "AnyQuantity": +) -> "TQuantity": """Return a Quantity with dropped variables (coordinates). Like :meth:`xarray.DataArray.drop_vars`. @@ -550,23 +552,36 @@ def drop_vars( return qty.drop_vars(names) -def group_sum(qty: "AnyQuantity", group: str, sum: str) -> "AnyQuantity": +def expand_dims( + qty: "types.TQuantity", + dim: Union[Hashable, Sequence[Hashable], Mapping[Any, Any], None] = None, + axis: Union[int, Sequence[int], None] = None, + create_index_for_new_dim: bool = True, + **dim_kwargs: Any, +) -> "types.TQuantity": + """Return a new object with (an) additional dimension(s). + + Like :meth:`xarray.DataArray.expand_dims`. + """ + return qty.expand_dims(dim, axis, create_index_for_new_dim, **dim_kwargs) + + +def group_sum(qty: "TQuantity", group: str, sum: str) -> "TQuantity": """Group by dimension `group`, then sum across dimension `sum`. The result drops the latter dimension. """ - kw = dict(squeeze=False) if isinstance(qty, SparseDataArray) else {} return concat( - *[values.sum(dim=[sum]) for _, values in qty.groupby(group, **kw)], # type: ignore [arg-type] + *[cast("TQuantity", values.sum(dim=[sum])) for _, values in qty.groupby(group)], dim=group, ) def index_to( - qty: "AnyQuantity", + qty: "TQuantity", dim_or_selector: Union[str, Mapping], label: Optional[Hashable] = None, -) -> "AnyQuantity": +) -> "TQuantity": """Compute an index of `qty` against certain of its values. If the label is not provided, :func:`index_to` uses the label in the first position @@ -605,13 +620,13 @@ def index_to( def interpolate( - qty: "AnyQuantity", + qty: "TQuantity", coords: Optional[Mapping[Hashable, Any]] = None, method: "types.InterpOptions" = "linear", assume_sorted: bool = True, kwargs: Optional[Mapping[str, Any]] = None, **coords_kwargs: Any, -) -> "AnyQuantity": +) -> "TQuantity": """Interpolate `qty`. For the meaning of arguments, see :meth:`xarray.DataArray.interp`. When @@ -777,7 +792,7 @@ def _load_file_csv( @Operator.define(helper=add_binop) -def mul(*quantities: "AnyQuantity") -> "AnyQuantity": +def mul(*quantities: "TQuantity") -> "TQuantity": """Compute the product of any number of `quantities`. See also @@ -793,7 +808,7 @@ def mul(*quantities: "AnyQuantity") -> "AnyQuantity": product = mul -def pow(a: "AnyQuantity", b: Union["AnyQuantity", int]) -> "AnyQuantity": +def pow(a: "TQuantity", b: Union["TQuantity", int]) -> "TQuantity": """Compute `a` raised to the power of `b`. Returns @@ -807,11 +822,38 @@ def pow(a: "AnyQuantity", b: Union["AnyQuantity", int]) -> "AnyQuantity": return a**b +def random_qty(shape: dict[str, int], **kwargs) -> "AnyQuantity": + """Return a Quantity with `shape` and random contents. + + Parameters + ---------- + shape : dict + Mapping from dimension names (:class:`str`) to lengths along each dimension + (:class:`int`). + **kwargs + Other keyword arguments to :class:`.Quantity`. + + Returns + ------- + .Quantity + Random data with one dimension for each key in `shape`, and coords along those + dimensions like "foo1", "foo2", with total length matching the value from + `shape`. If `shape` is empty, a scalar (0-dimensional) Quantity. + """ + return genno.Quantity( + np.random.rand(*shape.values()) if len(shape) else np.random.rand(1)[0], + coords={ + dim: [f"{dim}{i}" for i in range(length)] for dim, length in shape.items() + }, + **kwargs, + ) + + def relabel( - qty: "AnyQuantity", + qty: "TQuantity", labels: Optional[Mapping[Hashable, Mapping]] = None, **dim_labels: Mapping, -) -> "AnyQuantity": +) -> "TQuantity": """Replace specific labels along dimensions of `qty`. Parameters @@ -861,10 +903,10 @@ def map_labels(mapper, values): def rename( - qty: "AnyQuantity", + qty: "TQuantity", new_name_or_name_dict: Union[Hashable, Mapping[Any, Hashable]] = None, **names: Hashable, -) -> "AnyQuantity": +) -> "TQuantity": """Returns a new Quantity with renamed dimensions or a new name. Like :meth:`.xarray.DataArray.rename`, and identical in behaviour to @@ -874,10 +916,10 @@ def rename( def rename_dims( - qty: "AnyQuantity", + qty: "TQuantity", name_dict: Union[Hashable, Mapping[Any, Hashable]] = None, **names: Hashable, -) -> "AnyQuantity": +) -> "TQuantity": """Returns a new Quantity with renamed dimensions or a new name. Like :meth:`.xarray.DataArray.rename`, and identical in behaviour to @@ -886,18 +928,18 @@ def rename_dims( return qty.rename(name_dict, **names) -def round(qty: "AnyQuantity", *args, **kwargs) -> "AnyQuantity": +def round(qty: "TQuantity", *args, **kwargs) -> "TQuantity": """Like :meth:`xarray.DataArray.round`.""" return qty.round(*args, **kwargs) def select( - qty: "AnyQuantity", + qty: "TQuantity", indexers: Mapping[Hashable, Iterable[Hashable]], *, inverse: bool = False, drop: bool = False, -) -> "AnyQuantity": +) -> "TQuantity": """Select from `qty` based on `indexers`. Parameters @@ -953,7 +995,7 @@ def select( @Operator.define(helper=add_binop) -def sub(a: "AnyQuantity", b: "AnyQuantity") -> "AnyQuantity": +def sub(a: "TQuantity", b: "TQuantity") -> "TQuantity": """Subtract `b` from `a`. See also @@ -965,10 +1007,10 @@ def sub(a: "AnyQuantity", b: "AnyQuantity") -> "AnyQuantity": @Operator.define() def sum( - quantity: "AnyQuantity", - weights: Optional["AnyQuantity"] = None, + quantity: "TQuantity", + weights: Optional["TQuantity"] = None, dimensions: Optional[list[str]] = None, -) -> "AnyQuantity": +) -> "TQuantity": """Sum `quantity` over `dimensions`, with optional `weights`. Parameters @@ -981,8 +1023,8 @@ def sum( dimensions. """ if weights is None: - _w: "AnyQuantity" = genno.Quantity(1.0) - w_total: "AnyQuantity" = genno.Quantity(1.0) + _w: "TQuantity" = genno.Quantity(1.0) + w_total: "TQuantity" = genno.Quantity(1.0) else: _w, w_total = weights, weights.sum(dim=dimensions) if w_total.shape == (): @@ -1012,8 +1054,8 @@ def add_sum( def unique_units_from_dim( - qty: "AnyQuantity", dim: str, *, fail: Union[str, int] = "raise" -) -> "AnyQuantity": + qty: "TQuantity", dim: str, *, fail: Union[str, int] = "raise" +) -> "TQuantity": """Assign :attr:`.Quantity.units` using coords from the dimension `dim`. The dimension `dim` is dropped from the result. @@ -1051,13 +1093,31 @@ def unique_units_from_dim( def where( - qty: "AnyQuantity", cond: Any, other: Any = dtypes.NA, drop: bool = False -) -> "AnyQuantity": + qty: "TQuantity", cond: Any, other: Any = dtypes.NA, drop: bool = False +) -> "TQuantity": """Call :meth:`.Quantity.where`.""" return qty.where(cond, other, drop) -def _format_header_comment(value: str) -> str: +def wildcard_qty(value, units, dims: Sequence[Hashable]) -> "AnyQuantity": + """Return a Quantity with 1 label "*" along each of `dims`.""" + if genno.Quantity is SparseDataArray: + # Convert `value` into a list-of-lists of appropriate depth + value = reduce(lambda x, y: [x], range(len(dims)), value) + return genno.Quantity(value, coords={d: ["*"] for d in dims}, units=units) + + +def _format_header_comment(kwargs) -> str: + value = kwargs.pop("header_comment", "") + + if kwargs.pop("header_datetime", False): + tz = datetime.now().astimezone().tzinfo + value += os.linesep + f"Generated: {datetime.now(tz).isoformat()}" + os.linesep + + units = kwargs.pop("units") + if kwargs.pop("header_units", False): + value += os.linesep + f"Units: {units}" + os.linesep + if not len(value): return value @@ -1125,12 +1185,13 @@ def _( kwargs.setdefault("index", False) with open(path, "wb") as f: - f.write(_format_header_comment(kwargs.pop("header_comment", "")).encode()) + f.write(_format_header_comment(kwargs).encode()) quantity.to_csv(f, **kwargs) elif path.suffix == ".xlsx": kwargs = kwargs or dict() kwargs.setdefault("merge_cells", False) kwargs.setdefault("index", False) + kwargs.pop("units", None) quantity.to_excel(path, **kwargs) else: @@ -1145,4 +1206,6 @@ def _( kwargs: Optional[dict] = None, ) -> None: # Convert the Quantity to a pandas.DataFrame, then write + kwargs = kwargs or dict() + kwargs.setdefault("units", f"{quantity.units:~}") write_report(quantity.to_dataframe().reset_index(), path, kwargs) diff --git a/genno/testing/__init__.py b/genno/testing/__init__.py index 862ffa91..d58cf134 100644 --- a/genno/testing/__init__.py +++ b/genno/testing/__init__.py @@ -417,33 +417,6 @@ def assert_units(qty: "AnyQuantity", exp: str) -> None: ) -def random_qty(shape: dict[str, int], **kwargs) -> "AnyQuantity": - """Return a Quantity with `shape` and random contents. - - Parameters - ---------- - shape : dict - Mapping from dimension names (:class:`str`) to lengths along each dimension - (:class:`int`). - **kwargs - Other keyword arguments to :class:`.Quantity`. - - Returns - ------- - .Quantity - Random data with one dimension for each key in `shape`, and coords along those - dimensions like "foo1", "foo2", with total length matching the value from - `shape`. If `shape` is empty, a scalar (0-dimensional) Quantity. - """ - return genno.Quantity( - np.random.rand(*shape.values()) if len(shape) else np.random.rand(1)[0], - coords={ - dim: [f"{dim}{i}" for i in range(length)] for dim, length in shape.items() - }, - **kwargs, - ) - - def raises_or_warns(value, *args, **kwargs) -> ContextManager: """Context manager for tests that :func:`.pytest.raises` or :func:`.pytest.warns`. @@ -552,3 +525,19 @@ def quantity_is_sparsedataarray(request): yield finally: set_class(pre) + + +def __getattr__(name): + if name == "random_qty": + from warnings import warn + + warn( + "Import random_qty from genno.testing; import from genno.operator instead", + DeprecationWarning, + stacklevel=2, + ) + + from genno.operator import random_qty + + return random_qty + raise AttributeError(name) diff --git a/genno/tests/compat/test_sdmx.py b/genno/tests/compat/test_sdmx.py index 0e439f19..e7c05bd1 100644 --- a/genno/tests/compat/test_sdmx.py +++ b/genno/tests/compat/test_sdmx.py @@ -8,6 +8,22 @@ from genno.compat.sdmx import operator from genno.testing import add_test_data +VERSION = (None, Version["2.1"], Version["3.0"], "2.1", "3.0") + + +@pytest.fixture(scope="session") +def dm(test_data_path, dsd): + # Read the data message + yield sdmx.read_sdmx(test_data_path.joinpath("22_289.xml"), structure=dsd) + + +@pytest.fixture(scope="session") +def dsd(test_data_path): + # Read the data structure definition + yield sdmx.read_sdmx(test_data_path.joinpath("22_289-structure.xml")).structure[ + "DCIS_POPRES1" + ] + def test_codelist_to_groups() -> None: c = Computer() @@ -44,18 +60,26 @@ def test_codelist_to_groups() -> None: assert {"foo", "bar"} == set(result1.coords["t"].data) -@pytest.fixture(scope="session") -def dsd(test_data_path): - # Read the data structure definition - yield sdmx.read_sdmx(test_data_path.joinpath("22_289-structure.xml")).structure[ - "DCIS_POPRES1" - ] +@pytest.mark.parametrize( + "kwargs, cl0_id", + ( + (dict(), "X"), + (dict(id_transform=None), "x"), + ), +) +def test_coords_to_codelists(kwargs, cl0_id: str) -> None: + q_in = genno.operator.random_qty(dict(x=3, y=4, z=5)) + result = operator.coords_to_codelists(q_in, **kwargs) -@pytest.fixture(scope="session") -def dm(test_data_path, dsd): - # Read the data message - yield sdmx.read_sdmx(test_data_path.joinpath("22_289.xml"), structure=dsd) + # Result is a sequence of Codelist objects + assert len(q_in.dims) == len(result) + cl0 = result[0] + assert isinstance(cl0, Codelist) + + # Code list has the expected ID and items + assert cl0_id == cl0.id + assert {"x0": Code(id="x0"), "x1": Code(id="x1"), "x2": Code(id="x2")} == cl0.items def test_dataset_to_quantity(dsd, dm) -> None: @@ -78,9 +102,6 @@ def test_dataset_to_quantity(dsd, dm) -> None: assert len(ds.obs) == result.size -VERSION = (None, Version["2.1"], Version["3.0"], "2.1", "3.0") - - @pytest.mark.parametrize("observation_dimension", (None, "TIME_PERIOD")) @pytest.mark.parametrize("version", VERSION) @pytest.mark.parametrize("with_attrs", (True, False)) diff --git a/genno/tests/core/test_attrseries.py b/genno/tests/core/test_attrseries.py index 0793ed96..07f234b7 100644 --- a/genno/tests/core/test_attrseries.py +++ b/genno/tests/core/test_attrseries.py @@ -120,11 +120,21 @@ def test_squeeze(self, foo) -> None: https://github.com/iiasa/message_ix/issues/788 """ # Squeeze the length-1 dimension "a" - result = foo.sel(a=["a1"]).squeeze() + result0 = foo.sel(a=["a1"]).squeeze() # Result is 1-D but multi-indexed - assert 1 == len(result.dims) - assert isinstance(result.index, pd.MultiIndex) + assert 1 == len(result0.dims) + assert isinstance(result0.index, pd.MultiIndex) + + # Squeeze only 1 of >1 dimensions with length + result1 = foo.sel(a=["a1"], b=["b2"]).squeeze("a") + assert 1 == len(result1.dims) + + # Squeeze both + result2 = foo.sel(a=["a1"], b=["b2"]).squeeze() + assert 0 == len(result2.dims) + result3 = foo.sel(a=["a1"], b=["b2"]).squeeze(dim=["a", "b"]) + assert 0 == len(result3.dims) def test_sum(self, foo, bar): # AttrSeries can be summed across all dimensions diff --git a/genno/tests/core/test_computer.py b/genno/tests/core/test_computer.py index b4c79c8c..dcf49fc9 100644 --- a/genno/tests/core/test_computer.py +++ b/genno/tests/core/test_computer.py @@ -1,6 +1,7 @@ import logging import re from functools import partial +from typing import TYPE_CHECKING, Generator import numpy as np import pandas as pd @@ -17,6 +18,7 @@ operator, ) from genno.compat.pint import ApplicationRegistry +from genno.core.key import single_key from genno.testing import ( add_dantzig, add_test_data, @@ -24,6 +26,9 @@ assert_qty_equal, ) +if TYPE_CHECKING: + from genno.types import TQuantity + log = logging.getLogger(__name__) @@ -73,6 +78,36 @@ def test_add_aggregate(self, c): agg3 = c.get(key3) assert set(agg3.coords["t"].values) == set(t_groups.keys()) + def test_add_div_dims(self, c: Computer) -> None: + """Dimensions are inferred when :meth:`.add`-ing a :func:`.div` task.""" + c["X:a-b"] = (None,) + c["Y:b-c"] = (None,) + + key = single_key(c.add("Z", "div", "X:a-b", "Y:b-c")) + assert set("abc") == set(key.dims) + + def test_add_single(self, c: Computer) -> None: + """:meth:`.add_single` unwraps a single :class:`.Key`.""" + foo = Key("foo:a-b-c") + bar = Key("bar:x-y-z") + + # Python built-in type stored as-is + c.add_single(foo, 1.0) + assert c.graph[foo] == 1.0 + + # Key also stored as-is + c.add_single(bar, foo) + assert c.graph[bar] is foo + + def test_add_warn(self, recwarn, c: Computer) -> None: + # No warning emitted with DEFAULT_WARN_ON_RESULT_TUPLE = False + assert 0 == len(recwarn) + + # Warning emitted when configured + c.configure(config={"warn on result tuple": True}) + with pytest.warns(FutureWarning, match="Return 8-tuple from Computer.add"): + c.add("foo:x-y-z", None, sums=True) + @pytest.mark.parametrize("suffix", [".json", ".yaml"]) def test_configure(self, test_data_path, c: Computer, suffix) -> None: # Configuration can be read from file @@ -85,6 +120,22 @@ def test_configure(self, test_data_path, c: Computer, suffix) -> None: with pytest.raises(ValueError, match="cannot give both"): c.configure(path, config={"path": path}) + def test_contains(self) -> None: + """:meth:`Computer.__contains__` works regardless of dimension order.""" + c = Computer() + + c.add("a:x-y", 1) + assert "a:x-y" in c + assert "a:y-x" in c + assert Key("a:x-y") in c + assert Key("a:y-x") in c + + c.add(Key("b:z-y-x"), 1) + assert "b:x-y-z" in c + assert "b:y-x-z" in c + assert Key("b:x-y-z") in c + assert Key("b:y-x-z") in c + def test_deprecated_add_file(self, tmp_path, c): # Path to a temporary file p = tmp_path / "foo.csv" @@ -193,6 +244,94 @@ def func(qty): with pytest.raises(TypeError): c.disaggregate("x:", "d", method=None) + @pytest.fixture + def c2(self, c) -> Generator[Computer, None, None]: + import genno + + c.add("A:x-y", genno.Quantity([1.0], coords={"x": ["x0"], "y": ["y0"]})) + c.add("B:y-z", genno.Quantity([1.0], coords={"y": ["y0"], "z": ["z0"]})) + c.add("C", "mul", "A:x-y", "B:y-z") + yield c + + def test_duplicate(self, c2): + """Test :meth:`.Computer.duplicate`.""" + N = len(c2.graph) + + k1 = c2.full_key("C") + + # Method runs without error + k2 = c2.duplicate(k1, "duplicated") + + # 3 keys/tasks have been added + assert N + 3 == len(c2.graph) + + # Added tasks have derived keys + k2_desc = c2.describe(k2) + assert "'A:x-y:duplicated'" in k2_desc + assert "'B:y-z:duplicated'" in k2_desc + assert "'C:x-y-z:duplicated'" in k2_desc + + # Original tasks are not modified + k1_desc = c2.describe(k1) + assert "'A:x-y'" in k1_desc + assert "'B:y-z'" in k1_desc + assert "'C:x-y-z'" in k1_desc + + # Both the original and duplicated keys can be computed + c2["check"] = ([k1, k2],) + result = c2.get("check") + + # The results are identical + assert_qty_equal(result[0], result[1]) + + def test_insert0(self, caplog, c2) -> None: + def inserted(qty: "TQuantity", *, x, y) -> "TQuantity": + log.info(f"Inserted function, {x=} {y=}") + return x * qty + + # print(c2.describe("C")) # DEBUG + c2.insert("A:x-y", inserted, ..., x=2.0, y="foo") + # print(c2.describe("C")) # DEBUG + + with caplog.at_level(logging.INFO): + # Result can be obtained + result = c2.get("C") + + # Inserted function/operator ran, generating a log message and altering the + # result + assert ["Inserted function, x=2.0 y='foo'"] == caplog.messages + assert 2.0 == result.item() + + def test_insert1(self, caplog, c2) -> None: + def inserted(qty: "TQuantity", *, x, y) -> "TQuantity": # pragma: no cover + log.info(f"Inserted function, {x=} {y=}") + return x * qty + + # Key to be inserted already exists + c2.add("A:x-y:pre", None) + with pytest.raises(KeyExistsError): + c2.insert("A:x-y", inserted, ..., x=2.0, y="foo") + + # Too few positional arguments + with pytest.raises(ValueError, match="Must supply at least 2 args"): + c2.insert("A:x-y", tag="foo") + with pytest.raises(ValueError, match="Must supply at least 2 args"): + c2.insert("A:x-y", inserted, tag="foo") + + # 2+ positional arguments, but without `...` + with pytest.raises(ValueError, match=r"One arg must be '\.\.\.'; got"): + c2.insert("A:x-y", inserted, "bla", tag="foo") + + # Incorrect kwargs + with pytest.raises(TypeError, match="unexpected keyword argument 'z'"): + c2.insert("A:x-y", inserted, ..., tag="foo", z="not_an_arg") + + def test_setitem(self, c2) -> None: + c2["D"] = "add", "A:x-y", "B:y-z", dict(sums=True) + + result = c2.get("D:x-y-z") + assert set("xyz") == set(result.dims) + def test_cache(caplog, tmp_path, test_data_path, ureg): caplog.set_level(logging.INFO) @@ -277,15 +416,6 @@ def myfunc2(*args, **kwargs): assert "'cache_path' configuration not set; using " in caplog.messages[0] -def test_contains(): - """:meth:`Computer.__contains__` works regardless of dimension order.""" - c = Computer() - c.add("a:x-y", 1) - - assert "a:x-y" in c - assert "a:y-x" in c - - def test_eval(ureg): c = Computer() add_test_data(c) diff --git a/genno/tests/core/test_describe.py b/genno/tests/core/test_describe.py index b81fc484..c23591c2 100644 --- a/genno/tests/core/test_describe.py +++ b/genno/tests/core/test_describe.py @@ -40,6 +40,18 @@ def test_describe_shorten(): ) +def test_describe_cyclic(): + """Test that :meth:`.describe` works on a cyclic graph without RecursionError.""" + + c = Computer() + + c["x"] = "mul", 1, 2 + c["y"] = "concat", "x", "z:" + c["z"] = "add", "y", 0.5 + + assert "CYCLE DETECTED" in c.describe("z") + + def test_label(): """:func:`label` handles unusual callables.""" assert "operator.itemgetter(0)" == label(itemgetter(0)) diff --git a/genno/tests/core/test_graph.py b/genno/tests/core/test_graph.py index 683a72b2..db8e5735 100644 --- a/genno/tests/core/test_graph.py +++ b/genno/tests/core/test_graph.py @@ -11,11 +11,25 @@ def g(self): g["foo:c-b-a"] = 1 yield g - def test_contains(self, g) -> None: + def test_contains0(self, g) -> None: """__contains__ handles incompatible types, returning False.""" q = Quantity() assert (q in g) is False + def test_contains1(self, g) -> None: + """__contains__ handles compatible types.""" + # Compare to a key originally str and unsorted + assert ("foo:c-b-a" in g) is True + assert (Key("foo:c-b-a") in g) is True + assert (Key("foo:a-b-c") in g) is True + + # Compare to a key originally Key and sorted + g[Key("bar:x-y-z")] = None + assert ("bar:x-y-z" in g) is True + assert ("bar:z-x-y" in g) is True + assert (Key("bar:x-y-z") in g) is True + assert (Key("bar:z-x-y") in g) is True + def test_delitem(self, g) -> None: assert Key("foo", "cba") == g.full_key("foo") del g["foo:c-b-a"] diff --git a/genno/tests/core/test_key.py b/genno/tests/core/test_key.py index 1a6cfe5f..b8705b20 100644 --- a/genno/tests/core/test_key.py +++ b/genno/tests/core/test_key.py @@ -1,6 +1,6 @@ import pytest -from genno import Key, KeySeq +from genno import Key, Keys, KeySeq from genno.core.key import iter_keys, single_key from genno.testing import raises_or_warns @@ -104,6 +104,41 @@ def test_drop(self): def test_eq(self): assert False is (Key("x:a-b-c") == 3.4) + def test_generated(self) -> None: + k = Key("A:x") + + # Generate some related keys + k[3] + k["baz"] + k[2] + k["bar"] + k[1] + + exp = tuple(map(Key, ["A:x:3", "A:x:baz", "A:x:2", "A:x:bar", "A:x:1"])) + assert exp == k.generated + + def test_getitem(self) -> None: + k = Key("foo:x-y-z:bar") + + # __getitem__ works with str argument + assert "foo:x-y-z:bar+baz" == k["baz"] + assert "foo:x-y-z:bar+qux" == k["qux"] + + # __getitem__ works with int argument + assert "foo:x-y-z:bar+0" == k[0] + assert "foo:x-y-z:bar+1" == k[1] + + assert "foo:x-y-z:bar+2" == next(k) + assert "foo:x-y-z:bar+2" == k.last + + def test_hash(self) -> None: + k1 = Key("x:a-b-c") + k2 = Key("x:c-b-a") + + d = {k1: None} + + assert k2 in d + def test_operations(self): key = Key("x:a-b-c") @@ -137,6 +172,69 @@ def test_operations(self): with pytest.raises(TypeError): key / 3.3 + def test_sorted(self) -> None: + k1 = Key("foo", "abc") + k2 = Key("foo", "cba") + + # Keys with same dimensions, ordered differently, compare equal + assert k1 == k2 + + # Ordered returns a key with sorted dimensions + assert k1.dims == k2.sorted.dims + + # Keys compare equal to an equivalent string and to one another + assert k1 == "foo:b-a-c" == k2 == "foo:b-c-a" + + # Keys hash equal to a string with sorted dimensions + assert hash("foo:a-b-c") == hash(k1) == hash(k2) + + # `k2` does not hash equal to its own (unsorted) string representation + assert hash(k2) != hash(str(k2)) + + +class TestKeys: + """:class:`.Keys` behaves as expected.""" + + @pytest.fixture(scope="function") + def keys(self) -> Keys: + return Keys(foo=Key("foo:a-b-c"), bar="bar:a-b-c") + + def test_init(self, keys: Keys) -> None: + """:class:`.Keys` can be initialized with :any:`.KeyLike`.""" + assert isinstance(keys.foo, Key) and isinstance(keys.bar, Key) + + def test_delattr(self, keys: Keys) -> None: + """Keys can be deleted.""" + del keys.bar + + with pytest.raises(AttributeError): + keys.bar + + def test_getattr(self, keys: Keys) -> None: + """Keys can be accessed and used.""" + assert "foo:a-b-c:0" == keys.foo[0] + + # Binary operations work + assert "foo:a-c" == keys.foo / "b" + assert "foo:a-b-c-d" == keys.foo * "d" + assert "foo:a-b-c:tag" == keys.foo + "tag" + + def test_repr(self, keys: Keys) -> None: + keys.baz = Key("it's confusing:m-n-o-p") + # repr() does not include the Key.name, but the name in the namespace + assert "<3 keys: bar baz foo>" == repr(keys) + + def test_setattr(self, keys: Keys) -> None: + """Keys can be set and updated.""" + # Update an existing name + keys.bar = Key("bar:x-y-z") + # Update occurred + assert "bar:x-y-z" == keys.bar + + # New key + keys.baz = Key("baz:c-b-a") + assert "baz:a-b-c" == keys.baz + class TestKeySeq: @pytest.fixture @@ -204,26 +302,7 @@ def test_key_ops(self, ks) -> None: assert "foo:x-z:bar" == (ks / "y").base -def test_sorted(): - k1 = Key("foo", "abc") - k2 = Key("foo", "cba") - - # Keys with same dimensions, ordered differently, compare equal - assert k1 == k2 - - # Ordered returns a key with sorted dimensions - assert k1.dims == k2.sorted.dims - - # Keys compare equal to an equivalent string and to one another - assert k1 == "foo:b-a-c" == k2 == "foo:b-c-a" - - # Keys do not hash equal - assert hash(k1) == hash("foo:a-b-c") - assert hash(k2) == hash("foo:c-b-a") - assert hash(k1) != hash(k2) - - -def test_gt_lt(): +def test_gt_lt() -> None: """Test :meth:`Key.__gt__` and :meth:`Key.__lt__`.""" k = Key("foo", "abd") assert k > "foo:a-b-c" @@ -239,17 +318,17 @@ def test_gt_lt(): assert k > 1.1 -def test_iter_keys(): +def test_iter_keys() -> None: # Non-iterable with pytest.raises(TypeError): - next(iter_keys(1.2)) + next(iter_keys(1.2)) # type: ignore [arg-type] # Iterable containing non-keys with pytest.raises(TypeError): - list(iter_keys([Key("a"), Key("b"), 1.2])) + list(iter_keys([Key("a"), Key("b"), 1.2])) # type: ignore [arg-type] -def test_single_key(): +def test_single_key() -> None: # Single key is unpacked k = Key("a") result = single_key((k,)) @@ -257,7 +336,7 @@ def test_single_key(): # Tuple containing 1 non-key with pytest.raises(TypeError): - single_key((1.2,)) + single_key((1.2,)) # type: ignore [arg-type] # Tuple containing >1 Keys with pytest.raises(TypeError): @@ -265,4 +344,4 @@ def test_single_key(): # Empty iterable with pytest.raises(TypeError): - single_key([]) + single_key([]) # type: ignore [arg-type] diff --git a/genno/tests/core/test_sparsedataarray.py b/genno/tests/core/test_sparsedataarray.py index bef41107..500236a2 100644 --- a/genno/tests/core/test_sparsedataarray.py +++ b/genno/tests/core/test_sparsedataarray.py @@ -10,7 +10,8 @@ import genno from genno import Computer from genno.core.sparsedataarray import SparseDataArray -from genno.testing import add_test_data, random_qty +from genno.operator import random_qty +from genno.testing import add_test_data sparse = pytest.importorskip( "sparse", diff --git a/genno/tests/test_operator.py b/genno/tests/test_operator.py index 0349cf68..a2e51822 100644 --- a/genno/tests/test_operator.py +++ b/genno/tests/test_operator.py @@ -4,6 +4,7 @@ from collections.abc import Hashable, Iterable, Mapping from contextlib import nullcontext from functools import partial +from itertools import compress import numpy as np import pandas as pd @@ -16,6 +17,7 @@ import genno from genno import Computer, operator, quote from genno.core.sparsedataarray import SparseDataArray +from genno.operator import random_qty from genno.testing import ( MARK, add_large_data, @@ -24,7 +26,6 @@ assert_qty_allclose, assert_qty_equal, assert_units, - random_qty, ) pytestmark = pytest.mark.usefixtures("parametrize_quantity_class") @@ -453,6 +454,63 @@ def test_drop_vars(data): assert set(x.dims) == {"t"} | set(result.dims) +@pytest.mark.parametrize( + "shape_in", + ( + dict(), # 0 dimensions + dict(x=2), # 1 dimension + dict(x=2, y=2, z=2), # >1 dimension + ), +) +def test_expand_dims0(shape_in): + q_in = random_qty(shape_in, units="kg") + + def _shape(q): + return dict(zip(q.dims, q.shape)) + + # …no arguments → no-op + result0 = operator.expand_dims(q_in, {}) + assert shape_in == _shape(result0) + + # …single hashable → 1 dim of length 1 + result1 = operator.expand_dims(q_in, "a") + assert dict(a=1) | shape_in == _shape(result1) + + # …iterable of dimension IDs + result2 = operator.expand_dims(q_in, tuple("ab")) + assert dict(a=1, b=1) | shape_in == _shape(result2) + assert_units(result2, "kg") + + # …dict of dimension lengths + result3 = operator.expand_dims(q_in, dict(a=2, b=3)) + assert dict(a=2, b=3) | shape_in == _shape(result3) + assert [0, 1] == result3.coords["a"].data.tolist() + + # …dict of dimension values + result4 = operator.expand_dims(q_in, dict(a=["a0", "a1"], b=["b0", "b1", "b2"])) + assert dict(a=2, b=3) | shape_in == _shape(result4) + assert ["a0", "a1"] == result4.coords["a"].data.tolist() + + # …dict of dimension with empty list values + result5 = operator.expand_dims(q_in, dict(a=[], b=[])) + with ( + pytest.raises(AssertionError) + if isinstance(q_in, SparseDataArray) + else nullcontext() + ): + assert dict(a=1, b=1) | shape_in == _shape(result5) + + +def test_expand_dims1() -> None: + q_in = random_qty(dict(), units="kg") + + with pytest.raises(TypeError): + operator.expand_dims(q_in, 1) + + with pytest.raises(ValueError): + operator.expand_dims(q_in, ["a", "b", "a"]) + + def test_group_sum(ureg): X = random_qty(dict(a=2, b=3), units=ureg.kg, name="Foo") @@ -961,6 +1019,13 @@ def test_where(data) -> None: assert x.units == result.units +def test_wildcard_qty() -> None: + result = operator.wildcard_qty(1.0, "dimensionless", "abc") + + assert set("abc") == set(result.dims) + assert all(c.data == ["*"] for c in result.coords.values()) + + def test_write_report0(tmp_path, data) -> None: p = tmp_path.joinpath("foo.txt") *_, x = data @@ -978,10 +1043,31 @@ def test_write_report0(tmp_path, data) -> None: assert "Hello, world!" == p.read_text() -def test_write_report1(tmp_path, data) -> None: +EXP_HEADER = r"""# Hello, world! +# $ +# Generated: 20..-..-..T..:..:...*$ +# $ +# Units: kg +# $""" + + +@pytest.mark.parametrize( + "kwargs, lines", + ( + (dict(), [1, 1]), + (dict(header_datetime=True), [1, 1, 1, 1]), + (dict(header_units=True), [1, 1, 0, 0, 1, 1]), + (dict(header_datetime=True, header_units=True), [1, 1, 1, 1, 1, 1]), + ), +) +def test_write_report1(tmp_path, data, kwargs, lines) -> None: p = tmp_path.joinpath("foo.csv") *_, x = data + # Compile the expected header + expr = re.compile("\n".join(compress(EXP_HEADER.splitlines(), lines)), flags=re.M) + # Header comment is written - operator.write_report(x, p, dict(header_comment="Hello, world!\n")) - assert p.read_text().startswith("# Hello, world!\n#") + operator.write_report(x, p, dict(header_comment="Hello, world!\n") | kwargs) + match = expr.match(p.read_text()) + assert match and 0 == match.pos diff --git a/genno/tests/test_testing.py b/genno/tests/test_testing.py index 4c0827bb..d1735a25 100644 --- a/genno/tests/test_testing.py +++ b/genno/tests/test_testing.py @@ -14,7 +14,7 @@ @pytest.mark.xfail() -def test_assert_logs(caplog): +def test_assert_logs(caplog) -> None: caplog.set_level(logging.DEBUG) with assert_logs(caplog, "foo"): @@ -23,11 +23,11 @@ def test_assert_logs(caplog): log.warning("spam and eggs") -def test_assert_units(): +def test_assert_units() -> None: assert_units(Quantity(), "") -def test_assert_check_type(): +def test_assert_check_type() -> None: """Mismatched types in :func:`assert_qty_equal` and :func:`assert_qty_allclose`.""" with pytest.raises(AssertionError): assert_qty_equal(int(1), 2.2) @@ -36,8 +36,13 @@ def test_assert_check_type(): assert_qty_allclose(int(1), 2.2) +def test_deprecated_import() -> None: + with pytest.warns(DeprecationWarning, match="random_qty"): + from genno.testing import random_qty # noqa:F401 + + @pytest.mark.xfail(raises=TypeError) -def test_runtest_makereport(): +def test_runtest_makereport() -> None: """The Pytest hook :func:`.pytest_runtest_makereport` works.""" c = Computer() diff --git a/genno/types.py b/genno/types.py index a8eff120..ad8a8f10 100644 --- a/genno/types.py +++ b/genno/types.py @@ -6,18 +6,22 @@ # pragma: exclude file from collections.abc import Hashable, Sequence -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, TypeVar, Union from pint import Unit from xarray.core.types import Dims, InterpOptions, ScalarOrArray +from .core.attrseries import AttrSeries from .core.key import KeyLike from .core.quantity import AnyQuantity +from .core.sparsedataarray import SparseDataArray if TYPE_CHECKING: # TODO Remove this block once Python 3.10 is the lowest supported version from typing import TypeAlias + from .core.key import Key + __all__ = [ "AnyQuantity", "Dims", @@ -25,8 +29,18 @@ "InterpOptions", "KeyLike", "ScalarOrArray", + "TKeyLike", + "TQuantity", "Unit", ] # Mirror the definition from pandas-stubs IndexLabel: "TypeAlias" = Union[Hashable, Sequence[Hashable]] + +#: Similar to :any:`KeyLike`, but as a variable that can be use to match function/method +#: outputs to inputs. +TKeyLike = TypeVar("TKeyLike", "Key", str) + +#: Similar to :any:`.AnyQuantity`, but as a variable that can be used to match function +#: /method outputs to inputs. +TQuantity = TypeVar("TQuantity", AttrSeries, SparseDataArray)