-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4691709
commit 68a2b80
Showing
6 changed files
with
336 additions
and
320 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,147 +1,81 @@ | ||
import pytest | ||
import wandb | ||
from unittest.mock import patch, MagicMock | ||
import sys | ||
from pathlib import Path | ||
|
||
from views_pipeline_core.wandb.utils import ( | ||
add_wandb_monthly_metrics, | ||
generate_wandb_log_dict, | ||
log_wandb_log_dict, | ||
add_wandb_metrics, | ||
generate_wandb_step_wise_log_dict, | ||
generate_wandb_month_wise_log_dict, | ||
generate_wandb_time_series_wise_log_dict, | ||
log_wandb_log_dict | ||
) | ||
|
||
from views_evaluation.evaluation.metrics import EvaluationMetrics | ||
|
||
@pytest.fixture | ||
def mock_wandb(): | ||
""" | ||
Fixture to mock wandb methods. | ||
This fixture patches the `wandb.define_metric` and `wandb.log` methods to prevent | ||
actual calls to the Weights and Biases API during testing. It yields the mocked | ||
methods for use in tests. | ||
Yields: | ||
tuple: A tuple containing the mocked `wandb.define_metric` and `wandb.log` methods. | ||
""" | ||
with patch("wandb.define_metric") as mock_define_metric, patch( | ||
"wandb.log" | ||
) as mock_log: | ||
yield mock_define_metric, mock_log | ||
|
||
|
||
def test_add_wandb_monthly_metrics(mock_wandb): | ||
""" | ||
Test defining WandB metrics for monthly evaluation. | ||
This test verifies that the `add_wandb_monthly_metrics` function correctly defines | ||
the necessary metrics for monthly evaluation using the `wandb.define_metric` method. | ||
It checks that the metrics "monthly/out_sample_month" and "monthly/*" with the step | ||
metric "monthly/out_sample_month" are defined. | ||
with patch('views_pipeline_core.wandb.utils.wandb') as mock_wandb: | ||
yield mock_wandb | ||
|
||
Args: | ||
mock_wandb (tuple): A tuple containing the mocked `wandb.define_metric` and `wandb.log` methods. | ||
""" | ||
mock_define_metric, _ = mock_wandb | ||
add_wandb_monthly_metrics() | ||
mock_define_metric.assert_any_call("monthly/out_sample_month") | ||
mock_define_metric.assert_any_call( | ||
"monthly/*", step_metric="monthly/out_sample_month" | ||
) | ||
|
||
|
||
def test_generate_wandb_log_dict(): | ||
""" | ||
Test updating the log dictionary with evaluation metrics for a specific time step. | ||
This test verifies that the `generate_wandb_log_dict` function correctly updates the | ||
log dictionary with evaluation metrics for a given time step. It checks that the | ||
output is a dictionary with keys prefixed by "monthly/" and values that are either | ||
integers or floats. | ||
Raises: | ||
AssertionError: If the output is not a dictionary or if the keys and values do not | ||
meet the expected criteria. | ||
""" | ||
@pytest.fixture | ||
def eval_metrics(): | ||
return EvaluationMetrics(RMSLE=0.1, CRPS=0.2) | ||
|
||
def test_add_wandb_metrics(mock_wandb): | ||
add_wandb_metrics() | ||
mock_wandb.define_metric.assert_any_call("step-wise/step") | ||
mock_wandb.define_metric.assert_any_call("step-wise/*", step_metric="step-wise/step") | ||
mock_wandb.define_metric.assert_any_call("month-wise/month") | ||
mock_wandb.define_metric.assert_any_call("month-wise/*", step_metric="month-wise/month") | ||
mock_wandb.define_metric.assert_any_call("time-series-wise/time-series") | ||
mock_wandb.define_metric.assert_any_call("time-series-wise/*", step_metric="time-series-wise/time-series") | ||
|
||
def test_generate_wandb_step_wise_log_dict(eval_metrics): | ||
log_dict = {} | ||
dict_of_eval_dicts = { | ||
"step01": {"MSE": 0.1, "AP": 0.2, "AUC": 0.3, "Brier": 0.4}, | ||
"step02": {"MSE": 0.2, "AP": 0.3, "AUC": 0.4, "Brier": 0.5}, | ||
dict_of_eval_dicts = {'step01': eval_metrics} | ||
step = 'step01' | ||
result = generate_wandb_step_wise_log_dict(log_dict, dict_of_eval_dicts, step) | ||
assert result == { | ||
"step-wise/RMSLE": 0.1, | ||
"step-wise/CRPS": 0.2 | ||
} | ||
updated_log_dict = generate_wandb_log_dict(log_dict, dict_of_eval_dicts, "step01") | ||
|
||
# Verify the type of the output | ||
assert isinstance(updated_log_dict, dict) | ||
|
||
# Verify the shape of the output | ||
assert all(key.startswith("monthly/") for key in updated_log_dict.keys()) | ||
assert all(isinstance(value, (int, float)) for value in updated_log_dict.values()) | ||
|
||
|
||
def test_generate_wandb_log_dict_with_none_values(): | ||
""" | ||
Test updating the log dictionary with evaluation metrics, ignoring None values. | ||
This test verifies that the `generate_wandb_log_dict` function correctly updates the | ||
log dictionary with evaluation metrics for a given time step, ignoring any None values. | ||
It checks that the output is a dictionary with keys prefixed by "monthly/" and values | ||
that are either integers or floats. | ||
Raises: | ||
AssertionError: If the output is not a dictionary or if the keys and values do not | ||
meet the expected criteria. | ||
""" | ||
def test_generate_wandb_month_wise_log_dict(eval_metrics): | ||
log_dict = {} | ||
dict_of_eval_dicts = {"step01": {"MSE": 0.1, "AP": None, "AUC": 0.3, "Brier": None}} | ||
updated_log_dict = generate_wandb_log_dict(log_dict, dict_of_eval_dicts, "step01") | ||
|
||
# Verify the type of the output | ||
assert isinstance(updated_log_dict, dict) | ||
|
||
# Verify the shape of the output | ||
assert all(key.startswith("monthly/") for key in updated_log_dict.keys()) | ||
assert all(isinstance(value, (int, float)) for value in updated_log_dict.values()) | ||
|
||
|
||
def test_log_wandb_log_dict(mock_wandb): | ||
""" | ||
Test logging the WandB log dictionary for each step in the configuration. | ||
This test verifies that the `log_wandb_log_dict` function correctly logs the WandB | ||
log dictionary for each step specified in the configuration. It checks that the | ||
output for each step is a dictionary with keys prefixed by "monthly/" and values | ||
that are either integers or floats, and that the `wandb.log` method is called with | ||
the expected log dictionary. | ||
Args: | ||
mock_wandb (tuple): A tuple containing the mocked `wandb.define_metric` and `wandb.log` methods. | ||
Raises: | ||
AssertionError: If the output is not a dictionary or if the keys and values do not | ||
meet the expected criteria, or if `wandb.log` is not called with | ||
the expected log dictionary. | ||
""" | ||
_, mock_log = mock_wandb | ||
config = {"steps": [1, 2]} | ||
evaluation = { | ||
"step01": {"MSE": 0.1, "AP": 0.2, "AUC": 0.3, "Brier": 0.4}, | ||
"step02": {"MSE": 0.2, "AP": 0.3, "AUC": 0.4, "Brier": 0.5}, | ||
dict_of_eval_dicts = {'month501': eval_metrics} | ||
month = 'month501' | ||
result = generate_wandb_month_wise_log_dict(log_dict, dict_of_eval_dicts, month) | ||
assert result == { | ||
"month-wise/RMSLE": 0.1, | ||
"month-wise/CRPS": 0.2 | ||
} | ||
log_wandb_log_dict(config, evaluation) | ||
|
||
# Verify the type and shape of the output for each step | ||
for t in config["steps"]: | ||
step = f"step{str(t).zfill(2)}" | ||
expected_log_dict = generate_wandb_log_dict( | ||
{"monthly/out_sample_month": t}, evaluation, step | ||
) | ||
|
||
# Verify the type of the output | ||
assert isinstance(expected_log_dict, dict) | ||
|
||
# Verify the shape of the output | ||
assert all(key.startswith("monthly/") for key in expected_log_dict.keys()) | ||
assert all( | ||
isinstance(value, (int, float)) for value in expected_log_dict.values() | ||
) | ||
def test_generate_wandb_time_series_wise_log_dict(eval_metrics): | ||
log_dict = {} | ||
dict_of_eval_dicts = {'ts01': eval_metrics} | ||
time_series = 'ts01' | ||
result = generate_wandb_time_series_wise_log_dict(log_dict, dict_of_eval_dicts, time_series) | ||
assert result == { | ||
"time-series-wise/RMSLE": 0.1, | ||
"time-series-wise/CRPS": 0.2 | ||
} | ||
|
||
mock_log.assert_any_call(expected_log_dict) | ||
def test_log_wandb_log_dict(mock_wandb, eval_metrics): | ||
step_wise_evaluation = {'step01': eval_metrics} | ||
time_series_wise_evaluation = {'ts01': eval_metrics} | ||
month_wise_evaluation = {'month501': eval_metrics} | ||
|
||
log_wandb_log_dict(step_wise_evaluation, time_series_wise_evaluation, month_wise_evaluation) | ||
|
||
mock_wandb.log.assert_any_call({ | ||
"step-wise/step": 1, | ||
"step-wise/RMSLE": 0.1, | ||
"step-wise/CRPS": 0.2 | ||
}) | ||
mock_wandb.log.assert_any_call({ | ||
"month-wise/month": 501, | ||
"month-wise/RMSLE": 0.1, | ||
"month-wise/CRPS": 0.2 | ||
}) | ||
mock_wandb.log.assert_any_call({ | ||
"time-series-wise/time-series": 1, | ||
"time-series-wise/RMSLE": 0.1, | ||
"time-series-wise/CRPS": 0.2 | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.