From fc899662455092f6bd4c4a0aa06a967c880f4eca Mon Sep 17 00:00:00 2001 From: Dylan <52908667+smellycloud@users.noreply.github.com> Date: Wed, 22 Jan 2025 23:33:40 +0100 Subject: [PATCH] fix eval metrics table --- pyproject.toml | 2 +- tests/test_model_manager.py | 6 ++++-- views_pipeline_core/managers/ensemble.py | 4 +++- views_pipeline_core/managers/model.py | 23 +++++++++++++---------- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ae01e51..6389ec0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ ingester3 = "^2.1.0" properscoring = "^0.1" wandb = "^0.18.7" pyprojroot="^0.3.0" -views_evaluation = { git = "https://github.com/views-platform/views-evaluation.git", branch = "main" } +views-evaluation = { git = "https://github.com/views-platform/views-evaluation.git", branch = "main" } views-forecasts = "^0.5.5" diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index 5d986cc..4d291c2 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -89,7 +89,8 @@ def get_deployment_config(): manager._wandb_alert(title="Test Alert", text="This is a test alert", level="info") mock_alert.assert_called_once_with(title="Test Alert", text="This is a test alert", level="info") -def test_model_manager_init(mock_model_path): +@patch("views_forecasts.extensions.ForecastAccessor.read_store") +def test_model_manager_init(mock_model_path, mock_read_store): """ Test the initialization of the ModelManager class. @@ -133,7 +134,8 @@ def get_meta_config(): assert manager._config_hyperparameters == {"hp_key": "hp_value"} assert manager._config_meta == {"meta_key": "meta_value"} -def test_load_config(mock_model_path): +@patch("views_forecasts.extensions.ForecastAccessor.read_store") +def test_load_config(mock_model_pat, mock_read_store): """ Test the __load_config method of the ModelManager class. diff --git a/views_pipeline_core/managers/ensemble.py b/views_pipeline_core/managers/ensemble.py index 049443a..13731e2 100644 --- a/views_pipeline_core/managers/ensemble.py +++ b/views_pipeline_core/managers/ensemble.py @@ -5,7 +5,7 @@ from views_pipeline_core.files.utils import save_dataframe, read_dataframe from views_pipeline_core.configs.pipeline import PipelineConfig from views_evaluation.evaluation.metrics import MetricsManager -from views_forecasts.extensions import * +from views_forecasts.extensions import * from typing import Union, Optional, List, Dict import wandb import logging @@ -353,6 +353,7 @@ def _train_model_artifact( def _evaluate_model_artifact( self, model_name: str, run_type: str, eval_type: str ) -> List[pd.DataFrame]: + # from views_forecasts.extensions import ForecastAccessor logger.info(f"Evaluating single model {model_name}...") model_path = ModelPathManager(model_name) @@ -423,6 +424,7 @@ def _evaluate_model_artifact( return preds def _forecast_model_artifact(self, model_name: str, run_type: str) -> pd.DataFrame: + # from views_forecasts.extensions import ForecastAccessor logger.info(f"Forecasting single model {model_name}...") model_path = ModelPathManager(model_name) diff --git a/views_pipeline_core/managers/model.py b/views_pipeline_core/managers/model.py index 9f9aa35..74d85d2 100644 --- a/views_pipeline_core/managers/model.py +++ b/views_pipeline_core/managers/model.py @@ -22,7 +22,6 @@ from views_evaluation.evaluation.metrics import MetricsManager from views_forecasts.extensions import * import traceback -import numpy as np logger = logging.getLogger(__name__) @@ -645,9 +644,9 @@ def __init__( ) -> None: """ Initializes the ModelManager with the given model path. - Args: model_path (ModelPathManager): The path manager for the model. + wandb_notifications (bool, optional): Whether to enable WandB notifications. Defaults to True. """ self._entity = "views_pipeline" self._model_repo = "views-models" @@ -805,7 +804,8 @@ def __get_pred_store_name(self) -> str: # self._owner, self._model_repo # ) from views_pipeline_core.managers.package import PackageManager - + # from views_forecasts.extensions import ViewsMetadata + version = PackageManager.get_latest_release_version_from_github( repository_name=self._model_repo ) @@ -955,18 +955,20 @@ def _generate_evaluation_table(self, metric_dict: Dict) -> str: Returns: str: A formatted string representing the evaluation table. """ - table_str = "\n{:<70} {:<30}".format("Metric", "Value") - table_str += "-" * 100 + "\n" + from tabulate import tabulate + # create an empty dataframe with columns 'Metric' and 'Value' + metric_df = pd.DataFrame(columns=["Metric", "Value"]) for key, value in metric_dict.items(): try: if not str(key).startswith("_"): - value = float( - value - ) # Super hacky way to filter out metrics. 0/10 do not recommend - table_str += "{:<70}\t{:<30}\n".format(str(key), value) + value = float(value) + # add metric and value to the dataframe + metric_df = metric_df.append({"Metric": key, "Value": value}, ignore_index=True) except: continue - return table_str + result = tabulate(metric_df, headers='keys', tablefmt='grid') + print(result) + return f"```\n{result}\n```" def _save_model_artifact(self, run_type): """ @@ -1103,6 +1105,7 @@ def _save_predictions( path_generated (str or Path): The path where the predictions should be saved. sequence_number (int): The sequence number. """ + from views_forecasts.extensions import ForecastAccessor try: path_generated = Path(path_generated) path_generated.mkdir(parents=True, exist_ok=True)