Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/development' into sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaolong0728 committed Jan 23, 2025
2 parents c2b6b95 + fc89966 commit e2ed406
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ingester3 = "^2.1.0"
properscoring = "^0.1"
wandb = "^0.18.7"
pyprojroot="^0.3.0"
views_evaluation = { git = "https://github.com/views-platform/views-evaluation.git", branch = "main" }
views-evaluation = { git = "https://github.com/views-platform/views-evaluation.git", branch = "main" }
views-forecasts = "^0.5.5"


Expand Down
6 changes: 4 additions & 2 deletions tests/test_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_deployment_config():
manager._wandb_alert(title="Test Alert", text="This is a test alert", level="info")
mock_alert.assert_called_once_with(title="Test Alert", text="This is a test alert", level="info")

def test_model_manager_init(mock_model_path):
@patch("views_forecasts.extensions.ForecastAccessor.read_store")
def test_model_manager_init(mock_model_path, mock_read_store):
"""
Test the initialization of the ModelManager class.
Expand Down Expand Up @@ -133,7 +134,8 @@ def get_meta_config():
assert manager._config_hyperparameters == {"hp_key": "hp_value"}
assert manager._config_meta == {"meta_key": "meta_value"}

def test_load_config(mock_model_path):
@patch("views_forecasts.extensions.ForecastAccessor.read_store")
def test_load_config(mock_model_pat, mock_read_store):
"""
Test the __load_config method of the ModelManager class.
Expand Down
4 changes: 3 additions & 1 deletion views_pipeline_core/managers/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from views_pipeline_core.files.utils import save_dataframe, read_dataframe
from views_pipeline_core.configs.pipeline import PipelineConfig
from views_evaluation.evaluation.metrics import MetricsManager
from views_forecasts.extensions import *
from views_forecasts.extensions import *
from typing import Union, Optional, List, Dict
import wandb
import logging
Expand Down Expand Up @@ -353,6 +353,7 @@ def _train_model_artifact(
def _evaluate_model_artifact(
self, model_name: str, run_type: str, eval_type: str
) -> List[pd.DataFrame]:
# from views_forecasts.extensions import ForecastAccessor
logger.info(f"Evaluating single model {model_name}...")

model_path = ModelPathManager(model_name)
Expand Down Expand Up @@ -423,6 +424,7 @@ def _evaluate_model_artifact(
return preds

def _forecast_model_artifact(self, model_name: str, run_type: str) -> pd.DataFrame:
# from views_forecasts.extensions import ForecastAccessor
logger.info(f"Forecasting single model {model_name}...")

model_path = ModelPathManager(model_name)
Expand Down
23 changes: 13 additions & 10 deletions views_pipeline_core/managers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from views_evaluation.evaluation.metrics import MetricsManager
from views_forecasts.extensions import *
import traceback
import numpy as np


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -648,9 +647,9 @@ def __init__(
) -> None:
"""
Initializes the ModelManager with the given model path.
Args:
model_path (ModelPathManager): The path manager for the model.
wandb_notifications (bool, optional): Whether to enable WandB notifications. Defaults to True.
"""
self._entity = "views_pipeline"
self._model_repo = "views-models"
Expand Down Expand Up @@ -786,7 +785,8 @@ def __get_pred_store_name(self) -> str:
# self._owner, self._model_repo
# )
from views_pipeline_core.managers.package import PackageManager

# from views_forecasts.extensions import ViewsMetadata

version = PackageManager.get_latest_release_version_from_github(
repository_name=self._model_repo
)
Expand Down Expand Up @@ -937,18 +937,20 @@ def _generate_evaluation_table(self, metric_dict: Dict) -> str:
Returns:
str: A formatted string representing the evaluation table.
"""
table_str = "\n{:<70} {:<30}".format("Metric", "Value")
table_str += "-" * 100 + "\n"
from tabulate import tabulate
# create an empty dataframe with columns 'Metric' and 'Value'
metric_df = pd.DataFrame(columns=["Metric", "Value"])
for key, value in metric_dict.items():
try:
if not str(key).startswith("_"):
value = float(
value
) # Super hacky way to filter out metrics. 0/10 do not recommend
table_str += "{:<70}\t{:<30}\n".format(str(key), value)
value = float(value)
# add metric and value to the dataframe
metric_df = metric_df.append({"Metric": key, "Value": value}, ignore_index=True)
except:
continue
return table_str
result = tabulate(metric_df, headers='keys', tablefmt='grid')
print(result)
return f"```\n{result}\n```"

def _save_model_artifact(self, run_type: str) -> None:
"""
Expand Down Expand Up @@ -1085,6 +1087,7 @@ def _save_predictions(
path_generated (str or Path): The path where the predictions should be saved.
sequence_number (int): The sequence number.
"""
from views_forecasts.extensions import ForecastAccessor
try:
path_generated = Path(path_generated)
path_generated.mkdir(parents=True, exist_ok=True)
Expand Down

0 comments on commit e2ed406

Please sign in to comment.