Skip to content

Commit

Permalink
Model configs (#170)
Browse files Browse the repository at this point in the history
- Added configs to models

Сloses #108
#157
  • Loading branch information
feldlime authored Nov 9, 2024
1 parent 0be5e15 commit 0d1df86
Show file tree
Hide file tree
Showing 28 changed files with 1,948 additions and 163 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## Unreleased

### Added
- `from_config`, `get_config` and `get_params` methods to all models except neural-net-based([#170](https://github.com/MobileTeleSystems/RecTools/pull/170))


## [0.8.0] - 28.08.2024

### Added
Expand Down
150 changes: 141 additions & 9 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ tqdm = "^4.27.0"
implicit = "^0.7.1"
attrs = ">=19.1.0,<24.0.0"
typeguard = "^4.1.0"

pydantic = "^2.8.2"
pydantic-core = "^2.20.1"
typing-extensions = "^4.12.2"

# The latest released version of lightfm is 1.17 and it's not compatible with PEP-517 installers (like latest poetry versions).
rectools-lightfm = {version="1.17.1", python = "<3.12", optional = true}
rectools-lightfm = {version="1.17.2", python = "<3.12", optional = true}

nmslib = {version = "^2.0.4", python = "<3.11", optional = true}
# nmslib officialy doens't support Python 3.11 and 3.12. Use https://github.com/metabrainz/nmslib-metabrainz instead
Expand Down
3 changes: 2 additions & 1 deletion rectools/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def get_user_item_matrix(
include_weights: bool = True,
include_warm_users: bool = False,
include_warm_items: bool = False,
dtype: tp.Type = np.float32,
) -> sparse.csr_matrix:
"""
Construct user-item CSR matrix based on `interactions` attribute.
Expand All @@ -224,7 +225,7 @@ def get_user_item_matrix(
csr_matrix
Resized user-item CSR matrix
"""
matrix = self.interactions.get_user_item_matrix(include_weights)
matrix = self.interactions.get_user_item_matrix(include_weights, dtype)
n_rows = self.user_id_map.size if include_warm_users else matrix.shape[0]
n_columns = self.item_id_map.size if include_warm_items else matrix.shape[1]
matrix.resize(n_rows, n_columns)
Expand Down
6 changes: 4 additions & 2 deletions rectools/dataset/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Structure for saving user-item interactions."""

import typing as tp

import attr
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -121,7 +123,7 @@ def from_raw(

return cls(df)

def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matrix:
def get_user_item_matrix(self, include_weights: bool = True, dtype: tp.Type = np.float32) -> sparse.csr_matrix:
"""
Form a user-item CSR matrix based on interactions data.
Expand All @@ -142,7 +144,7 @@ def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matri

csr = sparse.csr_matrix(
(
values.astype(np.float32),
values.astype(dtype),
(
self.df[Columns.User].values,
self.df[Columns.Item].values,
Expand Down
140 changes: 139 additions & 1 deletion rectools/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@

import numpy as np
import pandas as pd
import typing_extensions as tpe
from pydantic import PlainSerializer
from pydantic_core import PydanticSerializationError

from rectools import Columns, ExternalIds, InternalIds
from rectools.dataset import Dataset
from rectools.dataset.identifiers import IdMap
from rectools.exceptions import NotFittedError
from rectools.types import ExternalIdsArray, InternalIdsArray
from rectools.utils.config import BaseConfig
from rectools.utils.misc import make_dict_flat

T = tp.TypeVar("T", bound="ModelBase")
ScoresArray = np.ndarray
Expand All @@ -38,7 +43,30 @@
RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet)


class ModelBase:
def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]:
if rs is None or isinstance(rs, int):
return rs

# NOBUG: We can add serialization using get/set_state, but it's not human readable
raise TypeError("`random_state` must be ``None`` or have ``int`` type to convert it to simple type")


RandomState = tpe.Annotated[
tp.Union[None, int, np.random.RandomState],
PlainSerializer(func=_serialize_random_state, when_used="json"),
]


class ModelConfig(BaseConfig):
"""Base model config."""

verbose: int = 0


ModelConfig_T = tp.TypeVar("ModelConfig_T", bound=ModelConfig)


class ModelBase(tp.Generic[ModelConfig_T]):
"""
Base model class.
Expand All @@ -49,10 +77,120 @@ class ModelBase:
recommends_for_warm: bool = False
recommends_for_cold: bool = False

config_class: tp.Type[ModelConfig_T]

def __init__(self, *args: tp.Any, verbose: int = 0, **kwargs: tp.Any) -> None:
self.is_fitted = False
self.verbose = verbose

@tp.overload
def get_config( # noqa: D102
self, mode: tp.Literal["pydantic"], simple_types: bool = False
) -> ModelConfig_T: # pragma: no cover
...

@tp.overload
def get_config( # noqa: D102
self, mode: tp.Literal["dict"] = "dict", simple_types: bool = False
) -> tp.Dict[str, tp.Any]: # pragma: no cover
...

def get_config(
self, mode: tp.Literal["pydantic", "dict"] = "dict", simple_types: bool = False
) -> tp.Union[ModelConfig_T, tp.Dict[str, tp.Any]]:
"""
Return model config.
Parameters
----------
mode : {'pydantic', 'dict'}, default 'dict'
Format of returning config.
simple_types : bool, default False
If True, return config with JSON serializable types.
Only works for `mode='dict'`.
Returns
-------
Pydantic model or dict
Model config.
Raises
------
ValueError
If `mode` is not 'object' or 'dict', or if `simple_types` is ``True`` and format is not 'dict'.
"""
config = self._get_config()
if mode == "pydantic":
if simple_types:
raise ValueError("`simple_types` is not compatible with `mode='pydantic'`")
return config

pydantic_mode = "json" if simple_types else "python"
try:
config_dict = config.model_dump(mode=pydantic_mode)
except PydanticSerializationError as e:
if e.__cause__ is not None:
raise e.__cause__
raise e

if mode == "dict":
return config_dict

raise ValueError(f"Unknown mode: {mode}")

def _get_config(self) -> ModelConfig_T:
raise NotImplementedError(f"`get_config` method is not implemented for `{self.__class__.__name__}` model")

def get_params(self, simple_types: bool = False, sep: str = ".") -> tp.Dict[str, tp.Any]:
"""
Return model parameters.
Same as `get_config` but returns flat dict.
Parameters
----------
simple_types : bool, default False
If True, return config with JSON serializable types.
sep : str, default "."
Separator for nested keys.
Returns
-------
dict
Model parameters.
"""
config_dict = self.get_config(mode="dict", simple_types=simple_types)
config_flat = make_dict_flat(config_dict, sep=sep) # NOBUG: We're not handling lists for now
return config_flat

@classmethod
def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self:
"""
Create model from config.
Parameters
----------
config : dict or ModelConfig
Model config.
Returns
-------
Model instance.
"""
try:
config_cls = cls.config_class
except AttributeError:
raise NotImplementedError(f"`from_config` method is not implemented for `{cls.__name__}` model.") from None

if not isinstance(config, config_cls):
config_obj = cls.config_class.model_validate(config)
else:
config_obj = config
return cls._from_config(config_obj)

@classmethod
def _from_config(cls, config: ModelConfig_T) -> tpe.Self:
raise NotImplementedError()

def fit(self: T, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> T:
"""
Fit model.
Expand Down
21 changes: 20 additions & 1 deletion rectools/models/ease.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,26 @@
import typing as tp

import numpy as np
import typing_extensions as tpe
from scipy import sparse

from rectools import InternalIds
from rectools.dataset import Dataset
from rectools.models.base import ModelConfig
from rectools.types import InternalIdsArray

from .base import ModelBase, Scores
from .rank import Distance, ImplicitRanker


class EASEModel(ModelBase):
class EASEModelConfig(ModelConfig):
"""Config for `EASE` model."""

regularization: float = 500.0
num_threads: int = 1


class EASEModel(ModelBase[EASEModelConfig]):
"""
Embarrassingly Shallow Autoencoders for Sparse Data model.
Expand All @@ -51,17 +60,27 @@ class EASEModel(ModelBase):
recommends_for_warm = False
recommends_for_cold = False

config_class = EASEModelConfig

def __init__(
self,
regularization: float = 500.0,
num_threads: int = 1,
verbose: int = 0,
):

super().__init__(verbose=verbose)
self.weight: np.ndarray
self.regularization = regularization
self.num_threads = num_threads

def _get_config(self) -> EASEModelConfig:
return EASEModelConfig(regularization=self.regularization, num_threads=self.num_threads, verbose=self.verbose)

@classmethod
def _from_config(cls, config: EASEModelConfig) -> tpe.Self:
return cls(regularization=config.regularization, num_threads=config.num_threads, verbose=config.verbose)

def _fit(self, dataset: Dataset) -> None: # type: ignore
ui_csr = dataset.get_user_item_matrix(include_weights=True)

Expand Down
Loading

0 comments on commit 0d1df86

Please sign in to comment.