Skip to content

Commit

Permalink
update all wandb and minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
smellycloud committed Jan 14, 2025
1 parent 4691709 commit 68a2b80
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 320 deletions.
2 changes: 1 addition & 1 deletion tests/test_files_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_save_dataframe_pickle(tmp_path, sample_dataframe):

def test_save_dataframe_invalid_extension(tmp_path, sample_dataframe):
save_path = tmp_path / "test.txt"
with pytest.raises(ValueError, match="The file extension must be provided. E.g. .parquet"):
with pytest.raises(ValueError, match="A valid file extension must be provided.E.g. .pkl or .parquet"):
save_dataframe(sample_dataframe, save_path)

# def test_read_dataframe_csv(tmp_path, sample_dataframe):
Expand Down
198 changes: 66 additions & 132 deletions tests/test_wandb_utils.py
Original file line number Diff line number Diff line change
@@ -1,147 +1,81 @@
import pytest
import wandb
from unittest.mock import patch, MagicMock
import sys
from pathlib import Path

from views_pipeline_core.wandb.utils import (
add_wandb_monthly_metrics,
generate_wandb_log_dict,
log_wandb_log_dict,
add_wandb_metrics,
generate_wandb_step_wise_log_dict,
generate_wandb_month_wise_log_dict,
generate_wandb_time_series_wise_log_dict,
log_wandb_log_dict
)

from views_evaluation.evaluation.metrics import EvaluationMetrics

@pytest.fixture
def mock_wandb():
"""
Fixture to mock wandb methods.
This fixture patches the `wandb.define_metric` and `wandb.log` methods to prevent
actual calls to the Weights and Biases API during testing. It yields the mocked
methods for use in tests.
Yields:
tuple: A tuple containing the mocked `wandb.define_metric` and `wandb.log` methods.
"""
with patch("wandb.define_metric") as mock_define_metric, patch(
"wandb.log"
) as mock_log:
yield mock_define_metric, mock_log


def test_add_wandb_monthly_metrics(mock_wandb):
"""
Test defining WandB metrics for monthly evaluation.
This test verifies that the `add_wandb_monthly_metrics` function correctly defines
the necessary metrics for monthly evaluation using the `wandb.define_metric` method.
It checks that the metrics "monthly/out_sample_month" and "monthly/*" with the step
metric "monthly/out_sample_month" are defined.
with patch('views_pipeline_core.wandb.utils.wandb') as mock_wandb:
yield mock_wandb

Args:
mock_wandb (tuple): A tuple containing the mocked `wandb.define_metric` and `wandb.log` methods.
"""
mock_define_metric, _ = mock_wandb
add_wandb_monthly_metrics()
mock_define_metric.assert_any_call("monthly/out_sample_month")
mock_define_metric.assert_any_call(
"monthly/*", step_metric="monthly/out_sample_month"
)


def test_generate_wandb_log_dict():
"""
Test updating the log dictionary with evaluation metrics for a specific time step.
This test verifies that the `generate_wandb_log_dict` function correctly updates the
log dictionary with evaluation metrics for a given time step. It checks that the
output is a dictionary with keys prefixed by "monthly/" and values that are either
integers or floats.
Raises:
AssertionError: If the output is not a dictionary or if the keys and values do not
meet the expected criteria.
"""
@pytest.fixture
def eval_metrics():
return EvaluationMetrics(RMSLE=0.1, CRPS=0.2)

def test_add_wandb_metrics(mock_wandb):
add_wandb_metrics()
mock_wandb.define_metric.assert_any_call("step-wise/step")
mock_wandb.define_metric.assert_any_call("step-wise/*", step_metric="step-wise/step")
mock_wandb.define_metric.assert_any_call("month-wise/month")
mock_wandb.define_metric.assert_any_call("month-wise/*", step_metric="month-wise/month")
mock_wandb.define_metric.assert_any_call("time-series-wise/time-series")
mock_wandb.define_metric.assert_any_call("time-series-wise/*", step_metric="time-series-wise/time-series")

