diff --git a/CHANGELOG.md b/CHANGELOG.md index ac9a48cf..af5ce099 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- If `Column.Model` is not provided in `VisualApp`, then default model name added to dataframe ([#235](https://github.com/MobileTeleSystems/RecTools/pull/235)) ## [0.9.0] - 11.12.2024 diff --git a/rectools/visuals/visual_app.py b/rectools/visuals/visual_app.py index 2de67abf..a2a6d89d 100644 --- a/rectools/visuals/visual_app.py +++ b/rectools/visuals/visual_app.py @@ -29,6 +29,7 @@ MIN_WIDTH_LIMIT = 10 REQUEST_NAMES_COL = "request_name" REQUEST_IDS_COL = "request_id" +DEFAULT_MODEL_NAME = "model" VisualAppT = tp.TypeVar("VisualAppT", bound="VisualAppBase") @@ -71,7 +72,8 @@ def from_raw( ---------- reco : tp.Union[pd.DataFrame, TablesDict] Recommendations from different models in a form of a pd.DataFrame or a dict. - In DataFrame form model names must be specified in `Columns.Model` column. In dict form + In DataFrame form model names must be specified in `Columns.Model` column. + If not, `Columns.Model` column will be created with default value ``model1``. In dict form model names are supposed to be dict keys. item_data : pd.DataFrame Data for items that is used for visualisation in both interactions and recommendations @@ -100,7 +102,7 @@ def from_raw( if isinstance(reco, pd.DataFrame): if Columns.Model not in reco.columns: - raise KeyError("Missing `{Columns.Model}` column in `reco` DataFrame") + reco[Columns.Model] = DEFAULT_MODEL_NAME reco = cls._df_to_tables_dict(reco, Columns.Model) cls._check_columns_present_in_reco(reco=reco, id_col=id_col) diff --git a/tests/visuals/test_visual_app.py b/tests/visuals/test_visual_app.py index ab724f7f..eb05ada3 100644 --- a/tests/visuals/test_visual_app.py +++ b/tests/visuals/test_visual_app.py @@ -21,7 +21,14 @@ import pytest from rectools import Columns, ExternalId -from rectools.visuals.visual_app import AppDataStorage, ItemToItemVisualApp, StorageFiles, TablesDict, VisualApp +from rectools.visuals.visual_app import ( + DEFAULT_MODEL_NAME, + AppDataStorage, + ItemToItemVisualApp, + StorageFiles, + TablesDict, + VisualApp, +) RECO_U2I: TablesDict = { "model1": pd.DataFrame( @@ -183,7 +190,6 @@ def test_empty_selected_requests(self, selected_requests: tp.Optional[tp.Dict[tp assert "random_2" in ads.selected_requests def test_missing_columns_validation(self) -> None: - # Missing `Columns.User` for u2i with pytest.raises(KeyError): incorrect_u2i_reco: TablesDict = { @@ -228,18 +234,29 @@ def test_missing_columns_validation(self) -> None: selected_requests=SELECTED_REQUESTS_U2I, ) - # Missing `Columns.Model` in reco pd.DataFrame - with pytest.raises(KeyError): - incorrect_reco = pd.DataFrame( - {Columns.User: [1, 2, 3, 4], Columns.Item: [3, 4, 3, 4], Columns.Score: [0.99, 0.9, 0.5, 0.5]} - ) - AppDataStorage.from_raw( - reco=incorrect_reco, - item_data=ITEM_DATA, - interactions=INTERACTIONS, - is_u2i=True, - selected_requests=SELECTED_REQUESTS_U2I, - ) + def test_successful_path_with_missing_model(self) -> None: + # Missing `Columns.Model` + reco_without_model = pd.DataFrame( + {Columns.User: [1, 2, 3, 4], Columns.Item: [3, 4, 3, 4], Columns.Score: [0.99, 0.9, 0.5, 0.5]} + ) + ads = AppDataStorage.from_raw( + reco=reco_without_model, + item_data=ITEM_DATA, + interactions=INTERACTIONS, + is_u2i=True, + selected_requests=SELECTED_REQUESTS_U2I, + ) + expected_grouped_reco = { + DEFAULT_MODEL_NAME: { + "user_one": pd.DataFrame({Columns.Item: [3], "feature_1": ["one"], Columns.Score: [0.99]}), + "user_three": pd.DataFrame({Columns.Item: [3], "feature_1": ["one"], Columns.Score: [0.5]}), + } + } + assert expected_grouped_reco.keys() == ads.grouped_reco.keys() + for model_name, model_reco in expected_grouped_reco.items(): + assert model_reco.keys() == ads.grouped_reco[model_name].keys() + for user_name, user_reco in model_reco.items(): + pd.testing.assert_frame_equal(user_reco, ads.grouped_reco[model_name][user_name]) def test_incorrect_interactions_for_reco_case(self) -> None: