From 5efc297e212425959c8fd0930dfe42d733c3d309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 18 Dec 2024 17:50:36 +0100 Subject: [PATCH 01/20] Remove the MyPy override for river.base --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 346c94489a..88e1056af9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,6 @@ ignore_missing_imports = true [[tool.mypy.overrides]] # Disable strict mode for all non fully-typed modules module = [ - "river.base.*", "river.metrics.*", "river.utils.*", "river.stats.*", From 0bfd845475deefe43446fe5b59451b18599e322f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 8 Jan 2025 17:04:15 +0100 Subject: [PATCH 02/20] Automaticaly infer simple type annotations We use Autotyping to add annotations where they are obvious from the context. --- river/base/base.py | 14 +++++++------- river/base/classifier.py | 4 ++-- river/base/clusterer.py | 2 +- river/base/drift_detector.py | 6 +++--- river/base/ensemble.py | 6 +++--- river/base/estimator.py | 4 ++-- river/base/test_base.py | 16 ++++++++-------- river/base/transformer.py | 6 +++--- river/base/wrapper.py | 4 ++-- 9 files changed, 31 insertions(+), 31 deletions(-) diff --git a/river/base/base.py b/river/base/base.py index 16e4f829b3..1e03d8e24e 100644 --- a/river/base/base.py +++ b/river/base/base.py @@ -22,10 +22,10 @@ class Base: """ - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ - def __repr__(self): + def __repr__(self) -> str: return _repr_obj(obj=self) @classmethod @@ -71,7 +71,7 @@ def _get_params(self) -> dict[str, typing.Any]: return params - def clone(self, new_params: dict | None = None, include_attributes=False): + def clone(self, new_params: dict | None = None, include_attributes: bool = False): """Return a fresh estimator with the same parameters. The clone has the same parameters but has not been updated with any data. @@ -205,7 +205,7 @@ def is_class_param(param): def _mutable_attributes(self) -> set: return set() - def mutate(self, new_attrs: dict): + def mutate(self, new_attrs: dict) -> None: """Modify attributes. This changes parameters inplace. Although you can change attributes yourself, this is the @@ -297,7 +297,7 @@ def mutate(self, new_attrs: dict): """ def _mutate(obj, new_attrs): - def is_class_attr(name, attr): + def is_class_attr(name: str, attr): return hasattr(getattr(obj, name), "mutate") and isinstance(attr, dict) for name, attr in new_attrs.items(): @@ -336,7 +336,7 @@ def is_class_param(param): and isinstance(param[1], dict) ) - def find(params): + def find(params) -> bool: if not isinstance(params, dict): return False for name, param in params.items(): @@ -396,7 +396,7 @@ def _memory_usage(self) -> str: return utils.pretty.humanize_bytes(self._raw_memory_usage) -def _log_method_calls(self, name, class_condition, method_condition): +def _log_method_calls(self, name: str, class_condition, method_condition): method = object.__getattribute__(self, name) if ( not name.startswith("_") diff --git a/river/base/classifier.py b/river/base/classifier.py index 876bef4e13..38730c9003 100644 --- a/river/base/classifier.py +++ b/river/base/classifier.py @@ -69,11 +69,11 @@ def predict_one(self, x: dict, **kwargs) -> base.typing.ClfTarget | None: return None @property - def _multiclass(self): + def _multiclass(self) -> bool: return False @property - def _supervised(self): + def _supervised(self) -> bool: return True diff --git a/river/base/clusterer.py b/river/base/clusterer.py index 35a73a71ad..92514ce2dc 100644 --- a/river/base/clusterer.py +++ b/river/base/clusterer.py @@ -9,7 +9,7 @@ class Clusterer(estimator.Estimator): """A clustering model.""" @property - def _supervised(self): + def _supervised(self) -> bool: return False @abc.abstractmethod diff --git a/river/base/drift_detector.py b/river/base/drift_detector.py index a350c40d4b..56ae42d43b 100644 --- a/river/base/drift_detector.py +++ b/river/base/drift_detector.py @@ -20,7 +20,7 @@ class _BaseDriftDetector(base.Base): """ - def __init__(self): + def __init__(self) -> None: self._drift_detected = False def _reset(self) -> None: @@ -40,11 +40,11 @@ class _BaseDriftAndWarningDetector(_BaseDriftDetector): """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._warning_detected = False - def _reset(self): + def _reset(self) -> None: super()._reset() self._warning_detected = False diff --git a/river/base/ensemble.py b/river/base/ensemble.py index c88f5181ef..9426a44954 100644 --- a/river/base/ensemble.py +++ b/river/base/ensemble.py @@ -17,7 +17,7 @@ class Ensemble(UserList): """ - def __init__(self, models: Iterator[Estimator]): + def __init__(self, models: Iterator[Estimator]) -> None: super().__init__(models) if len(self) < self._min_number_of_models: @@ -27,7 +27,7 @@ def __init__(self, models: Iterator[Estimator]): ) @property - def _min_number_of_models(self): + def _min_number_of_models(self) -> int: return 2 @property @@ -49,7 +49,7 @@ class WrapperEnsemble(Ensemble, Wrapper): """ - def __init__(self, model, n_models, seed): + def __init__(self, model, n_models: int, seed: int) -> None: super().__init__(model.clone() for _ in range(n_models)) self.model = model self.n_models = n_models diff --git a/river/base/estimator.py b/river/base/estimator.py index 630b76d640..98aa2c8c9e 100644 --- a/river/base/estimator.py +++ b/river/base/estimator.py @@ -9,7 +9,7 @@ class Estimator(base.Base, abc.ABC): """An estimator.""" @property - def _supervised(self): + def _supervised(self) -> bool: """Indicates whether or not the estimator is supervised or not. This is useful internally for determining if an estimator expects to be provided with a `y` @@ -35,7 +35,7 @@ def __ror__(self, other): return other.__or__(self) return compose.Pipeline(other, self) - def _repr_html_(self): + def _repr_html_(self) -> str: from xml.etree import ElementTree as ET from river.base import viz diff --git a/river/base/test_base.py b/river/base/test_base.py index 412303b93b..c8ab67606a 100644 --- a/river/base/test_base.py +++ b/river/base/test_base.py @@ -3,7 +3,7 @@ from river import compose, datasets, linear_model, optim, preprocessing, stats, time_series -def test_clone_estimator(): +def test_clone_estimator() -> None: obj = linear_model.LinearRegression(l2=42) obj.learn_one({"x": 3}, 6) @@ -14,7 +14,7 @@ def test_clone_estimator(): assert new.weights != obj.weights -def test_clone_include_attributes(): +def test_clone_include_attributes() -> None: var = stats.Var() var.update(1) var.update(2) @@ -25,7 +25,7 @@ def test_clone_include_attributes(): assert var.clone(include_attributes=True)._S == 2 -def test_clone_pipeline(): +def test_clone_pipeline() -> None: obj = preprocessing.StandardScaler() | linear_model.LinearRegression(l2=42) obj.learn_one({"x": 3}, 6) @@ -37,7 +37,7 @@ def test_clone_pipeline(): assert new["LinearRegression"].weights != obj["LinearRegression"].weights -def test_clone_idempotent(): +def test_clone_idempotent() -> None: model = preprocessing.StandardScaler() | linear_model.LogisticRegression( optimizer=optim.Adam(), l2=0.1 ) @@ -53,7 +53,7 @@ def test_clone_idempotent(): clone.learn_one(x, y) -def test_memory_usage(): +def test_memory_usage() -> None: model = preprocessing.StandardScaler() | linear_model.LogisticRegression() # We can't test the exact value because it depends on the platform and the Python version @@ -61,7 +61,7 @@ def test_memory_usage(): assert isinstance(model._memory_usage, str) -def test_mutate(): +def test_mutate() -> None: """ >>> from river import datasets, linear_model, optim, preprocessing @@ -114,13 +114,13 @@ def test_mutate(): """ -def test_clone_positional_args(): +def test_clone_positional_args() -> None: assert compose.Select(1, 2, 3).clone().keys == {1, 2, 3} assert compose.Discard("a", "b", "c").clone().keys == {"a", "b", "c"} assert compose.SelectType(float, int).clone().types == (float, int) -def test_clone_nested_pipeline(): +def test_clone_nested_pipeline() -> None: model = time_series.SNARIMAX( p=2, d=1, diff --git a/river/base/transformer.py b/river/base/transformer.py index 16f9aba276..b4bab312fc 100644 --- a/river/base/transformer.py +++ b/river/base/transformer.py @@ -54,7 +54,7 @@ class Transformer(base.Estimator, BaseTransformer): """A transformer.""" @property - def _supervised(self): + def _supervised(self) -> bool: return False def learn_one(self, x: dict) -> None: @@ -78,7 +78,7 @@ class SupervisedTransformer(base.Estimator, BaseTransformer): """A supervised transformer.""" @property - def _supervised(self): + def _supervised(self) -> bool: return True def learn_one(self, x: dict, y: base.typing.Target) -> None: @@ -134,7 +134,7 @@ class MiniBatchSupervisedTransformer(Transformer): """A supervised transformer that can operate on mini-batches.""" @property - def _supervised(self): + def _supervised(self) -> bool: return True @abc.abstractmethod diff --git a/river/base/wrapper.py b/river/base/wrapper.py index d59bf1e232..559d9555a7 100644 --- a/river/base/wrapper.py +++ b/river/base/wrapper.py @@ -12,11 +12,11 @@ def _wrapped_model(self): """Provides access to the wrapped model.""" @property - def _labelloc(self): + def _labelloc(self) -> str: """Indicates location of the wrapper name when drawing pipelines.""" return "t" # for top - def __str__(self): + def __str__(self) -> str: return f"{type(self).__name__}({self._wrapped_model})" def _more_tags(self): From 3c45baec6a483374e3c1b9391b0fffe8feeca06c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 8 Jan 2025 19:32:18 +0100 Subject: [PATCH 03/20] Remove unused ignores --- river/base/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/river/base/base.py b/river/base/base.py index 1e03d8e24e..d34f4a5c67 100644 --- a/river/base/base.py +++ b/river/base/base.py @@ -384,7 +384,7 @@ def _raw_memory_usage(self) -> int: elif hasattr(obj, "__iter__") and not ( isinstance(obj, str) or isinstance(obj, bytes) or isinstance(obj, bytearray) ): - buffer.extend([i for i in obj]) # type: ignore + buffer.extend([i for i in obj]) return size @@ -487,7 +487,7 @@ def _repr_obj(obj, show_modules: bool = False, depth: int = 0) -> str: params = { name: getattr(obj, name) - for name, param in inspect.signature(obj.__init__).parameters.items() # type: ignore + for name, param in inspect.signature(obj.__init__).parameters.items() if not ( param.name == "args" and param.kind == param.VAR_POSITIONAL From b79e7480b6afce20002307d2e40d9d795e99f19a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 8 Jan 2025 19:33:06 +0100 Subject: [PATCH 04/20] Start adding type hints --- river/base/base.py | 26 ++++++++++++++------------ river/base/test_base.py | 1 + river/base/transformer.py | 4 ++-- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/river/base/base.py b/river/base/base.py index d34f4a5c67..382ec0bc35 100644 --- a/river/base/base.py +++ b/river/base/base.py @@ -10,6 +10,8 @@ import types import typing +import typing_extensions + class Base: """Base class that is inherited by the majority of classes in River. @@ -29,7 +31,7 @@ def __repr__(self) -> str: return _repr_obj(obj=self) @classmethod - def _unit_test_params(cls): + def _unit_test_params(cls) -> typing.Iterator[dict[str, typing.Any]]: """Instantiates an object with default arguments. Most parameters of each object have a default value. However, this isn't always the case, @@ -71,7 +73,7 @@ def _get_params(self) -> dict[str, typing.Any]: return params - def clone(self, new_params: dict | None = None, include_attributes: bool = False): + def clone(self, new_params: dict[str, typing.Any] | None = None, include_attributes: bool = False) -> typing_extensions.Self: """Return a fresh estimator with the same parameters. The clone has the same parameters but has not been updated with any data. @@ -167,7 +169,7 @@ def clone(self, new_params: dict | None = None, include_attributes: bool = False """ - def is_class_param(param): + def is_class_param(param) -> bool: # See expand_param_grid to understand why this is necessary return ( isinstance(param, tuple) @@ -202,10 +204,10 @@ def is_class_param(param): return clone @property - def _mutable_attributes(self) -> set: + def _mutable_attributes(self) -> set[str]: return set() - def mutate(self, new_attrs: dict) -> None: + def mutate(self, new_attrs: dict[str, typing.Any]) -> None: """Modify attributes. This changes parameters inplace. Although you can change attributes yourself, this is the @@ -296,8 +298,8 @@ def mutate(self, new_attrs: dict) -> None: """ - def _mutate(obj, new_attrs): - def is_class_attr(name: str, attr): + def _mutate(obj, new_attrs) -> None: + def is_class_attr(name: str, attr) -> bool: return hasattr(getattr(obj, name), "mutate") and isinstance(attr, dict) for name, attr in new_attrs.items(): @@ -318,7 +320,7 @@ def is_class_attr(name: str, attr): _mutate(obj=self, new_attrs=new_attrs) @property - def _is_stochastic(self): + def _is_stochastic(self) -> bool: """Indicates if the model contains an unset seed parameter. The convention in River is to control randomness by exposing a seed parameter. This seed @@ -329,14 +331,14 @@ def _is_stochastic(self): """ - def is_class_param(param): + def is_class_param(param) -> bool: return ( isinstance(param, tuple) and inspect.isclass(param[0]) and isinstance(param[1], dict) ) - def find(params) -> bool: + def find(params: dict[str, typing.Any]) -> bool: if not isinstance(params, dict): return False for name, param in params.items(): @@ -354,7 +356,7 @@ def _raw_memory_usage(self) -> int: import numpy as np - buffer = collections.deque([self]) + buffer: collections.deque[typing.Any] = collections.deque([self]) seen = set() size = 0 while len(buffer) > 0: @@ -369,7 +371,7 @@ def _raw_memory_usage(self) -> int: buffer.extend([k for k in obj.keys()]) buffer.extend([v for v in obj.values()]) elif hasattr(obj, "__dict__"): # Save object contents - contents: dict = vars(obj) + contents= vars(obj) size += sys.getsizeof(contents) buffer.extend([k for k in contents.keys()]) buffer.extend([v for v in contents.values()]) diff --git a/river/base/test_base.py b/river/base/test_base.py index c8ab67606a..82d818d4af 100644 --- a/river/base/test_base.py +++ b/river/base/test_base.py @@ -8,6 +8,7 @@ def test_clone_estimator() -> None: obj.learn_one({"x": 3}, 6) new = obj.clone({"l2": 21}) + assert type(new) is type(obj) assert new.l2 == 21 assert obj.l2 == 42 assert new.weights == {} diff --git a/river/base/transformer.py b/river/base/transformer.py index b4bab312fc..83bc623521 100644 --- a/river/base/transformer.py +++ b/river/base/transformer.py @@ -25,10 +25,10 @@ def __radd__(self, other): def __mul__(self, other): from river import compose - if isinstance(other, Transformer) or isinstance(other, compose.Pipeline): + if isinstance(other, BaseTransformer) or isinstance(other, compose.Pipeline): return compose.TransformerProduct(self, other) - return compose.Grouper(transformer=self, by=other) + return compose.Grouper(transformer=self, by=other) # type: ignore[arg-type] def __rmul__(self, other): """Creates a Grouper.""" From aa9140953993082ecc3250a1e6f164d07dc6862b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 21 Aug 2024 15:06:12 +0200 Subject: [PATCH 05/20] Correct a confusion in arguments list annotations Arbitrary argument lists should be annotated with the type of the arguments and not as a tuple containing the type; unless the argument should all be tuples. --- river/compose/select.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/river/compose/select.py b/river/compose/select.py index 088bfe714e..cbfa28f04c 100644 --- a/river/compose/select.py +++ b/river/compose/select.py @@ -42,7 +42,7 @@ class Discard(base.Transformer): """ - def __init__(self, *keys: tuple[base.typing.FeatureName]): + def __init__(self, *keys: base.typing.FeatureName): self.keys = set(keys) def transform_one(self, x): @@ -124,7 +124,7 @@ class Select(base.MiniBatchTransformer): """ - def __init__(self, *keys: tuple[base.typing.FeatureName]): + def __init__(self, *keys: base.typing.FeatureName): self.keys = set(keys) def transform_one(self, x): @@ -173,7 +173,7 @@ class SelectType(base.Transformer): """ - def __init__(self, *types: tuple[type]): + def __init__(self, *types: type): self.types = types def transform_one(self, x): From f7339f4055d5bc5a935af169aa540b237a2287fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 15 Jan 2025 17:05:14 +0100 Subject: [PATCH 06/20] Complete general annotations --- river/base/base.py | 14 +++++++------- river/base/drift_detector.py | 2 +- river/base/ensemble.py | 6 +++--- river/base/estimator.py | 11 +++++++---- river/base/transformer.py | 10 +++++----- river/base/viz.py | 17 ++++++++--------- river/base/wrapper.py | 9 +++++---- 7 files changed, 36 insertions(+), 33 deletions(-) diff --git a/river/base/base.py b/river/base/base.py index 382ec0bc35..82e30c2e71 100644 --- a/river/base/base.py +++ b/river/base/base.py @@ -169,7 +169,7 @@ def clone(self, new_params: dict[str, typing.Any] | None = None, include_attribu """ - def is_class_param(param) -> bool: + def is_class_param(param: typing.Any) -> bool: # See expand_param_grid to understand why this is necessary return ( isinstance(param, tuple) @@ -298,8 +298,8 @@ def mutate(self, new_attrs: dict[str, typing.Any]) -> None: """ - def _mutate(obj, new_attrs) -> None: - def is_class_attr(name: str, attr) -> bool: + def _mutate(obj: typing.Any, new_attrs: dict[str, typing.Any]) -> None: + def is_class_attr(name: str, attr: typing.Any) -> bool: return hasattr(getattr(obj, name), "mutate") and isinstance(attr, dict) for name, attr in new_attrs.items(): @@ -331,7 +331,7 @@ def _is_stochastic(self) -> bool: """ - def is_class_param(param) -> bool: + def is_class_param(param: typing.Any) -> bool: return ( isinstance(param, tuple) and inspect.isclass(param[0]) @@ -398,7 +398,7 @@ def _memory_usage(self) -> str: return utils.pretty.humanize_bytes(self._raw_memory_usage) -def _log_method_calls(self, name: str, class_condition, method_condition): +def _log_method_calls(self: typing.Any, name: str, class_condition: typing.Callable[[typing.Any], bool], method_condition: typing.Callable[[typing.Any], bool]) -> typing.Any: method = object.__getattribute__(self, name) if ( not name.startswith("_") @@ -414,7 +414,7 @@ def _log_method_calls(self, name: str, class_condition, method_condition): def log_method_calls( class_condition: typing.Callable[[typing.Any], bool] | None = None, method_condition: typing.Callable[[typing.Any], bool] | None = None, -): +) -> collections.abc.Iterator[None]: """A context manager to log method calls. All method calls will be logged by default. This behavior can be overriden by passing filtering @@ -479,7 +479,7 @@ def log_method_calls( Base.__getattribute__ = old # type: ignore -def _repr_obj(obj, show_modules: bool = False, depth: int = 0) -> str: +def _repr_obj(obj: typing.Any, show_modules: bool = False, depth: int = 0) -> str: """Return a pretty representation of an object.""" rep = f"{obj.__class__.__name__} (" diff --git a/river/base/drift_detector.py b/river/base/drift_detector.py index 56ae42d43b..ab76036e64 100644 --- a/river/base/drift_detector.py +++ b/river/base/drift_detector.py @@ -49,7 +49,7 @@ def _reset(self) -> None: self._warning_detected = False @property - def warning_detected(self): + def warning_detected(self) -> bool: """Whether or not a drift is detected following the last update.""" return self._warning_detected diff --git a/river/base/ensemble.py b/river/base/ensemble.py index 9426a44954..9efbff3f0c 100644 --- a/river/base/ensemble.py +++ b/river/base/ensemble.py @@ -31,7 +31,7 @@ def _min_number_of_models(self) -> int: return 2 @property - def models(self): + def models(self) -> list: return self.data @@ -49,7 +49,7 @@ class WrapperEnsemble(Ensemble, Wrapper): """ - def __init__(self, model, n_models: int, seed: int) -> None: + def __init__(self, model: Estimator, n_models: int, seed: int | None) -> None: super().__init__(model.clone() for _ in range(n_models)) self.model = model self.n_models = n_models @@ -57,5 +57,5 @@ def __init__(self, model, n_models: int, seed: int) -> None: self._rng = Random(seed) @property - def _wrapped_model(self): + def _wrapped_model(self) -> Estimator: return self.model diff --git a/river/base/estimator.py b/river/base/estimator.py index 98aa2c8c9e..ebc6372805 100644 --- a/river/base/estimator.py +++ b/river/base/estimator.py @@ -1,8 +1,11 @@ from __future__ import annotations import abc +from typing import Any +from collections.abc import Iterator from . import base +from river import compose class Estimator(base.Base, abc.ABC): @@ -19,7 +22,7 @@ def _supervised(self) -> bool: """ return True - def __or__(self, other): + def __or__(self, other: Estimator | compose.Pipeline) -> compose.Pipeline: """Merge with another Transformer into a Pipeline.""" from river import compose @@ -27,7 +30,7 @@ def __or__(self, other): return other.__ror__(self) return compose.Pipeline(self, other) - def __ror__(self, other): + def __ror__(self, other: Estimator | compose.Pipeline) -> compose.Pipeline: """Merge with another Transformer into a Pipeline.""" from river import compose @@ -71,7 +74,7 @@ def _tags(self) -> dict[str, bool]: return tags @classmethod - def _unit_test_params(self): + def _unit_test_params(self) -> Iterator[dict[str, Any]]: """Indicates which parameters to use during unit testing. Most estimators have a default value for each of their parameters. However, in some cases, @@ -84,7 +87,7 @@ def _unit_test_params(self): """ yield {} - def _unit_test_skips(self): + def _unit_test_skips(self) -> set[str]: """Indicates which checks to skip during unit testing. Most estimators pass the full test suite. However, in some cases, some estimators might not diff --git a/river/base/transformer.py b/river/base/transformer.py index 83bc623521..8b4f816c1a 100644 --- a/river/base/transformer.py +++ b/river/base/transformer.py @@ -3,26 +3,26 @@ import abc import typing -from river import base +from river import base, compose if typing.TYPE_CHECKING: import pandas as pd class BaseTransformer: - def __add__(self, other): + def __add__(self, other: BaseTransformer) -> compose.TransformerUnion: """Fuses with another Transformer into a TransformerUnion.""" from river import compose return compose.TransformerUnion(self, other) - def __radd__(self, other): + def __radd__(self, other: BaseTransformer) -> compose.TransformerUnion: """Fuses with another Transformer into a TransformerUnion.""" from river import compose return compose.TransformerUnion(other, self) - def __mul__(self, other): + def __mul__(self, other: BaseTransformer | compose.Pipeline | base.typing.FeatureName | list[base.typing.FeatureName]) -> compose.Grouper | compose.TransformerProduct: from river import compose if isinstance(other, BaseTransformer) or isinstance(other, compose.Pipeline): @@ -30,7 +30,7 @@ def __mul__(self, other): return compose.Grouper(transformer=self, by=other) # type: ignore[arg-type] - def __rmul__(self, other): + def __rmul__(self, other: BaseTransformer | compose.Pipeline | base.typing.FeatureName | list[base.typing.FeatureName]) -> compose.Grouper | compose.TransformerProduct: """Creates a Grouper.""" return self * other diff --git a/river/base/viz.py b/river/base/viz.py index b99c272ff2..7248ace146 100644 --- a/river/base/viz.py +++ b/river/base/viz.py @@ -1,13 +1,14 @@ from __future__ import annotations +# This import is not cyclic because 'viz' is not exported by 'base' +from river import base, compose + import inspect import textwrap from xml.etree import ElementTree as ET -def to_html(obj) -> ET.Element: - from river import base, compose - +def to_html(obj: base.Estimator) -> ET.Element: if isinstance(obj, compose.Pipeline): return pipeline_to_html(obj) if isinstance(obj, compose.TransformerUnion): @@ -17,9 +18,7 @@ def to_html(obj) -> ET.Element: return estimator_to_html(obj) -def estimator_to_html(estimator) -> ET.Element: - from river import compose - +def estimator_to_html(estimator: base.Estimator) -> ET.Element: details = ET.Element("details", attrib={"class": "river-component river-estimator"}) summary = ET.Element("summary", attrib={"class": "river-summary"}) @@ -45,7 +44,7 @@ def estimator_to_html(estimator) -> ET.Element: return details -def pipeline_to_html(pipeline) -> ET.Element: +def pipeline_to_html(pipeline: compose.Pipeline) -> ET.Element: div = ET.Element("div", attrib={"class": "river-component river-pipeline"}) for step in pipeline.steps.values(): @@ -54,7 +53,7 @@ def pipeline_to_html(pipeline) -> ET.Element: return div -def union_to_html(union) -> ET.Element: +def union_to_html(union: compose.TransformerUnion) -> ET.Element: div = ET.Element("div", attrib={"class": "river-component river-union"}) for transformer in union.transformers.values(): @@ -63,7 +62,7 @@ def union_to_html(union) -> ET.Element: return div -def wrapper_to_html(wrapper) -> ET.Element: +def wrapper_to_html(wrapper: base.Wrapper) -> ET.Element: div = ET.Element("div", attrib={"class": "river-component river-wrapper"}) details = ET.Element("details", attrib={"class": "river-details"}) diff --git a/river/base/wrapper.py b/river/base/wrapper.py index 559d9555a7..815beb7b15 100644 --- a/river/base/wrapper.py +++ b/river/base/wrapper.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from river import base class Wrapper(ABC): @@ -8,7 +9,7 @@ class Wrapper(ABC): @property @abstractmethod - def _wrapped_model(self): + def _wrapped_model(self) -> base.Estimator: """Provides access to the wrapped model.""" @property @@ -19,13 +20,13 @@ def _labelloc(self) -> str: def __str__(self) -> str: return f"{type(self).__name__}({self._wrapped_model})" - def _more_tags(self): + def _more_tags(self) -> set[str]: return self._wrapped_model._tags @property - def _supervised(self): + def _supervised(self) -> bool: return self._wrapped_model._supervised @property - def _multiclass(self): + def _multiclass(self) -> bool: return self._wrapped_model._multiclass From d374d10b49835653b230f3b6474f5e688eb8c208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Tue, 14 Jan 2025 18:07:23 +0100 Subject: [PATCH 07/20] Add annotations to generics --- river/base/classifier.py | 7 ++++--- river/base/clusterer.py | 7 ++++--- river/base/ensemble.py | 4 ++-- river/base/multi_output.py | 11 ++++++----- river/base/regressor.py | 5 +++-- river/base/transformer.py | 7 ++++--- river/base/typing.py | 4 ++-- 7 files changed, 25 insertions(+), 20 deletions(-) diff --git a/river/base/classifier.py b/river/base/classifier.py index 38730c9003..d237c1a470 100644 --- a/river/base/classifier.py +++ b/river/base/classifier.py @@ -2,6 +2,7 @@ import abc import typing +from typing import Any from river import base @@ -15,7 +16,7 @@ class Classifier(estimator.Estimator): """A classifier.""" @abc.abstractmethod - def learn_one(self, x: dict, y: base.typing.ClfTarget) -> None: + def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.ClfTarget) -> None: """Update the model with a set of features `x` and a label `y`. Parameters @@ -27,7 +28,7 @@ def learn_one(self, x: dict, y: base.typing.ClfTarget) -> None: """ - def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]: + def predict_proba_one(self, x: dict[base.typing.FeatureName, Any]) -> dict[base.typing.ClfTarget, float]: """Predict the probability of each label for a dictionary of features `x`. Parameters @@ -47,7 +48,7 @@ def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]: # that a classifier does not support predict_proba_one. raise NotImplementedError - def predict_one(self, x: dict, **kwargs) -> base.typing.ClfTarget | None: + def predict_one(self, x: dict[base.typing.FeatureName, Any], **kwargs) -> base.typing.ClfTarget | None: """Predict the label of a set of features `x`. Parameters diff --git a/river/base/clusterer.py b/river/base/clusterer.py index 92514ce2dc..ea314a65d4 100644 --- a/river/base/clusterer.py +++ b/river/base/clusterer.py @@ -1,8 +1,9 @@ from __future__ import annotations import abc +from typing import Any -from . import estimator +from . import estimator, typing class Clusterer(estimator.Estimator): @@ -13,7 +14,7 @@ def _supervised(self) -> bool: return False @abc.abstractmethod - def learn_one(self, x: dict) -> None: + def learn_one(self, x: dict[typing.FeatureName, Any]) -> None: """Update the model with a set of features `x`. Parameters @@ -24,7 +25,7 @@ def learn_one(self, x: dict) -> None: """ @abc.abstractmethod - def predict_one(self, x: dict) -> int: + def predict_one(self, x: dict[typing.FeatureName, Any]) -> int: """Predicts the cluster number for a set of features `x`. Parameters diff --git a/river/base/ensemble.py b/river/base/ensemble.py index 9efbff3f0c..f0a4e4fdf0 100644 --- a/river/base/ensemble.py +++ b/river/base/ensemble.py @@ -8,7 +8,7 @@ from .wrapper import Wrapper -class Ensemble(UserList): +class Ensemble(UserList[Estimator]): """An ensemble is a model which is composed of a list of models. Parameters @@ -31,7 +31,7 @@ def _min_number_of_models(self) -> int: return 2 @property - def models(self) -> list: + def models(self) -> list[Estimator]: return self.data diff --git a/river/base/multi_output.py b/river/base/multi_output.py index 078ed1a362..974c0edde4 100644 --- a/river/base/multi_output.py +++ b/river/base/multi_output.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +import typing from .estimator import Estimator from .typing import FeatureName, RegTarget @@ -10,7 +11,7 @@ class MultiLabelClassifier(Estimator, abc.ABC): """Multi-label classifier.""" @abc.abstractmethod - def learn_one(self, x: dict, y: dict[FeatureName, bool]) -> None: + def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, bool]) -> None: """Update the model with a set of features `x` and the labels `y`. Parameters @@ -22,7 +23,7 @@ def learn_one(self, x: dict, y: dict[FeatureName, bool]) -> None: """ - def predict_proba_one(self, x: dict, **kwargs) -> dict[FeatureName, dict[bool, float]]: + def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs) -> dict[FeatureName, dict[bool, float]]: """Predict the probability of each label appearing given dictionary of features `x`. Parameters @@ -39,7 +40,7 @@ def predict_proba_one(self, x: dict, **kwargs) -> dict[FeatureName, dict[bool, f # In case the multi-label classifier does not support probabilities raise NotImplementedError - def predict_one(self, x: dict, **kwargs) -> dict[FeatureName, bool]: + def predict_one(self, x: dict[FeatureName, typing.Any], **kwargs) -> dict[FeatureName, bool]: """Predict the labels of a set of features `x`. Parameters @@ -68,7 +69,7 @@ class MultiTargetRegressor(Estimator, abc.ABC): """Multi-target regressor.""" @abc.abstractmethod - def learn_one(self, x: dict, y: dict[FeatureName, RegTarget], **kwargs) -> None: + def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, RegTarget], **kwargs) -> None: """Fits to a set of features `x` and a real-valued target `y`. Parameters @@ -81,7 +82,7 @@ def learn_one(self, x: dict, y: dict[FeatureName, RegTarget], **kwargs) -> None: """ @abc.abstractmethod - def predict_one(self, x: dict) -> dict[FeatureName, RegTarget]: + def predict_one(self, x: dict[FeatureName, typing.Any]) -> dict[FeatureName, RegTarget]: """Predict the outputs of features `x`. Parameters diff --git a/river/base/regressor.py b/river/base/regressor.py index 09abacb2b6..88c8a351a2 100644 --- a/river/base/regressor.py +++ b/river/base/regressor.py @@ -2,6 +2,7 @@ import abc import typing +from typing import Any from river import base @@ -15,7 +16,7 @@ class Regressor(estimator.Estimator): """A regressor.""" @abc.abstractmethod - def learn_one(self, x: dict, y: base.typing.RegTarget) -> None: + def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.RegTarget) -> None: """Fits to a set of features `x` and a real-valued target `y`. Parameters @@ -28,7 +29,7 @@ def learn_one(self, x: dict, y: base.typing.RegTarget) -> None: """ @abc.abstractmethod - def predict_one(self, x: dict) -> base.typing.RegTarget: + def predict_one(self, x: dict[base.typing.FeatureName, Any]) -> base.typing.RegTarget: """Predict the output of features `x`. Parameters diff --git a/river/base/transformer.py b/river/base/transformer.py index 8b4f816c1a..b1c255b7e1 100644 --- a/river/base/transformer.py +++ b/river/base/transformer.py @@ -2,6 +2,7 @@ import abc import typing +from typing import Any from river import base, compose @@ -35,7 +36,7 @@ def __rmul__(self, other: BaseTransformer | compose.Pipeline | base.typing.Featu return self * other @abc.abstractmethod - def transform_one(self, x: dict) -> dict: + def transform_one(self, x: dict[base.typing.FeatureName, Any]) -> dict[base.typing.FeatureName, Any]: """Transform a set of features `x`. Parameters @@ -57,7 +58,7 @@ class Transformer(base.Estimator, BaseTransformer): def _supervised(self) -> bool: return False - def learn_one(self, x: dict) -> None: + def learn_one(self, x: dict[base.typing.FeatureName, Any]) -> None: """Update with a set of features `x`. A lot of transformers don't actually have to do anything during the `learn_one` step @@ -81,7 +82,7 @@ class SupervisedTransformer(base.Estimator, BaseTransformer): def _supervised(self) -> bool: return True - def learn_one(self, x: dict, y: base.typing.Target) -> None: + def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.Target) -> None: """Update with a set of features `x` and a target `y`. Parameters diff --git a/river/base/typing.py b/river/base/typing.py index 762526b5ff..7d9d715aaf 100644 --- a/river/base/typing.py +++ b/river/base/typing.py @@ -7,5 +7,5 @@ RegTarget = numbers.Number ClfTarget = typing.Union[bool, str, int] # noqa: UP007 Target = typing.Union[ClfTarget, RegTarget] # noqa: UP007 -Dataset = typing.Iterable[typing.Tuple[dict, typing.Any]] # noqa: UP006 -Stream = typing.Iterator[typing.Tuple[dict, typing.Any]] # noqa: UP006 +Dataset = typing.Iterable[typing.Tuple[dict[FeatureName, typing.Any], typing.Any]] # noqa: UP006 +Stream = typing.Iterator[typing.Tuple[dict[FeatureName, typing.Any], typing.Any]] # noqa: UP006 From b5645da13bc274a00d8183e33b454db1febe5db4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Sun, 1 Sep 2024 02:09:58 +0200 Subject: [PATCH 08/20] Make tags a set instead of a dict --- river/base/estimator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/river/base/estimator.py b/river/base/estimator.py index ebc6372805..dae5351ab2 100644 --- a/river/base/estimator.py +++ b/river/base/estimator.py @@ -47,11 +47,11 @@ def _repr_html_(self) -> str: div_str = ET.tostring(div, encoding="unicode") return f"
{div_str}
" - def _more_tags(self): + def _more_tags(self) -> set[str]: return set() @property - def _tags(self) -> dict[str, bool]: + def _tags(self) -> set[str]: """Return the estimator's tags. Tags can be used to specify what kind of inputs an estimator is able to process. For @@ -67,7 +67,7 @@ def _tags(self) -> dict[str, bool]: for parent in self.__class__.__mro__: try: - tags |= parent._more_tags(self) # type: ignore + tags |= parent._more_tags(self) # type: ignore[attr-defined] except AttributeError: pass From 1a6939b00753077ac4751136644c349f06124528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Mon, 2 Sep 2024 00:18:09 +0200 Subject: [PATCH 09/20] Use Iterator from the collections.abc module typing.Iterator is deprecated since Python 3.9. --- river/base/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/river/base/base.py b/river/base/base.py index 82e30c2e71..b456f204f2 100644 --- a/river/base/base.py +++ b/river/base/base.py @@ -31,7 +31,7 @@ def __repr__(self) -> str: return _repr_obj(obj=self) @classmethod - def _unit_test_params(cls) -> typing.Iterator[dict[str, typing.Any]]: + def _unit_test_params(cls) -> collections.abc.Iterator[dict[str, typing.Any]]: """Instantiates an object with default arguments. Most parameters of each object have a default value. However, this isn't always the case, From 6ca8ed97ca5197c929e751350dd772e77694f1e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 18 Dec 2024 18:01:14 +0100 Subject: [PATCH 10/20] Add some supplementary hints in other modules --- river/compose/pipeline.py | 6 +++--- river/compose/union.py | 2 +- river/datasets/phishing.py | 2 +- river/optim/adam.py | 2 +- river/optim/sgd.py | 2 +- river/preprocessing/scale.py | 2 +- river/rules/amrules.py | 1 + river/stats/var.py | 4 ++-- river/time_series/snarimax.py | 4 ++-- river/utils/pretty.py | 2 +- 10 files changed, 14 insertions(+), 13 deletions(-) diff --git a/river/compose/pipeline.py b/river/compose/pipeline.py index 2c894ede04..b0588220fd 100644 --- a/river/compose/pipeline.py +++ b/river/compose/pipeline.py @@ -274,7 +274,7 @@ class Pipeline(base.Estimator): _LEARN_UNSUPERVISED_DURING_PREDICT = False - def __init__(self, *steps): + def __init__(self, *steps) -> None: self.steps = collections.OrderedDict() for step in steps: self |= step @@ -289,12 +289,12 @@ def __len__(self): """Just for convenience.""" return len(self.steps) - def __or__(self, other): + def __or__(self, other) -> Pipeline: """Insert a step at the end of the pipeline.""" self._add_step(other, at_start=False) return self - def __ror__(self, other): + def __ror__(self, other) -> Pipeline: """Insert a step at the start of the pipeline.""" self._add_step(other, at_start=True) return self diff --git a/river/compose/union.py b/river/compose/union.py index 0210745a1e..b7af584833 100644 --- a/river/compose/union.py +++ b/river/compose/union.py @@ -156,7 +156,7 @@ class TransformerUnion(base.MiniBatchTransformer): """ - def __init__(self, *transformers): + def __init__(self, *transformers) -> None: self.transformers = {} for transformer in transformers: if transformer.__class__ == self.__class__: diff --git a/river/datasets/phishing.py b/river/datasets/phishing.py index d76131edcc..ef84d4e486 100644 --- a/river/datasets/phishing.py +++ b/river/datasets/phishing.py @@ -16,7 +16,7 @@ class Phishing(base.FileDataset): """ - def __init__(self): + def __init__(self) -> None: super().__init__( n_samples=1_250, n_features=9, diff --git a/river/optim/adam.py b/river/optim/adam.py index f78afe05ec..748445e119 100644 --- a/river/optim/adam.py +++ b/river/optim/adam.py @@ -51,7 +51,7 @@ class Adam(optim.base.Optimizer): """ - def __init__(self, lr=0.1, beta_1=0.9, beta_2=0.999, eps=1e-8): + def __init__(self, lr=0.1, beta_1=0.9, beta_2=0.999, eps=1e-8) -> None: super().__init__(lr) self.beta_1 = beta_1 self.beta_2 = beta_2 diff --git a/river/optim/sgd.py b/river/optim/sgd.py index 0e05e8b98a..6d8e9e7321 100644 --- a/river/optim/sgd.py +++ b/river/optim/sgd.py @@ -39,7 +39,7 @@ class SGD(optim.base.Optimizer): """ - def __init__(self, lr=0.01): + def __init__(self, lr=0.01) -> None: super().__init__(lr) def _step_with_dict(self, w, g): diff --git a/river/preprocessing/scale.py b/river/preprocessing/scale.py index eab0a68966..8c3bd9b0f7 100644 --- a/river/preprocessing/scale.py +++ b/river/preprocessing/scale.py @@ -152,7 +152,7 @@ class StandardScaler(base.MiniBatchTransformer): """ - def __init__(self, with_std=True): + def __init__(self, with_std=True) -> None: self.with_std = with_std self.counts = collections.Counter() self.means = collections.defaultdict(float) diff --git a/river/rules/amrules.py b/river/rules/amrules.py index 245f71c1f4..602dc81592 100644 --- a/river/rules/amrules.py +++ b/river/rules/amrules.py @@ -334,6 +334,7 @@ def n_drifts_detected(self) -> int: return self._n_drifts_detected def _new_rule(self) -> RegRule: + predictor: base.Regressor if self.pred_type == self._PRED_MEAN: predictor = MeanRegressor() elif self.pred_type == self._PRED_MODEL: diff --git a/river/stats/var.py b/river/stats/var.py index 86fa412061..cfdc2b6fbf 100644 --- a/river/stats/var.py +++ b/river/stats/var.py @@ -70,7 +70,7 @@ class Var(stats.base.Univariate): """ - def __init__(self, ddof=1): + def __init__(self, ddof=1) -> None: self.ddof = ddof self.mean = stats.Mean() self._S = 0 @@ -79,7 +79,7 @@ def __init__(self, ddof=1): def n(self): return self.mean.n - def update(self, x, w=1.0): + def update(self, x, w=1.0) -> None: mean_old = self.mean.get() self.mean.update(x, w) mean_new = self.mean.get() diff --git a/river/time_series/snarimax.py b/river/time_series/snarimax.py index 150f15c2bb..fa399b2c65 100644 --- a/river/time_series/snarimax.py +++ b/river/time_series/snarimax.py @@ -4,7 +4,7 @@ import itertools import math -from river import base, linear_model, preprocessing, time_series +from river import base, linear_model, preprocessing, time_series, compose __all__ = ["SNARIMAX"] @@ -280,7 +280,7 @@ def __init__( sp: int = 0, sd: int = 0, sq: int = 0, - regressor: base.Regressor | None = None, + regressor: base.Regressor | compose.Pipeline | None = None, ): self.p = p self.d = d diff --git a/river/utils/pretty.py b/river/utils/pretty.py index df9d7c926a..8b53d0a83b 100644 --- a/river/utils/pretty.py +++ b/river/utils/pretty.py @@ -56,7 +56,7 @@ def print_table( return table -def humanize_bytes(n_bytes: int): +def humanize_bytes(n_bytes: int) -> str: """Returns a human-friendly byte size. Parameters From 387a7a57a5a43820f9eb1fd44199a71d7403068f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 4 Sep 2024 19:53:12 +0200 Subject: [PATCH 11/20] Check if the _multiclass attribute exists The wrapped model can be any type of model, but only the classifiers have a _multiclass property. Before acessing it, we must make sure the attribute exists. --- river/base/wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/river/base/wrapper.py b/river/base/wrapper.py index 815beb7b15..65bfc0af71 100644 --- a/river/base/wrapper.py +++ b/river/base/wrapper.py @@ -29,4 +29,4 @@ def _supervised(self) -> bool: @property def _multiclass(self) -> bool: - return self._wrapped_model._multiclass + return isinstance(self._wrapped_model, base.Classifier) and self._wrapped_model._multiclass From 627818aaf603551cc003d9be5556eb6c68489ef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 4 Sep 2024 19:53:57 +0200 Subject: [PATCH 12/20] The Grouper should accept all transformers base.Transformer corresponds only to unbatched unsupervied transformers. Other transformers should also be accepted for grouping. --- river/compose/grouper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/river/compose/grouper.py b/river/compose/grouper.py index 7d19f50ebf..b2bb853f94 100644 --- a/river/compose/grouper.py +++ b/river/compose/grouper.py @@ -28,7 +28,7 @@ class Grouper(base.Transformer): def __init__( self, - transformer: base.Transformer, + transformer: base.BaseTransformer, by: base.typing.FeatureName | list[base.typing.FeatureName], ): self.transformer = transformer From a0d6486678db367eeef86fb8b58ad840048db58f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Sat, 2 Nov 2024 19:05:26 +0100 Subject: [PATCH 13/20] Change the regression target to 'float' The Number interface matches all Python numbers and more. This includes complex numbers. Most formulas are intended to deal with real numbers only, and will error out if applied on ccomplex numbers. Moreover, MyPy does not recognise the interfaces for the 'number' module, and recommends to use Python's default numeric tower as proposed in PEP 484. --- river/base/typing.py | 3 +-- river/metrics/base.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/river/base/typing.py b/river/base/typing.py index 7d9d715aaf..05d3ec9f84 100644 --- a/river/base/typing.py +++ b/river/base/typing.py @@ -1,10 +1,9 @@ from __future__ import annotations -import numbers import typing FeatureName = typing.Hashable -RegTarget = numbers.Number +RegTarget = float ClfTarget = typing.Union[bool, str, int] # noqa: UP007 Target = typing.Union[ClfTarget, RegTarget] # noqa: UP007 Dataset = typing.Iterable[typing.Tuple[dict[FeatureName, typing.Any], typing.Any]] # noqa: UP006 diff --git a/river/metrics/base.py b/river/metrics/base.py index 788de42d3f..7110c66844 100644 --- a/river/metrics/base.py +++ b/river/metrics/base.py @@ -2,7 +2,6 @@ import abc import collections -import numbers import operator from river import base, stats, utils @@ -190,11 +189,11 @@ class RegressionMetric(Metric): _fmt = ",.6f" # use commas to separate big numbers and show 6 decimals @abc.abstractmethod - def update(self, y_true: numbers.Number, y_pred: numbers.Number) -> None: + def update(self, y_true: float, y_pred: float) -> None: """Update the metric.""" @abc.abstractmethod - def revert(self, y_true: numbers.Number, y_pred: numbers.Number) -> None: + def revert(self, y_true: float, y_pred: float) -> None: """Revert the metric.""" @property From 439f1894d97ef41184630b40c415e5fdc005ab01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Sat, 2 Nov 2024 19:18:34 +0100 Subject: [PATCH 14/20] Annotate the kwargs argument for predict_* --- river/base/classifier.py | 4 ++-- river/base/multi_output.py | 6 +++--- river/forest/adaptive_random_forest.py | 2 +- river/tree/stochastic_gradient_tree.py | 3 ++- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/river/base/classifier.py b/river/base/classifier.py index d237c1a470..c3b6c044cb 100644 --- a/river/base/classifier.py +++ b/river/base/classifier.py @@ -28,7 +28,7 @@ def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.ClfTar """ - def predict_proba_one(self, x: dict[base.typing.FeatureName, Any]) -> dict[base.typing.ClfTarget, float]: + def predict_proba_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any) -> dict[base.typing.ClfTarget, float]: """Predict the probability of each label for a dictionary of features `x`. Parameters @@ -48,7 +48,7 @@ def predict_proba_one(self, x: dict[base.typing.FeatureName, Any]) -> dict[base. # that a classifier does not support predict_proba_one. raise NotImplementedError - def predict_one(self, x: dict[base.typing.FeatureName, Any], **kwargs) -> base.typing.ClfTarget | None: + def predict_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any) -> base.typing.ClfTarget | None: """Predict the label of a set of features `x`. Parameters diff --git a/river/base/multi_output.py b/river/base/multi_output.py index 974c0edde4..ae970b16c9 100644 --- a/river/base/multi_output.py +++ b/river/base/multi_output.py @@ -23,7 +23,7 @@ def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, bool] """ - def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs) -> dict[FeatureName, dict[bool, float]]: + def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any) -> dict[FeatureName, dict[bool, float]]: """Predict the probability of each label appearing given dictionary of features `x`. Parameters @@ -40,7 +40,7 @@ def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs) -> dict[ # In case the multi-label classifier does not support probabilities raise NotImplementedError - def predict_one(self, x: dict[FeatureName, typing.Any], **kwargs) -> dict[FeatureName, bool]: + def predict_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any) -> dict[FeatureName, bool]: """Predict the labels of a set of features `x`. Parameters @@ -69,7 +69,7 @@ class MultiTargetRegressor(Estimator, abc.ABC): """Multi-target regressor.""" @abc.abstractmethod - def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, RegTarget], **kwargs) -> None: + def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, RegTarget], **kwargs: typing.Any) -> None: """Fits to a set of features `x` and a real-valued target `y`. Parameters diff --git a/river/forest/adaptive_random_forest.py b/river/forest/adaptive_random_forest.py index 791e614d88..7fcdc80131 100644 --- a/river/forest/adaptive_random_forest.py +++ b/river/forest/adaptive_random_forest.py @@ -663,7 +663,7 @@ def _mutable_attributes(self): def _multiclass(self): return True - def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]: + def predict_proba_one(self, x: dict, **kwargs: typing.Any) -> dict[base.typing.ClfTarget, float]: y_pred: typing.Counter = collections.Counter() if len(self) == 0: diff --git a/river/tree/stochastic_gradient_tree.py b/river/tree/stochastic_gradient_tree.py index 8ea1d76371..01dfb3dcce 100644 --- a/river/tree/stochastic_gradient_tree.py +++ b/river/tree/stochastic_gradient_tree.py @@ -2,6 +2,7 @@ import abc import sys +from typing import Any from scipy.stats import f as f_dist @@ -291,7 +292,7 @@ def __init__( def _target_transform(self, y): return float(y) - def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]: + def predict_proba_one(self, x: dict, **kwargs: Any) -> dict[base.typing.ClfTarget, float]: if isinstance(self._root, DTBranch): leaf = self._root.traverse(x, until_leaf=True) else: From b448db9704b2cbe89b4e78ed2d52af77ed155224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Sat, 2 Nov 2024 21:38:48 +0100 Subject: [PATCH 15/20] Add type hints for some container variables --- river/compose/pipeline.py | 2 +- river/compose/union.py | 2 +- river/preprocessing/scale.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/river/compose/pipeline.py b/river/compose/pipeline.py index b0588220fd..be80adfe5a 100644 --- a/river/compose/pipeline.py +++ b/river/compose/pipeline.py @@ -275,7 +275,7 @@ class Pipeline(base.Estimator): _LEARN_UNSUPERVISED_DURING_PREDICT = False def __init__(self, *steps) -> None: - self.steps = collections.OrderedDict() + self.steps: collections.OrderedDict = collections.OrderedDict() for step in steps: self |= step diff --git a/river/compose/union.py b/river/compose/union.py index b7af584833..b7b600125b 100644 --- a/river/compose/union.py +++ b/river/compose/union.py @@ -157,7 +157,7 @@ class TransformerUnion(base.MiniBatchTransformer): """ def __init__(self, *transformers) -> None: - self.transformers = {} + self.transformers: dict = {} for transformer in transformers: if transformer.__class__ == self.__class__: for t in transformer: diff --git a/river/preprocessing/scale.py b/river/preprocessing/scale.py index 8c3bd9b0f7..654075e15b 100644 --- a/river/preprocessing/scale.py +++ b/river/preprocessing/scale.py @@ -154,9 +154,9 @@ class StandardScaler(base.MiniBatchTransformer): def __init__(self, with_std=True) -> None: self.with_std = with_std - self.counts = collections.Counter() - self.means = collections.defaultdict(float) - self.vars = collections.defaultdict(float) + self.counts: collections.Counter = collections.Counter() + self.means: collections.defaultdict = collections.defaultdict(float) + self.vars: collections.defaultdict = collections.defaultdict(float) def learn_one(self, x): for i, xi in x.items(): From 590efbf7aa16957f6970c690ee0cc8c6ec16a7b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Sat, 2 Nov 2024 22:29:37 +0100 Subject: [PATCH 16/20] Only import compose when type checking --- river/base/estimator.py | 6 ++++-- river/base/transformer.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/river/base/estimator.py b/river/base/estimator.py index dae5351ab2..e38d4b8677 100644 --- a/river/base/estimator.py +++ b/river/base/estimator.py @@ -1,11 +1,13 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, TYPE_CHECKING from collections.abc import Iterator from . import base -from river import compose + +if TYPE_CHECKING: + from river import compose class Estimator(base.Base, abc.ABC): diff --git a/river/base/transformer.py b/river/base/transformer.py index b1c255b7e1..9a6f66df57 100644 --- a/river/base/transformer.py +++ b/river/base/transformer.py @@ -4,10 +4,11 @@ import typing from typing import Any -from river import base, compose +from river import base if typing.TYPE_CHECKING: import pandas as pd + from river import compose class BaseTransformer: From 9eed9f2eec1c4617d92e09e14b4f456e78a7a633 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Wed, 13 Nov 2024 10:37:57 +0100 Subject: [PATCH 17/20] Export BaseTransformer --- river/base/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/river/base/__init__.py b/river/base/__init__.py index 0aaa521934..f54493087f 100644 --- a/river/base/__init__.py +++ b/river/base/__init__.py @@ -29,6 +29,7 @@ from .multi_output import MultiLabelClassifier, MultiTargetRegressor from .regressor import MiniBatchRegressor, Regressor from .transformer import ( + BaseTransformer, MiniBatchSupervisedTransformer, MiniBatchTransformer, SupervisedTransformer, @@ -38,6 +39,7 @@ __all__ = [ "Base", + "BaseTransformer", "BinaryDriftDetector", "BinaryDriftAndWarningDetector", "Classifier", From 69e6382996f13b9f3f58c35a6b3e4876b1bf5441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Sun, 24 Nov 2024 21:57:36 +0100 Subject: [PATCH 18/20] Ignore errors in other modules These error were indirectly caused by changes in `base`, where the stricter typing invalidates the previously ambiguous (but valid) type checks. These errors can be treated in their own modules. --- river/anomaly/pad.py | 4 ++-- river/bandit/evaluate.py | 4 ++-- river/ensemble/streaming_random_patches.py | 14 +++++++------- river/forest/adaptive_random_forest.py | 16 ++++++++-------- river/forest/aggregated_mondrian_forest.py | 4 ++-- river/forest/online_extra_trees.py | 4 ++-- river/multioutput/chain.py | 4 ++-- river/stream/iter_sql.py | 2 +- 8 files changed, 26 insertions(+), 26 deletions(-) diff --git a/river/anomaly/pad.py b/river/anomaly/pad.py index 0ddd3a403f..d5126c5913 100644 --- a/river/anomaly/pad.py +++ b/river/anomaly/pad.py @@ -130,7 +130,7 @@ def learn_one(self, x: dict | None, y: base.typing.Target | float): else: self.predictive_model.learn_one(y=y, x=x) else: - self.predictive_model.learn_one(x=x, y=y) + self.predictive_model.learn_one(x=x, y=y) # type:ignore[attr-defined] def score_one(self, x: dict, y: base.typing.Target): # Return the predicted value of x from the predictive model, first by checking whether @@ -138,7 +138,7 @@ def score_one(self, x: dict, y: base.typing.Target): if isinstance(self.predictive_model, time_series.base.Forecaster): y_pred = self.predictive_model.forecast(self.horizon)[0] else: - y_pred = self.predictive_model.predict_one(x) + y_pred = self.predictive_model.predict_one(x) # type:ignore[attr-defined] # Calculate the squared error squared_error = (y_pred - y) ** 2 diff --git a/river/bandit/evaluate.py b/river/bandit/evaluate.py index 0079079c8e..6de45b4494 100644 --- a/river/bandit/evaluate.py +++ b/river/bandit/evaluate.py @@ -136,10 +136,10 @@ def evaluate( if done[policy_idx]: continue - arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined] + arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined, arg-type] observation, reward, terminated, truncated, info = env_.step(arm) policy_.update(arm, reward) - reward_stat_.update(reward) + reward_stat_.update(reward) # type: ignore[arg-type] yield { "episode": episode, diff --git a/river/ensemble/streaming_random_patches.py b/river/ensemble/streaming_random_patches.py index 415b30be10..21d94e77e3 100644 --- a/river/ensemble/streaming_random_patches.py +++ b/river/ensemble/streaming_random_patches.py @@ -93,11 +93,11 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs): for model in self: # Get prediction for instance - y_pred = model.predict_one(x) + y_pred = model.predict_one(x) # type:ignore[attr-defined] # Update performance evaluator if y_pred is not None: - model.metric.update(y_true=y, y_pred=y_pred) + model.metric.update(y_true=y, y_pred=y_pred) # type: ignore[attr-defined] # BaseSRPEstimator has a metric field # Train using random subspaces without resampling, # i.e. all instances are used for training. @@ -109,7 +109,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs): k = poisson(rate=self.lam, rng=self._rng) if k == 0: continue - model.learn_one(x=x, y=y, w=k, n_samples_seen=self._n_samples_seen) + model.learn_one(x=x, y=y, w=k, n_samples_seen=self._n_samples_seen) # type:ignore[attr-defined] def _generate_subspaces(self, features: list): n_features = len(features) @@ -543,7 +543,7 @@ def learn_one( # TODO Find a way to verify if the model natively supports sample_weight (w) for _ in range(int(w)): - self.model.learn_one(x=x_subset, y=y, **kwargs) + self.model.learn_one(x=x_subset, y=y, **kwargs) # type:ignore[attr-defined] if self._background_learner: # Train the background learner @@ -557,7 +557,7 @@ def learn_one( ) if not self.disable_drift_detector and not self.is_background_learner: - correctly_classifies = self.model.predict_one(x_subset) == y + correctly_classifies = self.model.predict_one(x_subset) == y # type:ignore[attr-defined] # Check for warnings only if the background learner is active if not self.disable_background_learner: # Update the warning detection method @@ -845,10 +845,10 @@ def learn_one( # TODO Find a way to verify if the model natively supports sample_weight (w) for _ in range(int(w)): - self.model.learn_one(x=x_subset, y=y, **kwargs) + self.model.learn_one(x=x_subset, y=y, **kwargs) # type:ignore[attr-defined] # Drift detection input - y_pred = self.model.predict_one(x_subset) + y_pred = self.model.predict_one(x_subset) # type:ignore[attr-defined] if self.drift_detection_criteria == "error": # Track absolute error drift_detector_input = abs(y_pred - y) diff --git a/river/forest/adaptive_random_forest.py b/river/forest/adaptive_random_forest.py index 7fcdc80131..5137b240bf 100644 --- a/river/forest/adaptive_random_forest.py +++ b/river/forest/adaptive_random_forest.py @@ -155,13 +155,13 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs): self._init_ensemble(sorted(x.keys())) for i, model in enumerate(self): - y_pred = model.predict_one(x) + y_pred = model.predict_one(x) # type:ignore[attr-defined] # Update performance evaluator self._metrics[i].update( - y_true=y, + y_true=y, # type:ignore[arg-type] y_pred=( - model.predict_proba_one(x) + model.predict_proba_one(x) # type:ignore[attr-defined] if isinstance(self.metric, metrics.base.ClassificationMetric) and not self.metric.requires_labels else y_pred @@ -173,7 +173,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs): if not self._warning_detection_disabled and self._background[i] is not None: self._background[i].learn_one(x=x, y=y, w=k) # type: ignore - model.learn_one(x=x, y=y, w=k) + model.learn_one(x=x, y=y, w=k) # type:ignore[attr-defined] drift_input = None if not self._warning_detection_disabled: @@ -198,7 +198,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs): if self._drift_detectors[i].drift_detected: if not self._warning_detection_disabled and self._background[i] is not None: - self.data[i] = self._background[i] + self.data[i] = self._background[i] # type:ignore[assignment] self._background[i] = None self._warning_detectors[i] = self.warning_detector.clone() self._drift_detectors[i] = self.drift_detector.clone() @@ -671,7 +671,7 @@ def predict_proba_one(self, x: dict, **kwargs: typing.Any) -> dict[base.typing.C return y_pred # type: ignore for i, model in enumerate(self): - y_proba_temp = model.predict_proba_one(x) + y_proba_temp = model.predict_proba_one(x) # type:ignore[attr-defined] metric_value = self._metrics[i].get() if not self.disable_weighted_vote and metric_value > 0.0: y_proba_temp = {k: val * metric_value for k, val in y_proba_temp.items()} @@ -952,7 +952,7 @@ def predict_one(self, x: dict) -> base.typing.RegTarget: weights = np.zeros(self.n_models) sum_weights = 0.0 for i, model in enumerate(self): - y_pred[i] = model.predict_one(x) + y_pred[i] = model.predict_one(x) # type:ignore[attr-defined] weights[i] = self._metrics[i].get() sum_weights += weights[i] @@ -964,7 +964,7 @@ def predict_one(self, x: dict) -> base.typing.RegTarget: y_pred *= weights else: for i, model in enumerate(self): - y_pred[i] = model.predict_one(x) + y_pred[i] = model.predict_one(x) # type:ignore[attr-defined] if self.aggregation_method == self._MEAN: y_pred = y_pred.mean() diff --git a/river/forest/aggregated_mondrian_forest.py b/river/forest/aggregated_mondrian_forest.py index 12601d0cd5..2e79b0667c 100644 --- a/river/forest/aggregated_mondrian_forest.py +++ b/river/forest/aggregated_mondrian_forest.py @@ -170,7 +170,7 @@ def __init__( self._classes: set[base.typing.ClfTarget] = set() def _initialize_trees(self) -> None: - self.data: list[MondrianTreeClassifier] = [] + self.data: list[MondrianTreeClassifier] = [] # type:ignore[assignment] for _ in range(self.n_estimators): tree = MondrianTreeClassifier( self.step, @@ -290,7 +290,7 @@ def __init__( def _initialize_trees(self) -> None: """Initialize the forest.""" - self.data: list[MondrianTreeRegressor] = [] + self.data: list[MondrianTreeRegressor] = [] # type:ignore[assignment] for _ in range(self.n_estimators): # We don't want to have the same stochastic scheme for each tree, or it'll break the randomness # Hence we introduce a new seed for each, that is derived of the given seed by a deterministic process diff --git a/river/forest/online_extra_trees.py b/river/forest/online_extra_trees.py index ee361007eb..7c8808d401 100644 --- a/river/forest/online_extra_trees.py +++ b/river/forest/online_extra_trees.py @@ -720,7 +720,7 @@ def predict_one(self, x: dict) -> base.typing.RegTarget: weights = [] for perf, model in zip(self._perfs, self.models): - preds.append(model.predict_one(x)) + preds.append(model.predict_one(x)) # type:ignore[attr-defined] weights.append(perf.get()) sum_weights = sum(weights) @@ -733,6 +733,6 @@ def predict_one(self, x: dict) -> base.typing.RegTarget: preds = [(w / sum_weights) * pred for w, pred in zip(weights, preds)] return sum(preds) else: - preds = [model.predict_one(x) for model in self.models] + preds = [model.predict_one(x) for model in self.models] # type:ignore[attr-defined] return sum(preds) / len(preds) diff --git a/river/multioutput/chain.py b/river/multioutput/chain.py index e4516fd708..091a2bef3c 100644 --- a/river/multioutput/chain.py +++ b/river/multioutput/chain.py @@ -34,7 +34,7 @@ def __getitem__(self, key): return self[key] -class ClassifierChain(BaseChain, base.MultiLabelClassifier): +class ClassifierChain(BaseChain, base.MultiLabelClassifier): # type:ignore[misc] """A multi-output model that arranges classifiers into a chain. This will create one model per output. The prediction of the first output will be used as a @@ -165,7 +165,7 @@ def predict_proba_one(self, x, **kwargs): return y_pred -class RegressorChain(BaseChain, base.MultiTargetRegressor): +class RegressorChain(BaseChain, base.MultiTargetRegressor): # type:ignore[misc] """A multi-output model that arranges regressors into a chain. This will create one model per output. The prediction of the first output will be used as a diff --git a/river/stream/iter_sql.py b/river/stream/iter_sql.py index 13462315e0..2c007da4cc 100644 --- a/river/stream/iter_sql.py +++ b/river/stream/iter_sql.py @@ -102,4 +102,4 @@ def iter_sql( for row in result_proxy: x = dict(row._mapping.items()) y = x.pop(target_name) - yield x, y + yield x, y # type: ignore[misc] From b131ec8b005615b9b73556523a945c2b4350435d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Mon, 2 Sep 2024 00:30:28 +0200 Subject: [PATCH 19/20] Formatting --- river/bandit/evaluate.py | 2 +- river/base/base.py | 13 ++++++++++--- river/base/classifier.py | 8 ++++++-- river/base/estimator.py | 6 +++--- river/base/multi_output.py | 15 ++++++++++++--- river/base/transformer.py | 23 +++++++++++++++++++---- river/base/viz.py | 6 +++--- river/base/wrapper.py | 1 + river/forest/adaptive_random_forest.py | 4 +++- river/stream/iter_sql.py | 2 +- river/time_series/snarimax.py | 2 +- 11 files changed, 60 insertions(+), 22 deletions(-) diff --git a/river/bandit/evaluate.py b/river/bandit/evaluate.py index 6de45b4494..caf4ea46dd 100644 --- a/river/bandit/evaluate.py +++ b/river/bandit/evaluate.py @@ -139,7 +139,7 @@ def evaluate( arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined, arg-type] observation, reward, terminated, truncated, info = env_.step(arm) policy_.update(arm, reward) - reward_stat_.update(reward) # type: ignore[arg-type] + reward_stat_.update(reward) # type: ignore[arg-type] yield { "episode": episode, diff --git a/river/base/base.py b/river/base/base.py index b456f204f2..0db98c8759 100644 --- a/river/base/base.py +++ b/river/base/base.py @@ -73,7 +73,9 @@ def _get_params(self) -> dict[str, typing.Any]: return params - def clone(self, new_params: dict[str, typing.Any] | None = None, include_attributes: bool = False) -> typing_extensions.Self: + def clone( + self, new_params: dict[str, typing.Any] | None = None, include_attributes: bool = False + ) -> typing_extensions.Self: """Return a fresh estimator with the same parameters. The clone has the same parameters but has not been updated with any data. @@ -371,7 +373,7 @@ def _raw_memory_usage(self) -> int: buffer.extend([k for k in obj.keys()]) buffer.extend([v for v in obj.values()]) elif hasattr(obj, "__dict__"): # Save object contents - contents= vars(obj) + contents = vars(obj) size += sys.getsizeof(contents) buffer.extend([k for k in contents.keys()]) buffer.extend([v for v in contents.values()]) @@ -398,7 +400,12 @@ def _memory_usage(self) -> str: return utils.pretty.humanize_bytes(self._raw_memory_usage) -def _log_method_calls(self: typing.Any, name: str, class_condition: typing.Callable[[typing.Any], bool], method_condition: typing.Callable[[typing.Any], bool]) -> typing.Any: +def _log_method_calls( + self: typing.Any, + name: str, + class_condition: typing.Callable[[typing.Any], bool], + method_condition: typing.Callable[[typing.Any], bool], +) -> typing.Any: method = object.__getattribute__(self, name) if ( not name.startswith("_") diff --git a/river/base/classifier.py b/river/base/classifier.py index c3b6c044cb..01f6b8b9d4 100644 --- a/river/base/classifier.py +++ b/river/base/classifier.py @@ -28,7 +28,9 @@ def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.ClfTar """ - def predict_proba_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any) -> dict[base.typing.ClfTarget, float]: + def predict_proba_one( + self, x: dict[base.typing.FeatureName, Any], **kwargs: Any + ) -> dict[base.typing.ClfTarget, float]: """Predict the probability of each label for a dictionary of features `x`. Parameters @@ -48,7 +50,9 @@ def predict_proba_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any # that a classifier does not support predict_proba_one. raise NotImplementedError - def predict_one(self, x: dict[base.typing.FeatureName, Any], **kwargs: Any) -> base.typing.ClfTarget | None: + def predict_one( + self, x: dict[base.typing.FeatureName, Any], **kwargs: Any + ) -> base.typing.ClfTarget | None: """Predict the label of a set of features `x`. Parameters diff --git a/river/base/estimator.py b/river/base/estimator.py index e38d4b8677..77504c69f2 100644 --- a/river/base/estimator.py +++ b/river/base/estimator.py @@ -1,14 +1,14 @@ from __future__ import annotations import abc -from typing import Any, TYPE_CHECKING from collections.abc import Iterator - -from . import base +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from river import compose +from . import base + class Estimator(base.Base, abc.ABC): """An estimator.""" diff --git a/river/base/multi_output.py b/river/base/multi_output.py index ae970b16c9..68cf013afc 100644 --- a/river/base/multi_output.py +++ b/river/base/multi_output.py @@ -23,7 +23,9 @@ def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, bool] """ - def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any) -> dict[FeatureName, dict[bool, float]]: + def predict_proba_one( + self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any + ) -> dict[FeatureName, dict[bool, float]]: """Predict the probability of each label appearing given dictionary of features `x`. Parameters @@ -40,7 +42,9 @@ def predict_proba_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.A # In case the multi-label classifier does not support probabilities raise NotImplementedError - def predict_one(self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any) -> dict[FeatureName, bool]: + def predict_one( + self, x: dict[FeatureName, typing.Any], **kwargs: typing.Any + ) -> dict[FeatureName, bool]: """Predict the labels of a set of features `x`. Parameters @@ -69,7 +73,12 @@ class MultiTargetRegressor(Estimator, abc.ABC): """Multi-target regressor.""" @abc.abstractmethod - def learn_one(self, x: dict[FeatureName, typing.Any], y: dict[FeatureName, RegTarget], **kwargs: typing.Any) -> None: + def learn_one( + self, + x: dict[FeatureName, typing.Any], + y: dict[FeatureName, RegTarget], + **kwargs: typing.Any, + ) -> None: """Fits to a set of features `x` and a real-valued target `y`. Parameters diff --git a/river/base/transformer.py b/river/base/transformer.py index 9a6f66df57..b41aa78f32 100644 --- a/river/base/transformer.py +++ b/river/base/transformer.py @@ -8,6 +8,7 @@ if typing.TYPE_CHECKING: import pandas as pd + from river import compose @@ -24,20 +25,34 @@ def __radd__(self, other: BaseTransformer) -> compose.TransformerUnion: return compose.TransformerUnion(other, self) - def __mul__(self, other: BaseTransformer | compose.Pipeline | base.typing.FeatureName | list[base.typing.FeatureName]) -> compose.Grouper | compose.TransformerProduct: + def __mul__( + self, + other: BaseTransformer + | compose.Pipeline + | base.typing.FeatureName + | list[base.typing.FeatureName], + ) -> compose.Grouper | compose.TransformerProduct: from river import compose if isinstance(other, BaseTransformer) or isinstance(other, compose.Pipeline): return compose.TransformerProduct(self, other) - return compose.Grouper(transformer=self, by=other) # type: ignore[arg-type] + return compose.Grouper(transformer=self, by=other) - def __rmul__(self, other: BaseTransformer | compose.Pipeline | base.typing.FeatureName | list[base.typing.FeatureName]) -> compose.Grouper | compose.TransformerProduct: + def __rmul__( + self, + other: BaseTransformer + | compose.Pipeline + | base.typing.FeatureName + | list[base.typing.FeatureName], + ) -> compose.Grouper | compose.TransformerProduct: """Creates a Grouper.""" return self * other @abc.abstractmethod - def transform_one(self, x: dict[base.typing.FeatureName, Any]) -> dict[base.typing.FeatureName, Any]: + def transform_one( + self, x: dict[base.typing.FeatureName, Any] + ) -> dict[base.typing.FeatureName, Any]: """Transform a set of features `x`. Parameters diff --git a/river/base/viz.py b/river/base/viz.py index 7248ace146..8c7201cca3 100644 --- a/river/base/viz.py +++ b/river/base/viz.py @@ -1,12 +1,12 @@ from __future__ import annotations -# This import is not cyclic because 'viz' is not exported by 'base' -from river import base, compose - import inspect import textwrap from xml.etree import ElementTree as ET +# This import is not cyclic because 'viz' is not exported by 'base' +from river import base, compose + def to_html(obj: base.Estimator) -> ET.Element: if isinstance(obj, compose.Pipeline): diff --git a/river/base/wrapper.py b/river/base/wrapper.py index 65bfc0af71..b1f484d710 100644 --- a/river/base/wrapper.py +++ b/river/base/wrapper.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod + from river import base diff --git a/river/forest/adaptive_random_forest.py b/river/forest/adaptive_random_forest.py index 5137b240bf..9ed1df53dd 100644 --- a/river/forest/adaptive_random_forest.py +++ b/river/forest/adaptive_random_forest.py @@ -663,7 +663,9 @@ def _mutable_attributes(self): def _multiclass(self): return True - def predict_proba_one(self, x: dict, **kwargs: typing.Any) -> dict[base.typing.ClfTarget, float]: + def predict_proba_one( + self, x: dict, **kwargs: typing.Any + ) -> dict[base.typing.ClfTarget, float]: y_pred: typing.Counter = collections.Counter() if len(self) == 0: diff --git a/river/stream/iter_sql.py b/river/stream/iter_sql.py index 2c007da4cc..453cd0c33e 100644 --- a/river/stream/iter_sql.py +++ b/river/stream/iter_sql.py @@ -102,4 +102,4 @@ def iter_sql( for row in result_proxy: x = dict(row._mapping.items()) y = x.pop(target_name) - yield x, y # type: ignore[misc] + yield x, y # type: ignore[misc] diff --git a/river/time_series/snarimax.py b/river/time_series/snarimax.py index fa399b2c65..144d34fb63 100644 --- a/river/time_series/snarimax.py +++ b/river/time_series/snarimax.py @@ -4,7 +4,7 @@ import itertools import math -from river import base, linear_model, preprocessing, time_series, compose +from river import base, compose, linear_model, preprocessing, time_series __all__ = ["SNARIMAX"] From 7ab53ebf8e133bd9b812e80e5543811a3507a735 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89mile=20Royer?= Date: Fri, 17 Jan 2025 14:08:03 +0100 Subject: [PATCH 20/20] Add a changelog note --- docs/releases/unreleased.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md index 79e701b844..582cabbcb9 100644 --- a/docs/releases/unreleased.md +++ b/docs/releases/unreleased.md @@ -1 +1,6 @@ # Unreleased + +## base + +- The `base` module is now fully type-annotated. Some type hints have changed, but this does not impact the behaviour of the code. For instance, the regression target is now indicated as a float instead of a Number. +- The `tags` and `more_tags` properties of `base.Estimator` are now both a set of strings.