From ffb37a9054b62a2b7ff4bedb30f060bb5ac4915f Mon Sep 17 00:00:00 2001 From: skadio Date: Thu, 7 Sep 2023 10:21:43 -0400 Subject: [PATCH] update --- jurity/fairness/for_difference.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/jurity/fairness/for_difference.py b/jurity/fairness/for_difference.py index 1d9a8e8..6e352a3 100644 --- a/jurity/fairness/for_difference.py +++ b/jurity/fairness/for_difference.py @@ -10,10 +10,11 @@ from jurity.fairness.base import _BaseBinaryFairness from jurity.utils import check_and_convert_list_types -from jurity.utils import check_inputs_validity +from jurity.utils import check_inputs from jurity.utils import performance_measures from jurity.utils import split_array_based_on_membership_label + class FORDifference(_BaseBinaryFairness): def __init__(self): @@ -26,7 +27,7 @@ def __init__(self): @staticmethod def get_score(labels: Union[List, np.ndarray, pd.Series], predictions: Union[List, np.ndarray, pd.Series], - is_member: Union[List, np.ndarray, pd.Series], + memberships: Union[List, np.ndarray, pd.Series], membership_label: Union[str, float, int] = 1) -> float: """ The equality (or lack thereof) of the false omission rates across groups is an important fairness metric. @@ -43,7 +44,7 @@ def get_score(labels: Union[List, np.ndarray, pd.Series], Binary ground truth labels for the provided dataset (0/1). predictions: Union[List, np.ndarray, pd.Series] Binary predictions from some black-box classifier (0/1). - is_member: Union[List, np.ndarray, pd.Series] + memberships: Union[List, np.ndarray, pd.Series] Binary membership labels (0/1). membership_label: Union[str, float, int] Value indicating group membership. @@ -54,10 +55,11 @@ def get_score(labels: Union[List, np.ndarray, pd.Series], False Omission Rate difference between groups. """ # Logic to check input types. - check_inputs_validity(labels=labels, predictions=predictions, is_member=is_member, optional_labels=False) + check_inputs(predictions=predictions, memberships=memberships, membership_labels=membership_label, + must_have_labels=True, labels=labels) # List needs to be converted to np for indexing - is_member = check_and_convert_list_types(is_member) + is_member = check_and_convert_list_types(memberships) predictions = check_and_convert_list_types(predictions) labels = check_and_convert_list_types(labels)