Skip to content

Commit

Permalink
Add option to display errored metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Dec 1, 2021
1 parent 15aab1c commit 13d2c23
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion sdv/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _select_metrics(synthetic_data, metrics):


def evaluate(synthetic_data, real_data=None, metadata=None, root_path=None,
table_name=None, metrics=None, aggregate=True):
table_name=None, metrics=None, aggregate=True, report_errors=False):
"""Apply multiple metrics at once.
Args:
Expand All @@ -117,6 +117,9 @@ def evaluate(synthetic_data, real_data=None, metadata=None, root_path=None,
If ``get_report`` is ``False``, whether to compute the mean of all the normalized
scores to return a single float value or return a ``dict`` containing the score
that each metric obtained. Defaults to ``True``.
report_errors (bool):
If ``True``, report the metrics that errored out and their corresponding errors.
If ``False``, omit the metrics that errored out.
Return:
float or sdmetrics.MetricsReport
Expand All @@ -133,7 +136,14 @@ def evaluate(synthetic_data, real_data=None, metadata=None, root_path=None,
synthetic_data = synthetic_data[table]

scores = sdmetrics.compute_metrics(metrics, real_data, synthetic_data, metadata=metadata)

if report_errors:
errored_metrics = scores[~scores['error'].isnull()]
errored_metrics = errored_metrics[['metric', 'name', 'error']]
print(f'The following metrics errored out: \n %s', errored_metrics.to_string())

scores.dropna(inplace=True)
scores.drop(columns=['error'], errors='ignore')

if aggregate:
return scores.normalized_score.mean()
Expand Down

0 comments on commit 13d2c23

Please sign in to comment.