From 51842babef63f9aa29d1d2add14633c4640627fc Mon Sep 17 00:00:00 2001 From: Googler Date: Sun, 13 Feb 2022 20:57:28 -0800 Subject: [PATCH] An interpreter that generates data for ROC and PR plots. PiperOrigin-RevId: 428411165 --- lit_nlp/api/types.py | 10 ++ lit_nlp/client/lib/types.ts | 2 +- lit_nlp/components/curves.py | 127 +++++++++++++++++ lit_nlp/components/curves_test.py | 225 ++++++++++++++++++++++++++++++ 4 files changed, 363 insertions(+), 1 deletion(-) create mode 100644 lit_nlp/components/curves.py create mode 100644 lit_nlp/components/curves_test.py diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py index 36c34551..a6c917a8 100644 --- a/lit_nlp/api/types.py +++ b/lit_nlp/api/types.py @@ -472,4 +472,14 @@ class Boolean(LitType): default: bool = False +@attr.s(auto_attribs=True, frozen=True, kw_only=True) +class CurveDataPoints(LitType): + """Represents data points of a curve. + + A list of tuples where the first and second elements of the tuple are the + x and y coordinates of the corresponding curve point respectively. + """ + pass + + # LINT.ThenChange(../client/lib/types.ts) diff --git a/lit_nlp/client/lib/types.ts b/lit_nlp/client/lib/types.ts index 00651a76..f026b5e9 100644 --- a/lit_nlp/client/lib/types.ts +++ b/lit_nlp/client/lib/types.ts @@ -33,7 +33,7 @@ export type LitName = 'type'|'LitType'|'String'|'TextSegment'|'GeneratedText'| 'AttentionHeads'|'SparseMultilabel'|'FieldMatcher'|'MultiFieldMatcher'| 'Gradients'|'Boolean'|'TokenSalience'|'ImageBytes'|'SparseMultilabelPreds'| 'ImageGradients'|'ImageSalience'|'SequenceSalience'|'ReferenceScores'| - 'FeatureSalience'|'TopTokens'; + 'FeatureSalience'|'TopTokens'|'CurveDataPoints'; export const listFieldTypes: LitName[] = ['Tokens', 'SequenceTags', 'SpanLabels', 'EdgeLabels', 'SparseMultilabel']; diff --git a/lit_nlp/components/curves.py b/lit_nlp/components/curves.py new file mode 100644 index 00000000..49bc8937 --- /dev/null +++ b/lit_nlp/components/curves.py @@ -0,0 +1,127 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""An interpreters for generating data for ROC and PR curves.""" + +from typing import cast, List, Optional, Sequence, Text + +from lit_nlp.api import components as lit_components +from lit_nlp.api import dataset as lit_dataset +from lit_nlp.api import model as lit_model +from lit_nlp.api import types +from lit_nlp.lib import utils as lit_utils +import sklearn.metrics as metrics + +JsonDict = types.JsonDict +IndexedInput = types.IndexedInput +Spec = types.Spec + +# The config key for specifying model output to use for calculations. +TARGET_PREDICTION_KEY = 'Prediction field' +# The config key for specifying the class label to use for calculations. +TARGET_LABEL_KEY = 'Label' +# They field name in the interpreter output that contains ROC curve data. +ROC_DATA = 'roc_data' +# They field name in the interpreter output that contains PR curve data. +PR_DATA = 'pr_data' + + +class CurvesInterpreter(lit_components.Interpreter): + """Returns data for rendering ROC and Precision-Recall curves.""" + + def run_with_metadata(self, + indexed_inputs: Sequence[IndexedInput], + model: lit_model.Model, + dataset: lit_dataset.IndexedDataset, + model_outputs: Optional[List[JsonDict]] = None, + config: Optional[JsonDict] = None): + + if (not config) or (TARGET_LABEL_KEY not in config): + raise ValueError( + f'The config \'{TARGET_LABEL_KEY}\' field should contain the positive' + f' class label.') + target_label = config.get(TARGET_LABEL_KEY) + + # Find the prediction field key in the model output to use for calculations. + output_spec = model.output_spec() + supported_keys = self._find_supported_pred_keys(output_spec) + if len(supported_keys) == 1: + predictions_key = supported_keys[0] + else: + if TARGET_PREDICTION_KEY not in config: + raise ValueError( + f'The config \'{TARGET_PREDICTION_KEY}\' should contain the name' + f' of the prediction field to use for calculations.') + predictions_key = config.get(TARGET_PREDICTION_KEY) + + # Run prediction if needed: + if model_outputs is None: + model_outputs = list(model.predict_with_metadata(indexed_inputs)) + + # Get scores for the target label. + pred_spec = cast(types.MulticlassPreds, output_spec[predictions_key]) + labels = pred_spec.vocab + target_index = labels.index(target_label) + scores = [o[predictions_key][target_index] for o in model_outputs] + + # Get ground truth for the target label. + parent_key = pred_spec.parent + ground_truth_list = [] + for indexed_input in indexed_inputs: + ground_truth_label = indexed_input['data'][parent_key] + ground_truth = 1.0 if ground_truth_label == target_label else 0.0 + ground_truth_list.append(ground_truth) + + # Compute ROC curve data. + x, y, _ = metrics.roc_curve(ground_truth_list, scores) + roc_data = list(zip(x, y)) + + # Compute PR curve data. + x, y, _ = metrics.precision_recall_curve(ground_truth_list, scores) + pr_data = list(zip(x, y)) + + # Create and return the result. + return {ROC_DATA: roc_data, PR_DATA: pr_data} + + def is_compatible(self, model: lit_model.Model) -> bool: + # A model is compatible if it is a classification model and has + # reference to the ground truth in the dataset. + output_spec = model.output_spec() + return True if self._find_supported_pred_keys(output_spec) else False + + def config_spec(self) -> types.Spec: + # If a model is a multiclass classifier, a user can specify which + # class label to use for plotting the curves. If the label is not + # specified then the label with index 0 is used by default. + return {TARGET_LABEL_KEY: types.CategoryLabel()} + + def meta_spec(self) -> types.Spec: + return {ROC_DATA: types.CurveDataPoints(), PR_DATA: types.CurveDataPoints()} + + def _find_supported_pred_keys(self, output_spec: types.Spec) -> List[Text]: + """Returns the list of supported prediction keys in the model output. + + Args: + output_spec: The model output specification. + + Returns: + The list of keys. + """ + all_keys = lit_utils.find_spec_keys(output_spec, types.MulticlassPreds) + supported_keys = [ + k for k in all_keys + if cast(types.MulticlassPreds, output_spec[k]).parent + ] + return supported_keys diff --git a/lit_nlp/components/curves_test.py b/lit_nlp/components/curves_test.py new file mode 100644 index 00000000..c456c570 --- /dev/null +++ b/lit_nlp/components/curves_test.py @@ -0,0 +1,225 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Tests for lit_nlp.components.curves.""" + +from typing import List, NamedTuple, Text, Tuple +from absl.testing import absltest + +from lit_nlp.api import dataset as lit_dataset +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.api.dataset import JsonDict +from lit_nlp.api.dataset import Spec +from lit_nlp.components import curves +from lit_nlp.lib import caching + +# Labels used in the test dataset. +COLORS = ['red', 'green', 'blue'] + + +class TestDataEntry(NamedTuple): + prediction: Tuple[float, float, float] + label: Text + + +TEST_DATA = { + 0: TestDataEntry((0.7, 0.2, 0.1), 'red'), + 1: TestDataEntry((0.3, 0.5, 0.2), 'red'), + 2: TestDataEntry((0.6, 0.1, 0.3), 'blue'), +} + + +class TestModel(lit_model.Model): + """A test model for the interpreter that uses 'TEST_DATA' as model output.""" + + def input_spec(self) -> lit_types.Spec: + return { + 'x': lit_types.Scalar(), + } + + def output_spec(self) -> lit_types.Spec: + return { + 'pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'), + 'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label') + } + + def predict_minibatch(self, inputs: List[lit_types.JsonDict], + **unused) -> List[lit_types.JsonDict]: + output = [] + + def predict_example(ex: lit_types.JsonDict) -> Tuple[float, float, float]: + x = ex['x'] + return TEST_DATA[x].prediction + + for example in inputs: + output.append({ + 'pred': predict_example(example), + 'aux_pred': [1 / 3, 1 / 3, 1 / 3] + }) + return output + + +class IncompatiblePredictionTestModel(lit_model.Model): + """A model with unsupported output type.""" + + def input_spec(self) -> lit_types.Spec: + return { + 'x': lit_types.Scalar(), + } + + def output_spec(self) -> lit_types.Spec: + return {'pred': lit_types.RegressionScore(parent='label')} + + def predict_minibatch(self, inputs: List[lit_types.JsonDict], + **unused) -> List[lit_types.JsonDict]: + return [] + + +class NoParentTestModel(lit_model.Model): + """A model that doesn't specify the ground truth field in the dataset.""" + + def input_spec(self) -> lit_types.Spec: + return { + 'x': lit_types.Scalar(), + } + + def output_spec(self) -> lit_types.Spec: + return {'pred': lit_types.MulticlassPreds(vocab=COLORS)} + + def predict_minibatch(self, inputs: List[lit_types.JsonDict], + **unused) -> List[lit_types.JsonDict]: + return [] + + +class TestDataset(lit_dataset.Dataset): + """Dataset for testing the interpreter that uses 'TEST_DATA' as the source.""" + + def spec(self) -> Spec: + return { + 'x': lit_types.Scalar(), + 'label': lit_types.Scalar(), + } + + @property + def examples(self) -> List[JsonDict]: + data = [] + for x, entry in TEST_DATA.items(): + data.append({'x': x, 'label': entry.label}) + return data + + +class CurvesInterpreterTest(absltest.TestCase): + """Tests CurvesInterpreter.""" + + def setUp(self): + super().setUp() + self.dataset = lit_dataset.IndexedDataset( + base=TestDataset(), id_fn=caching.input_hash) + self.model = TestModel() + + def test_label_not_in_config(self): + """The interpreter throws an error if the config doesn't have Label.""" + ci = curves.CurvesInterpreter() + with self.assertRaisesRegex( + ValueError, 'The config \'Label\' field should contain the positive' + ' class label.'): + ci.run_with_metadata( + indexed_inputs=self.dataset.indexed_examples, + model=self.model, + dataset=self.dataset, + ) + + def test_model_output_is_missing_in_config(self): + """Tests the case when the name of the model output is absent in config. + + The interpreter throws an error if the name of the output is absent. + """ + ci = curves.CurvesInterpreter() + with self.assertRaisesRegex( + ValueError, 'The config \'Prediction field\' should contain'): + ci.run_with_metadata( + indexed_inputs=self.dataset.indexed_examples, + model=self.model, + dataset=self.dataset, + config={'Label': 'red'}) + + def test_interpreter_honors_user_selected_label(self): + """Tests a happy scenario when a user doesn't specify the class label.""" + ci = curves.CurvesInterpreter() + self.assertTrue(ci.is_compatible(self.model)) + + # Test the curve data for 'red' label. + curves_data = ci.run_with_metadata( + indexed_inputs=self.dataset.indexed_examples, + model=self.model, + dataset=self.dataset, + config={ + 'Label': 'red', + 'Prediction field': 'pred' + }) + self.assertIn('roc_data', curves_data) + self.assertIn('pr_data', curves_data) + roc_data = curves_data['roc_data'] + self.assertEqual(roc_data, [(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)]) + pr_data = curves_data['pr_data'] + self.assertEqual(pr_data, [(2 / 3, 1.0), (0.5, 0.5), (1.0, 0.5), + (1.0, 0.0)]) + + # Test the curve data for 'blue' label. + curves_data = ci.run_with_metadata( + indexed_inputs=self.dataset.indexed_examples, + model=self.model, + dataset=self.dataset, + config={ + 'Label': 'blue', + 'Prediction field': 'pred' + }) + self.assertIn('roc_data', curves_data) + self.assertIn('pr_data', curves_data) + roc_data = curves_data['roc_data'] + self.assertEqual(roc_data, [(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)]) + pr_data = curves_data['pr_data'] + self.assertEqual(pr_data, [(1.0, 1.0), (1.0, 0.0)]) + + def test_config_spec(self): + """Tests that the interpreter config has correct fields of correct type.""" + ci = curves.CurvesInterpreter() + spec = ci.config_spec() + self.assertIn('Label', spec) + self.assertIsInstance(spec['Label'], lit_types.CategoryLabel) + + def test_meta_spec(self): + """Tests that the interpreter meta has correct fields of correct type.""" + ci = curves.CurvesInterpreter() + spec = ci.meta_spec() + self.assertIn('roc_data', spec) + self.assertIsInstance(spec['roc_data'], lit_types.CurveDataPoints) + self.assertIn('pr_data', spec) + self.assertIsInstance(spec['pr_data'], lit_types.CurveDataPoints) + + def test_incompatible_model_prediction(self): + """A model is incompatible if prediction is not MulticlassPreds.""" + ci = curves.CurvesInterpreter() + self.assertFalse(ci.is_compatible(IncompatiblePredictionTestModel())) + + def test_no_parent_in_model_spec(self): + """A model is incompatible if there is no reference to the parent.""" + ci = curves.CurvesInterpreter() + self.assertFalse(ci.is_compatible(NoParentTestModel())) + + +if __name__ == '__main__': + absltest.main()