Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full type hints for the module base #1668

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
# Unreleased

## base

- The `base` module is now fully type-annotated. Some type hints have changed, but this does not impact the behaviour of the code. For instance, the regression target is now indicated as a float instead of a Number.
- The `tags` and `more_tags` properties of `base.Estimator` are now both a set of strings.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
# Disable strict mode for all non fully-typed modules
module = [
"river.base.*",
"river.metrics.*",
"river.utils.*",
"river.stats.*",
Expand Down
4 changes: 2 additions & 2 deletions river/anomaly/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def learn_one(self, x: dict | None, y: base.typing.Target | float):
else:
self.predictive_model.learn_one(y=y, x=x)
else:
self.predictive_model.learn_one(x=x, y=y)
self.predictive_model.learn_one(x=x, y=y) # type:ignore[attr-defined]

def score_one(self, x: dict, y: base.typing.Target):
# Return the predicted value of x from the predictive model, first by checking whether
# it is a time-series forecaster.
if isinstance(self.predictive_model, time_series.base.Forecaster):
y_pred = self.predictive_model.forecast(self.horizon)[0]
else:
y_pred = self.predictive_model.predict_one(x)
y_pred = self.predictive_model.predict_one(x) # type:ignore[attr-defined]

# Calculate the squared error
squared_error = (y_pred - y) ** 2
Expand Down
4 changes: 2 additions & 2 deletions river/bandit/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def evaluate(
if done[policy_idx]:
continue

arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined]
arm = policy_.pull(range(env_.action_space.n)) # type: ignore[attr-defined, arg-type]
observation, reward, terminated, truncated, info = env_.step(arm)
policy_.update(arm, reward)
reward_stat_.update(reward)
reward_stat_.update(reward) # type: ignore[arg-type]

yield {
"episode": episode,
Expand Down
2 changes: 2 additions & 0 deletions river/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .multi_output import MultiLabelClassifier, MultiTargetRegressor
from .regressor import MiniBatchRegressor, Regressor
from .transformer import (
BaseTransformer,
MiniBatchSupervisedTransformer,
MiniBatchTransformer,
SupervisedTransformer,
Expand All @@ -38,6 +39,7 @@

__all__ = [
"Base",
"BaseTransformer",
"BinaryDriftDetector",
"BinaryDriftAndWarningDetector",
"Classifier",
Expand Down
47 changes: 28 additions & 19 deletions river/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import types
import typing

import typing_extensions


class Base:
"""Base class that is inherited by the majority of classes in River.
Expand All @@ -22,14 +24,14 @@ class Base:

"""

def __str__(self):
def __str__(self) -> str:
return self.__class__.__name__

def __repr__(self):
def __repr__(self) -> str:
return _repr_obj(obj=self)

@classmethod
def _unit_test_params(cls):
def _unit_test_params(cls) -> collections.abc.Iterator[dict[str, typing.Any]]:
"""Instantiates an object with default arguments.

Most parameters of each object have a default value. However, this isn't always the case,
Expand Down Expand Up @@ -71,7 +73,9 @@ def _get_params(self) -> dict[str, typing.Any]:

return params

def clone(self, new_params: dict | None = None, include_attributes=False):
def clone(
self, new_params: dict[str, typing.Any] | None = None, include_attributes: bool = False
) -> typing_extensions.Self:
"""Return a fresh estimator with the same parameters.

The clone has the same parameters but has not been updated with any data.
Expand Down Expand Up @@ -167,7 +171,7 @@ def clone(self, new_params: dict | None = None, include_attributes=False):

"""

def is_class_param(param):
def is_class_param(param: typing.Any) -> bool:
# See expand_param_grid to understand why this is necessary
return (
isinstance(param, tuple)
Expand Down Expand Up @@ -202,10 +206,10 @@ def is_class_param(param):
return clone

@property
def _mutable_attributes(self) -> set:
def _mutable_attributes(self) -> set[str]:
return set()

def mutate(self, new_attrs: dict):
def mutate(self, new_attrs: dict[str, typing.Any]) -> None:
"""Modify attributes.

This changes parameters inplace. Although you can change attributes yourself, this is the
Expand Down Expand Up @@ -296,8 +300,8 @@ def mutate(self, new_attrs: dict):

"""

def _mutate(obj, new_attrs):
def is_class_attr(name, attr):
def _mutate(obj: typing.Any, new_attrs: dict[str, typing.Any]) -> None:
def is_class_attr(name: str, attr: typing.Any) -> bool:
return hasattr(getattr(obj, name), "mutate") and isinstance(attr, dict)

for name, attr in new_attrs.items():
Expand All @@ -318,7 +322,7 @@ def is_class_attr(name, attr):
_mutate(obj=self, new_attrs=new_attrs)

@property
def _is_stochastic(self):
def _is_stochastic(self) -> bool:
"""Indicates if the model contains an unset seed parameter.

The convention in River is to control randomness by exposing a seed parameter. This seed
Expand All @@ -329,14 +333,14 @@ def _is_stochastic(self):

"""

def is_class_param(param):
def is_class_param(param: typing.Any) -> bool:
return (
isinstance(param, tuple)
and inspect.isclass(param[0])
and isinstance(param[1], dict)
)

def find(params):
def find(params: dict[str, typing.Any]) -> bool:
if not isinstance(params, dict):
return False
for name, param in params.items():
Expand All @@ -354,7 +358,7 @@ def _raw_memory_usage(self) -> int:

import numpy as np

buffer = collections.deque([self])
buffer: collections.deque[typing.Any] = collections.deque([self])
seen = set()
size = 0
while len(buffer) > 0:
Expand All @@ -369,7 +373,7 @@ def _raw_memory_usage(self) -> int:
buffer.extend([k for k in obj.keys()])
buffer.extend([v for v in obj.values()])
elif hasattr(obj, "__dict__"): # Save object contents
contents: dict = vars(obj)
contents = vars(obj)
size += sys.getsizeof(contents)
buffer.extend([k for k in contents.keys()])
buffer.extend([v for v in contents.values()])
Expand All @@ -384,7 +388,7 @@ def _raw_memory_usage(self) -> int:
elif hasattr(obj, "__iter__") and not (
isinstance(obj, str) or isinstance(obj, bytes) or isinstance(obj, bytearray)
):
buffer.extend([i for i in obj]) # type: ignore
buffer.extend([i for i in obj])

return size

Expand All @@ -396,7 +400,12 @@ def _memory_usage(self) -> str:
return utils.pretty.humanize_bytes(self._raw_memory_usage)


def _log_method_calls(self, name, class_condition, method_condition):
def _log_method_calls(
self: typing.Any,
name: str,
class_condition: typing.Callable[[typing.Any], bool],
method_condition: typing.Callable[[typing.Any], bool],
) -> typing.Any:
method = object.__getattribute__(self, name)
if (
not name.startswith("_")
Expand All @@ -412,7 +421,7 @@ def _log_method_calls(self, name, class_condition, method_condition):
def log_method_calls(
class_condition: typing.Callable[[typing.Any], bool] | None = None,
method_condition: typing.Callable[[typing.Any], bool] | None = None,
):
) -> collections.abc.Iterator[None]:
"""A context manager to log method calls.

All method calls will be logged by default. This behavior can be overriden by passing filtering
Expand Down Expand Up @@ -477,7 +486,7 @@ def log_method_calls(
Base.__getattribute__ = old # type: ignore


def _repr_obj(obj, show_modules: bool = False, depth: int = 0) -> str:
def _repr_obj(obj: typing.Any, show_modules: bool = False, depth: int = 0) -> str:
"""Return a pretty representation of an object."""

rep = f"{obj.__class__.__name__} ("
Expand All @@ -487,7 +496,7 @@ def _repr_obj(obj, show_modules: bool = False, depth: int = 0) -> str:

params = {
name: getattr(obj, name)
for name, param in inspect.signature(obj.__init__).parameters.items() # type: ignore
for name, param in inspect.signature(obj.__init__).parameters.items()
if not (
param.name == "args"
and param.kind == param.VAR_POSITIONAL
Expand Down
15 changes: 10 additions & 5 deletions river/base/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import typing
from typing import Any

from river import base

Expand All @@ -15,7 +16,7 @@ class Classifier(estimator.Estimator):
"""A classifier."""

@abc.abstractmethod
def learn_one(self, x: dict, y: base.typing.ClfTarget) -> None:
def learn_one(self, x: dict[base.typing.FeatureName, Any], y: base.typing.ClfTarget) -> None:
"""Update the model with a set of features `x` and a label `y`.

Parameters
Expand All @@ -27,7 +28,9 @@ def learn_one(self, x: dict, y: base.typing.ClfTarget) -> None:

"""

def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]:
def predict_proba_one(
self, x: dict[base.typing.FeatureName, Any], **kwargs: Any
) -> dict[base.typing.ClfTarget, float]:
"""Predict the probability of each label for a dictionary of features `x`.

Parameters
Expand All @@ -47,7 +50,9 @@ def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]:
# that a classifier does not support predict_proba_one.
raise NotImplementedError

def predict_one(self, x: dict, **kwargs) -> base.typing.ClfTarget | None:
def predict_one(
self, x: dict[base.typing.FeatureName, Any], **kwargs: Any
) -> base.typing.ClfTarget | None:
"""Predict the label of a set of features `x`.

Parameters
Expand All @@ -69,11 +74,11 @@ def predict_one(self, x: dict, **kwargs) -> base.typing.ClfTarget | None:
return None

@property
def _multiclass(self):
def _multiclass(self) -> bool:
return False

@property
def _supervised(self):
def _supervised(self) -> bool:
return True


Expand Down
9 changes: 5 additions & 4 deletions river/base/clusterer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

import abc
from typing import Any

from . import estimator
from . import estimator, typing


class Clusterer(estimator.Estimator):
"""A clustering model."""

@property
def _supervised(self):
def _supervised(self) -> bool:
return False

@abc.abstractmethod
def learn_one(self, x: dict) -> None:
def learn_one(self, x: dict[typing.FeatureName, Any]) -> None:
"""Update the model with a set of features `x`.

Parameters
Expand All @@ -24,7 +25,7 @@ def learn_one(self, x: dict) -> None:
"""

@abc.abstractmethod
def predict_one(self, x: dict) -> int:
def predict_one(self, x: dict[typing.FeatureName, Any]) -> int:
"""Predicts the cluster number for a set of features `x`.

Parameters
Expand Down
8 changes: 4 additions & 4 deletions river/base/drift_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class _BaseDriftDetector(base.Base):

"""

def __init__(self):
def __init__(self) -> None:
self._drift_detected = False

def _reset(self) -> None:
Expand All @@ -40,16 +40,16 @@ class _BaseDriftAndWarningDetector(_BaseDriftDetector):

"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._warning_detected = False

def _reset(self):
def _reset(self) -> None:
super()._reset()
self._warning_detected = False

@property
def warning_detected(self):
def warning_detected(self) -> bool:
"""Whether or not a drift is detected following the last update."""
return self._warning_detected

Expand Down
12 changes: 6 additions & 6 deletions river/base/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .wrapper import Wrapper


class Ensemble(UserList):
class Ensemble(UserList[Estimator]):
"""An ensemble is a model which is composed of a list of models.

Parameters
Expand All @@ -17,7 +17,7 @@ class Ensemble(UserList):

"""

def __init__(self, models: Iterator[Estimator]):
def __init__(self, models: Iterator[Estimator]) -> None:
super().__init__(models)

if len(self) < self._min_number_of_models:
Expand All @@ -27,11 +27,11 @@ def __init__(self, models: Iterator[Estimator]):
)

@property
def _min_number_of_models(self):
def _min_number_of_models(self) -> int:
return 2

@property
def models(self):
def models(self) -> list[Estimator]:
return self.data


Expand All @@ -49,13 +49,13 @@ class WrapperEnsemble(Ensemble, Wrapper):

"""

def __init__(self, model, n_models, seed):
def __init__(self, model: Estimator, n_models: int, seed: int | None) -> None:
super().__init__(model.clone() for _ in range(n_models))
self.model = model
self.n_models = n_models
self.seed = seed
self._rng = Random(seed)

@property
def _wrapped_model(self):
def _wrapped_model(self) -> Estimator:
return self.model
Loading
Loading