def test_generate_wandb_step_wise_log_dict(eval_metrics):
log_dict = {}
dict_of_eval_dicts = {
"step01": {"MSE": 0.1, "AP": 0.2, "AUC": 0.3, "Brier": 0.4},
"step02": {"MSE": 0.2, "AP": 0.3, "AUC": 0.4, "Brier": 0.5},
dict_of_eval_dicts = {'step01': eval_metrics}
step = 'step01'
result = generate_wandb_step_wise_log_dict(log_dict, dict_of_eval_dicts, step)
assert result == {
"step-wise/RMSLE": 0.1,
"step-wise/CRPS": 0.2
}
updated_log_dict = generate_wandb_log_dict(log_dict, dict_of_eval_dicts, "step01")

# Verify the type of the output
assert isinstance(updated_log_dict, dict)

# Verify the shape of the output
assert all(key.startswith("monthly/") for key in updated_log_dict.keys())
assert all(isinstance(value, (int, float)) for value in updated_log_dict.values())


def test_generate_wandb_log_dict_with_none_values():
"""
Test updating the log dictionary with evaluation metrics, ignoring None values.
This test verifies that the `generate_wandb_log_dict` function correctly updates the
log dictionary with evaluation metrics for a given time step, ignoring any None values.
It checks that the output is a dictionary with keys prefixed by "monthly/" and values
that are either integers or floats.
Raises:
AssertionError: If the output is not a dictionary or if the keys and values do not
meet the expected criteria.
"""
def test_generate_wandb_month_wise_log_dict(eval_metrics):
log_dict = {}
dict_of_eval_dicts = {"step01": {"MSE": 0.1, "AP": None, "AUC": 0.3, "Brier": None}}
updated_log_dict = generate_wandb_log_dict(log_dict, dict_of_eval_dicts, "step01")

# Verify the type of the output
assert isinstance(updated_log_dict, dict)

# Verify the shape of the output
assert all(key.startswith("monthly/") for key in updated_log_dict.keys())
assert all(isinstance(value, (int, float)) for value in updated_log_dict.values())


def test_log_wandb_log_dict(mock_wandb):
"""
Test logging the WandB log dictionary for each step in the configuration.
This test verifies that the `log_wandb_log_dict` function correctly logs the WandB
log dictionary for each step specified in the configuration. It checks that the
output for each step is a dictionary with keys prefixed by "monthly/" and values
that are either integers or floats, and that the `wandb.log` method is called with
the expected log dictionary.
Args:
mock_wandb (tuple): A tuple containing the mocked `wandb.define_metric` and `wandb.log` methods.
Raises:
AssertionError: If the output is not a dictionary or if the keys and values do not
meet the expected criteria, or if `wandb.log` is not called with
the expected log dictionary.
"""
_, mock_log = mock_wandb
config = {"steps": [1, 2]}
evaluation = {
"step01": {"MSE": 0.1, "AP": 0.2, "AUC": 0.3, "Brier": 0.4},
"step02": {"MSE": 0.2, "AP": 0.3, "AUC": 0.4, "Brier": 0.5},
dict_of_eval_dicts = {'month501': eval_metrics}
month = 'month501'
result = generate_wandb_month_wise_log_dict(log_dict, dict_of_eval_dicts, month)
assert result == {
"month-wise/RMSLE": 0.1,
"month-wise/CRPS": 0.2
}
log_wandb_log_dict(config, evaluation)

# Verify the type and shape of the output for each step
for t in config["steps"]:
step = f"step{str(t).zfill(2)}"
expected_log_dict = generate_wandb_log_dict(
{"monthly/out_sample_month": t}, evaluation, step
)

# Verify the type of the output
assert isinstance(expected_log_dict, dict)

# Verify the shape of the output
assert all(key.startswith("monthly/") for key in expected_log_dict.keys())
assert all(
isinstance(value, (int, float)) for value in expected_log_dict.values()
)
def test_generate_wandb_time_series_wise_log_dict(eval_metrics):
log_dict = {}
dict_of_eval_dicts = {'ts01': eval_metrics}
time_series = 'ts01'
result = generate_wandb_time_series_wise_log_dict(log_dict, dict_of_eval_dicts, time_series)
assert result == {
"time-series-wise/RMSLE": 0.1,
"time-series-wise/CRPS": 0.2
}

mock_log.assert_any_call(expected_log_dict)
def test_log_wandb_log_dict(mock_wandb, eval_metrics):
step_wise_evaluation = {'step01': eval_metrics}
time_series_wise_evaluation = {'ts01': eval_metrics}
month_wise_evaluation = {'month501': eval_metrics}

log_wandb_log_dict(step_wise_evaluation, time_series_wise_evaluation, month_wise_evaluation)

mock_wandb.log.assert_any_call({
"step-wise/step": 1,
"step-wise/RMSLE": 0.1,
"step-wise/CRPS": 0.2
})
mock_wandb.log.assert_any_call({
"month-wise/month": 501,
"month-wise/RMSLE": 0.1,
"month-wise/CRPS": 0.2
})
mock_wandb.log.assert_any_call({
"time-series-wise/time-series": 1,
"time-series-wise/RMSLE": 0.1,
"time-series-wise/CRPS": 0.2
})
2 changes: 1 addition & 1 deletion views_pipeline_core/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class PipelineConfig:
def __init__(self):
self._dataframe_format = '.pkl'
self._dataframe_format = '.parquet'
self._model_format = '.pkl'
self._organization_name = 'views'
# self._version_range = ">=0.2.0,<1.0.0"
Expand Down
24 changes: 12 additions & 12 deletions views_pipeline_core/files/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,26 @@ def save_dataframe(dataframe: pd.DataFrame, save_path: Union[str, Path]):
- ValueError: If the file extension is not provided or is not supported.
- Exception: If there is an error saving the DataFrame.
"""
FILE_EXTENSION_ERROR_MESSAGE = "The file extension must be provided. E.g. .parquet"
FILE_EXTENSION_ERROR_MESSAGE = "A valid file extension must be provided.E.g. .pkl or .parquet"

# Checks
if not isinstance(save_path, Path):
save_path = Path(save_path)
file_extension = save_path.suffix.lower()
file_extension = save_path.suffix
if dataframe is None:
raise ValueError("The DataFrame must be provided")
if not isinstance(dataframe, pd.DataFrame):
raise ValueError("The DataFrame must be a pandas DataFrame")
if file_extension is None or file_extension == "":
raise ValueError(f"Invalid file extension {file_extension} found. {FILE_EXTENSION_ERROR_MESSAGE}")
raise ValueError(f"No file extension {file_extension} found. {FILE_EXTENSION_ERROR_MESSAGE}")

try:
logger.debug(f"Saving the DataFrame to {save_path} in {file_extension} format")
if file_extension == ".csv":
dataframe.to_csv(save_path, index=True)
# if file_extension == ".csv":
# dataframe.to_csv(save_path, index=True)
# elif file_extension == ".xlsx":
# dataframe.to_excel(save_path)
elif file_extension == ".parquet":
if file_extension == ".parquet":
dataframe.to_parquet(save_path)
elif file_extension == ".pkl":
dataframe.to_pickle(save_path)
Expand All @@ -189,22 +189,22 @@ def read_dataframe(file_path: Union[str, Path]) -> pd.DataFrame:
- ValueError: If the file extension is not provided or is not supported.
- Exception: If there is an error reading the DataFrame.
"""
FILE_EXTENSION_ERROR_MESSAGE = "The file extension must be provided. E.g. .parquet"
FILE_EXTENSION_ERROR_MESSAGE = "A valid extension must be provided. E.g. .pkl or .parquet"

# Checks
if not isinstance(file_path, Path):
file_path = Path(file_path)
file_extension = file_path.suffix.lower()
file_extension = file_path.suffix
if file_extension is None or file_extension == "":
raise ValueError(f"Invalid file extension {file_extension} found. {FILE_EXTENSION_ERROR_MESSAGE}")
raise ValueError(f"No file extension {file_extension} found. {FILE_EXTENSION_ERROR_MESSAGE}")

try:
logger.debug(f"Reading the DataFrame from {file_path} in {file_extension} format")
if file_extension == ".csv":
return pd.read_csv(file_path, index_col=[0, 1])
# if file_extension == ".csv":
# return pd.read_csv(file_path, index_col=[0, 1])
# elif file_extension == ".xlsx":
# return pd.read_excel(file_path)
elif file_extension == ".parquet":
if file_extension == ".parquet":
return pd.read_parquet(file_path)
elif file_extension == ".pkl":
return pd.read_pickle(file_path)
Expand Down
Loading

0 comments on commit 68a2b80

Please sign in to comment.