From c0e3663156991ae3639e1ee707d613705f60f6f8 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Wed, 8 Mar 2023 10:10:24 -0800 Subject: [PATCH] Adds multi-label metrics (accuracy, precision, recall, F1) to LIT. PiperOrigin-RevId: 515069808 --- lit_nlp/components/core.py | 1 + lit_nlp/components/metrics.py | 153 ++++++++++++ lit_nlp/components/metrics_test.py | 371 +++++++++++++++++++++++++++++ 3 files changed, 525 insertions(+) diff --git a/lit_nlp/components/core.py b/lit_nlp/components/core.py index 39b664b1..80860965 100644 --- a/lit_nlp/components/core.py +++ b/lit_nlp/components/core.py @@ -137,6 +137,7 @@ def default_metrics() -> ComponentGroup: return ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), + 'multilabel': metrics.MultilabelMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), 'rouge': metrics.RougeL(), diff --git a/lit_nlp/components/metrics.py b/lit_nlp/components/metrics.py index 027a2f9a..f3d531ab 100644 --- a/lit_nlp/components/metrics.py +++ b/lit_nlp/components/metrics.py @@ -28,6 +28,7 @@ from scipy import stats as scipy_stats from scipy.spatial import distance as scipy_distance from sklearn import metrics as sklearn_metrics +from sklearn import preprocessing as sklearn_preprocessing from rouge_score import rouge_scorer @@ -36,6 +37,8 @@ LitType = types.LitType Spec = types.Spec +_MultiLabelBinarizer = sklearn_preprocessing.MultiLabelBinarizer + def map_pred_keys( data_spec: Spec, model_output_spec: Spec, @@ -559,6 +562,156 @@ def __init__(self): ClassificationMetricsWrapper.__init__(self, MulticlassPairedMetricsImpl()) +class MultilabelMetrics(SimpleMetrics): + """Metrics for assessing the performance of multi-label learning models.""" + + def is_field_compatible( + self, pred_spec: types.LitType, parent_spec: Optional[types.LitType] + ) -> bool: + """Determines the compatibility of a field with these metrics. + + Args: + pred_spec: The field in the model's output spec containing the predicted + labels, must be a `SparseMultilabelPreds` type. + parent_spec: The field in the dataset containing the ground truth, must be + a `SparseMultilabel` field. + + Returns: + True if the pred_spec and parent_spec pair are compatible. + """ + pred_suppported = isinstance(pred_spec, types.SparseMultilabelPreds) + parent_supported = isinstance(parent_spec, types.SparseMultilabel) + return pred_suppported and parent_supported + + def meta_spec(self) -> types.Spec: + """Returns the Spec describing the computed metrics.""" + return { + 'exactmatch': types.MetricResult( + best_value=types.MetricBestValue.HIGHEST, + description=( + 'Multi-label accuracy is the exact match ratio; the proportion ' + 'of exact matches between the predicted labels and the true ' + 'labels across all examples. Closer to 1 is better.' + ), + ), + 'precision': types.MetricResult( + best_value=types.MetricBestValue.HIGHEST, + description=( + 'The mean proportion of correctly predicted labels out of all ' + 'predicted labels across examples. Closer to 1 is better.' + ), + ), + 'recall': types.MetricResult( + best_value=types.MetricBestValue.HIGHEST, + description=( + 'The mean proportion of correctly predicted labels relative to ' + 'the true labels across examples. Closer to 1 is better.' + ), + ), + 'f1': types.MetricResult( + best_value=types.MetricBestValue.HIGHEST, + description=( + 'The mean performance of the model (i.e., the harmonic mean of ' + 'precision and recall) across examples. Closer to 1 is better.' + ), + ), + } + + def compute( + self, + labels: Sequence[Sequence[str]], + preds: Sequence[types.ScoredTextCandidates], + label_spec: types.LitType, + pred_spec: types.LitType, + config: Optional[types.JsonDict] = None, + ) -> lit_components.MetricsDict: + """Computes standard metrics for multi-label classification models. + + Args: + labels: Ground truth against which predictions are compared. + preds: The predictions made by the model. + label_spec: A `SparseMultilabel` instance describing the types of elements + in `labels`. + pred_spec: A `SparseMultilabelPreds` instance describing the types of + elements in `preds`. + config: unused parameter from base class. + + Returns: + A dict containing the accuracy (exact match ratio), precision, recall, and + F1 score for the provided predictions given true labels. + + Raises: + TypeError: If `label_spec` is not a `SparseMultilabel` instance or + `pred_spec` is not a `SparseMultilabelPreds` instance. + ValueError: If `labels` is not the same length as `preds`. + """ + # TODO(b/271864674): Use this config dict to get user-defined thresholds + del config # unused in multi-label metrics, for now. + + if not labels or not preds: + return {} + + num_labels = len(labels) + num_preds = len(preds) + if num_labels != num_preds: + raise ValueError( + 'Must have exactly as many labels as predictions. Received ' + f'{num_labels} labels and {num_preds} preds.' + ) + + if not isinstance(label_spec, types.SparseMultilabel): + raise TypeError( + 'label_spec must be a SparseMultilabel, received ' + f'{type(label_spec).__name__}' + ) + + if not isinstance(pred_spec, types.SparseMultilabelPreds): + raise TypeError( + 'pred_spec must be a SparseMultilabelPreds, received ' + f'{type(pred_spec).__name__}' + ) + + # Learn the complete vocabulary of the possible labels + if pred_spec.vocab: # Try to get the labels from the model's output spec + all_labels: list[Sequence[str]] = [pred_spec.vocab] + elif label_spec.vocab: # Or, try to get them from the dataset spec + all_labels: list[Sequence[str]] = [label_spec.vocab] + else: # Otherwise, derive them from the observed labels + # WARNING: this is only correct for metrics like precision, recall, and + # exact-match accuracy which do not depend on knowing the full label set. + # For per-label accuracy this will give incorrect results if not all + # labels are observed in a given sample. + all_labels: list[Sequence[str]] = [] + all_labels.extend(labels) + all_labels.extend([{l for l, _ in p} for p in preds]) + + binarizer = _MultiLabelBinarizer() + binarizer.fit(all_labels) + + # Next, extract the labels from the ScoredTextCandidates for binarization. + pred_labels = [ + # TODO(b/271864674): Update this set comprehension to respect + # user-defined margins from the config dict or pred_spec.threshold. + {l for l, s in p if s is not None and s > 0.5} for p in preds + ] + + # Transform the true and predicted labels into the binarized vector space. + v_true = binarizer.transform(labels) + v_pred = binarizer.transform(pred_labels) + + # Compute and return the metrics + return { + 'exactmatch': sklearn_metrics.accuracy_score(v_true, v_pred), + 'precision': sklearn_metrics.precision_score( + v_true, v_pred, average='samples' + ), + 'recall': sklearn_metrics.recall_score( + v_true, v_pred, average='samples' + ), + 'f1': sklearn_metrics.f1_score(v_true, v_pred, average='samples'), + } + + class CorpusBLEU(SimpleMetrics): """Corpus BLEU score using SacreBLEU.""" diff --git a/lit_nlp/components/metrics_test.py b/lit_nlp/components/metrics_test.py index 7ce9abc1..acc9eca7 100644 --- a/lit_nlp/components/metrics_test.py +++ b/lit_nlp/components/metrics_test.py @@ -292,6 +292,377 @@ def test_compute_with_metadata_empty(self): testing_utils.assert_deep_almost_equal(self, result, {}) +_MULTI_LABEL_VOCAB = ['1', '2', '3', '4', '5'] + + +class MultilabelMetricsTest(parameterized.TestCase): + + def setUp(self): + super(MultilabelMetricsTest, self).setUp() + self.metrics = metrics.MultilabelMetrics() + + @parameterized.named_parameters( + ('bad_parent', types.SparseMultilabelPreds(), types.Scalar(), False), + ('bad_pred', types.RegressionScore(), types.SparseMultilabel(), False), + ('yes', types.SparseMultilabelPreds(), types.SparseMultilabel(), True), + ) + def test_is_field_compatible( + self, pred: LitType, parent: LitType, expected: bool + ): + self.assertEqual(self.metrics.is_field_compatible(pred, parent), expected) + + def test_meta_spec(self): + meta_spec = self.metrics.meta_spec() + self.assertLen(meta_spec, 4) + self.assertIn('exactmatch', meta_spec) + self.assertIn('precision', meta_spec) + self.assertIn('recall', meta_spec) + self.assertIn('f1', meta_spec) + for spec in meta_spec.values(): + self.assertIsInstance(spec, types.MetricResult) + + @parameterized.named_parameters( + dict( + testcase_name='all_correct_inferred_full_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('1', 1.0), ('2', 0.0), ('3', 1.0), ('4', 0.0), ('5', 0.0)], + [('1', 0.0), ('2', 1.0), ('3', 0.0), ('4', 1.0), ('5', 0.0)], + [('1', 0.0), ('2', 0.0), ('3', 0.0), ('4', 0.0), ('5', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(vocab=_MULTI_LABEL_VOCAB), + expected={ + 'exactmatch': 1, + 'precision': 1, + 'recall': 1, + 'f1': 1, + }, + ), + dict( + testcase_name='all_correct_inferred_sparse_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('1', 1.0), ('3', 1.0)], + [('2', 1.0), ('4', 1.0)], + [('5', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 1, + 'precision': 1, + 'recall': 1, + 'f1': 1, + }, + ), + dict( + testcase_name='all_correct_label_spec_vocab', + labels=[ + ['1'], + ['2'], + ['5'] + ], + preds=[ + [('1', 1.0)], + [('2', 1.0)], + [('5', 1.0)], + ], + label_spec=types.SparseMultilabel(vocab=_MULTI_LABEL_VOCAB), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 1, + 'precision': 1, + 'recall': 1, + 'f1': 1, + }, + ), + dict( + testcase_name='all_correct_pred_spec_vocab', + labels=[ + ['1'], + ['2'], + ['5'] + ], + preds=[ + [('1', 1.0)], + [('2', 1.0)], + [('5', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(vocab=_MULTI_LABEL_VOCAB), + expected={ + 'exactmatch': 1, + 'precision': 1, + 'recall': 1, + 'f1': 1, + }, + ), + dict( + testcase_name='all_wrong_inferred_disjoint_vocab', + labels=[['1']], + preds=[[('3', 1.0), ('4', 1.0)]], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0, + 'precision': 0, + 'recall': 0, + 'f1': 0, + }, + ), + dict( + testcase_name='all_wrong_inferred_full_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('1', 0.0), ('2', 0.0), ('3', 0.0), ('4', 0.0), ('5', 0.5)], + [('1', 0.9), ('2', 0.0), ('3', 0.9), ('4', 0.0), ('5', 0.0)], + [('1', 0.0), ('2', 1.0), ('3', 0.0), ('4', 1.0), ('5', 0.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0, + 'precision': 0, + 'recall': 0, + 'f1': 0, + }, + ), + dict( + testcase_name='all_wrong_inferred_sparse_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('5', 0.5)], + [('1', 0.9), ('3', 0.9)], + [('2', 1.0), ('4', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0, + 'precision': 0, + 'recall': 0, + 'f1': 0, + }, + ), + dict( + testcase_name='mixed_inferred_disjoint_vocab', + labels=[ + ['1', '3'], + ['5'] + ], + preds=[ + [('1', 1.0), ('4', 1.0)], + [('4', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0, + 'precision': 0.25, + 'recall': 0.25, + 'f1': 0.25, + }, + ), + dict( + testcase_name='mixed_inferred_full_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('1', 1.0), ('2', 0.0), ('3', 0.0), ('4', 1.0), ('5', 0.0)], + [('1', 0.0), ('2', 1.0), ('3', 0.0), ('4', 0.0), ('5', 0.0)], + [('1', 0.0), ('2', 0.0), ('3', 0.0), ('4', 0.0), ('5', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0.3333, + 'precision': 0.8333, + 'recall': 0.6667, + 'f1': 0.7222, + }, + ), + dict( + testcase_name='mixed_inferred_sparse_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('1', 1.0), ('4', 1.0)], + [('2', 1.0)], + [('5', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0.3333, + 'precision': 0.8333, + 'recall': 0.6667, + 'f1': 0.7222, + }, + ), + dict( + testcase_name='mixed_label_spec_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('1', 1.0), ('4', 1.0)], + [('2', 1.0)], + [('5', 1.0)], + ], + label_spec=types.SparseMultilabel(vocab=_MULTI_LABEL_VOCAB), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0.3333, + 'precision': 0.8333, + 'recall': 0.6667, + 'f1': 0.7222, + }, + ), + dict( + testcase_name='mixed_label_spec_vocab_superset_of_observed_vocab', + labels=[ + ['1', '3'], + ['5'] + ], + preds=[ + [('1', 1.0), ('4', 1.0)], + [('4', 1.0)], + ], + label_spec=types.SparseMultilabel(vocab=_MULTI_LABEL_VOCAB), + pred_spec=types.SparseMultilabelPreds(), + expected={ + 'exactmatch': 0, + 'precision': 0.25, + 'recall': 0.25, + 'f1': 0.25, + }, + ), + dict( + testcase_name='mixed_pred_spec_vocab', + labels=[ + ['1', '3'], + ['2', '4'], + ['5'] + ], + preds=[ + [('1', 1.0), ('4', 1.0)], + [('2', 1.0)], + [('5', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(vocab=_MULTI_LABEL_VOCAB), + expected={ + 'exactmatch': 0.3333, + 'precision': 0.8333, + 'recall': 0.6667, + 'f1': 0.7222, + }, + ), + dict( + testcase_name='mixed_pred_spec_vocab_superset_of_observed_vocab', + labels=[ + ['1', '3'], + ['5'] + ], + preds=[ + [('1', 1.0), ('4', 1.0)], + [('4', 1.0)], + ], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(vocab=_MULTI_LABEL_VOCAB), + expected={ + 'exactmatch': 0, + 'precision': 0.25, + 'recall': 0.25, + 'f1': 0.25, + }, + ), + ) + def test_compute(self, labels, preds, label_spec, pred_spec, expected): + result = self.metrics.compute(labels, preds, label_spec, pred_spec) + testing_utils.assert_deep_almost_equal(self, result, expected) + + @parameterized.named_parameters( + ('no_labels', [], [('1', 0.1)]), + ('no_labels_no_preds', ['1'], []), + ('no_preds', [], []), + ) + def test_compute_empty(self, labels, preds): + self.assertEmpty( + self.metrics.compute( + labels, + preds, + types.SparseMultilabel(), + types.SparseMultilabelPreds(), + ) + ) + + @parameterized.named_parameters( + dict( + testcase_name='bad_label_spec', + error_type=TypeError, + labels=[['1']], + preds=[[('1', 0.1)]], + label_spec=types.CategoryLabel(), + pred_spec=types.SparseMultilabelPreds(), + ), + dict( + testcase_name='bad_pred_spec', + error_type=TypeError, + labels=[['1']], + preds=[[('1', 0.1)]], + label_spec=types.SparseMultilabel(), + pred_spec=types.MulticlassPreds(vocab=[]), + ), + dict( + testcase_name='more_labels_than_preds', + error_type=ValueError, + labels=[['1'], ['2']], + preds=[[('1', 0.1)]], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + ), + dict( + testcase_name='more_preds_than_labels', + error_type=ValueError, + labels=[['1']], + preds=[[('1', 0.1)], [('2', 0.2)]], + label_spec=types.SparseMultilabel(), + pred_spec=types.SparseMultilabelPreds(), + ), + ) + def test_compute_errors( + self, error_type, labels, preds, label_spec, pred_spec + ): + with self.assertRaises(error_type): + self.metrics.compute(labels, preds, label_spec, pred_spec) + + class CorpusBLEUTest(parameterized.TestCase): def setUp(self):