-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
An interpreter that generates data for ROC and PR plots.
PiperOrigin-RevId: 428411165
- Loading branch information
Googler
committed
Feb 14, 2022
1 parent
74aedcc
commit 51842ba
Showing
4 changed files
with
363 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Lint as: python3 | ||
"""An interpreters for generating data for ROC and PR curves.""" | ||
|
||
from typing import cast, List, Optional, Sequence, Text | ||
|
||
from lit_nlp.api import components as lit_components | ||
from lit_nlp.api import dataset as lit_dataset | ||
from lit_nlp.api import model as lit_model | ||
from lit_nlp.api import types | ||
from lit_nlp.lib import utils as lit_utils | ||
import sklearn.metrics as metrics | ||
|
||
JsonDict = types.JsonDict | ||
IndexedInput = types.IndexedInput | ||
Spec = types.Spec | ||
|
||
# The config key for specifying model output to use for calculations. | ||
TARGET_PREDICTION_KEY = 'Prediction field' | ||
# The config key for specifying the class label to use for calculations. | ||
TARGET_LABEL_KEY = 'Label' | ||
# They field name in the interpreter output that contains ROC curve data. | ||
ROC_DATA = 'roc_data' | ||
# They field name in the interpreter output that contains PR curve data. | ||
PR_DATA = 'pr_data' | ||
|
||
|
||
class CurvesInterpreter(lit_components.Interpreter): | ||
"""Returns data for rendering ROC and Precision-Recall curves.""" | ||
|
||
def run_with_metadata(self, | ||
indexed_inputs: Sequence[IndexedInput], | ||
model: lit_model.Model, | ||
dataset: lit_dataset.IndexedDataset, | ||
model_outputs: Optional[List[JsonDict]] = None, | ||
config: Optional[JsonDict] = None): | ||
|
||
if (not config) or (TARGET_LABEL_KEY not in config): | ||
raise ValueError( | ||
f'The config \'{TARGET_LABEL_KEY}\' field should contain the positive' | ||
f' class label.') | ||
target_label = config.get(TARGET_LABEL_KEY) | ||
|
||
# Find the prediction field key in the model output to use for calculations. | ||
output_spec = model.output_spec() | ||
supported_keys = self._find_supported_pred_keys(output_spec) | ||
if len(supported_keys) == 1: | ||
predictions_key = supported_keys[0] | ||
else: | ||
if TARGET_PREDICTION_KEY not in config: | ||
raise ValueError( | ||
f'The config \'{TARGET_PREDICTION_KEY}\' should contain the name' | ||
f' of the prediction field to use for calculations.') | ||
predictions_key = config.get(TARGET_PREDICTION_KEY) | ||
|
||
# Run prediction if needed: | ||
if model_outputs is None: | ||
model_outputs = list(model.predict_with_metadata(indexed_inputs)) | ||
|
||
# Get scores for the target label. | ||
pred_spec = cast(types.MulticlassPreds, output_spec[predictions_key]) | ||
labels = pred_spec.vocab | ||
target_index = labels.index(target_label) | ||
scores = [o[predictions_key][target_index] for o in model_outputs] | ||
|
||
# Get ground truth for the target label. | ||
parent_key = pred_spec.parent | ||
ground_truth_list = [] | ||
for indexed_input in indexed_inputs: | ||
ground_truth_label = indexed_input['data'][parent_key] | ||
ground_truth = 1.0 if ground_truth_label == target_label else 0.0 | ||
ground_truth_list.append(ground_truth) | ||
|
||
# Compute ROC curve data. | ||
x, y, _ = metrics.roc_curve(ground_truth_list, scores) | ||
roc_data = list(zip(x, y)) | ||
|
||
# Compute PR curve data. | ||
x, y, _ = metrics.precision_recall_curve(ground_truth_list, scores) | ||
pr_data = list(zip(x, y)) | ||
|
||
# Create and return the result. | ||
return {ROC_DATA: roc_data, PR_DATA: pr_data} | ||
|
||
def is_compatible(self, model: lit_model.Model) -> bool: | ||
# A model is compatible if it is a classification model and has | ||
# reference to the ground truth in the dataset. | ||
output_spec = model.output_spec() | ||
return True if self._find_supported_pred_keys(output_spec) else False | ||
|
||
def config_spec(self) -> types.Spec: | ||
# If a model is a multiclass classifier, a user can specify which | ||
# class label to use for plotting the curves. If the label is not | ||
# specified then the label with index 0 is used by default. | ||
return {TARGET_LABEL_KEY: types.CategoryLabel()} | ||
|
||
def meta_spec(self) -> types.Spec: | ||
return {ROC_DATA: types.CurveDataPoints(), PR_DATA: types.CurveDataPoints()} | ||
|
||
def _find_supported_pred_keys(self, output_spec: types.Spec) -> List[Text]: | ||
"""Returns the list of supported prediction keys in the model output. | ||
Args: | ||
output_spec: The model output specification. | ||
Returns: | ||
The list of keys. | ||
""" | ||
all_keys = lit_utils.find_spec_keys(output_spec, types.MulticlassPreds) | ||
supported_keys = [ | ||
k for k in all_keys | ||
if cast(types.MulticlassPreds, output_spec[k]).parent | ||
] | ||
return supported_keys |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Lint as: python3 | ||
"""Tests for lit_nlp.components.curves.""" | ||
|
||
from typing import List, NamedTuple, Text, Tuple | ||
from absl.testing import absltest | ||
|
||
from lit_nlp.api import dataset as lit_dataset | ||
from lit_nlp.api import model as lit_model | ||
from lit_nlp.api import types as lit_types | ||
from lit_nlp.api.dataset import JsonDict | ||
from lit_nlp.api.dataset import Spec | ||
from lit_nlp.components import curves | ||
from lit_nlp.lib import caching | ||
|
||
# Labels used in the test dataset. | ||
COLORS = ['red', 'green', 'blue'] | ||
|
||
|
||
class TestDataEntry(NamedTuple): | ||
prediction: Tuple[float, float, float] | ||
label: Text | ||
|
||
|
||
TEST_DATA = { | ||
0: TestDataEntry((0.7, 0.2, 0.1), 'red'), | ||
1: TestDataEntry((0.3, 0.5, 0.2), 'red'), | ||
2: TestDataEntry((0.6, 0.1, 0.3), 'blue'), | ||
} | ||
|
||
|
||
class TestModel(lit_model.Model): | ||
"""A test model for the interpreter that uses 'TEST_DATA' as model output.""" | ||
|
||
def input_spec(self) -> lit_types.Spec: | ||
return { | ||
'x': lit_types.Scalar(), | ||
} | ||
|
||
def output_spec(self) -> lit_types.Spec: | ||
return { | ||
'pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'), | ||
'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label') | ||
} | ||
|
||
def predict_minibatch(self, inputs: List[lit_types.JsonDict], | ||
**unused) -> List[lit_types.JsonDict]: | ||
output = [] | ||
|
||
def predict_example(ex: lit_types.JsonDict) -> Tuple[float, float, float]: | ||
x = ex['x'] | ||
return TEST_DATA[x].prediction | ||
|
||
for example in inputs: | ||
output.append({ | ||
'pred': predict_example(example), | ||
'aux_pred': [1 / 3, 1 / 3, 1 / 3] | ||
}) | ||
return output | ||
|
||
|
||
class IncompatiblePredictionTestModel(lit_model.Model): | ||
"""A model with unsupported output type.""" | ||
|
||
def input_spec(self) -> lit_types.Spec: | ||
return { | ||
'x': lit_types.Scalar(), | ||
} | ||
|
||
def output_spec(self) -> lit_types.Spec: | ||
return {'pred': lit_types.RegressionScore(parent='label')} | ||
|
||
def predict_minibatch(self, inputs: List[lit_types.JsonDict], | ||
**unused) -> List[lit_types.JsonDict]: | ||
return [] | ||
|
||
|
||
class NoParentTestModel(lit_model.Model): | ||
"""A model that doesn't specify the ground truth field in the dataset.""" | ||
|
||
def input_spec(self) -> lit_types.Spec: | ||
return { | ||
'x': lit_types.Scalar(), | ||
} | ||
|
||
def output_spec(self) -> lit_types.Spec: | ||
return {'pred': lit_types.MulticlassPreds(vocab=COLORS)} | ||
|
||
def predict_minibatch(self, inputs: List[lit_types.JsonDict], | ||
**unused) -> List[lit_types.JsonDict]: | ||
return [] | ||
|
||
|
||
class TestDataset(lit_dataset.Dataset): | ||
"""Dataset for testing the interpreter that uses 'TEST_DATA' as the source.""" | ||
|
||
def spec(self) -> Spec: | ||
return { | ||
'x': lit_types.Scalar(), | ||
'label': lit_types.Scalar(), | ||
} | ||
|
||
@property | ||
def examples(self) -> List[JsonDict]: | ||
data = [] | ||
for x, entry in TEST_DATA.items(): | ||
data.append({'x': x, 'label': entry.label}) | ||
return data | ||
|
||
|
||
class CurvesInterpreterTest(absltest.TestCase): | ||
"""Tests CurvesInterpreter.""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.dataset = lit_dataset.IndexedDataset( | ||
base=TestDataset(), id_fn=caching.input_hash) | ||
self.model = TestModel() | ||
|
||
def test_label_not_in_config(self): | ||
"""The interpreter throws an error if the config doesn't have Label.""" | ||
ci = curves.CurvesInterpreter() | ||
with self.assertRaisesRegex( | ||
ValueError, 'The config \'Label\' field should contain the positive' | ||
' class label.'): | ||
ci.run_with_metadata( | ||
indexed_inputs=self.dataset.indexed_examples, | ||
model=self.model, | ||
dataset=self.dataset, | ||
) | ||
|
||
def test_model_output_is_missing_in_config(self): | ||
"""Tests the case when the name of the model output is absent in config. | ||
The interpreter throws an error if the name of the output is absent. | ||
""" | ||
ci = curves.CurvesInterpreter() | ||
with self.assertRaisesRegex( | ||
ValueError, 'The config \'Prediction field\' should contain'): | ||
ci.run_with_metadata( | ||
indexed_inputs=self.dataset.indexed_examples, | ||
model=self.model, | ||
dataset=self.dataset, | ||
config={'Label': 'red'}) | ||
|
||
def test_interpreter_honors_user_selected_label(self): | ||
"""Tests a happy scenario when a user doesn't specify the class label.""" | ||
ci = curves.CurvesInterpreter() | ||
self.assertTrue(ci.is_compatible(self.model)) | ||
|
||
# Test the curve data for 'red' label. | ||
curves_data = ci.run_with_metadata( | ||
indexed_inputs=self.dataset.indexed_examples, | ||
model=self.model, | ||
dataset=self.dataset, | ||
config={ | ||
'Label': 'red', | ||
'Prediction field': 'pred' | ||
}) | ||
self.assertIn('roc_data', curves_data) | ||
self.assertIn('pr_data', curves_data) | ||
roc_data = curves_data['roc_data'] | ||
self.assertEqual(roc_data, [(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)]) | ||
pr_data = curves_data['pr_data'] | ||
self.assertEqual(pr_data, [(2 / 3, 1.0), (0.5, 0.5), (1.0, 0.5), | ||
(1.0, 0.0)]) | ||
|
||
# Test the curve data for 'blue' label. | ||
curves_data = ci.run_with_metadata( | ||
indexed_inputs=self.dataset.indexed_examples, | ||
model=self.model, | ||
dataset=self.dataset, | ||
config={ | ||
'Label': 'blue', | ||
'Prediction field': 'pred' | ||
}) | ||
self.assertIn('roc_data', curves_data) | ||
self.assertIn('pr_data', curves_data) | ||
roc_data = curves_data['roc_data'] | ||
self.assertEqual(roc_data, [(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)]) | ||
pr_data = curves_data['pr_data'] | ||
self.assertEqual(pr_data, [(1.0, 1.0), (1.0, 0.0)]) | ||
|
||
def test_config_spec(self): | ||
"""Tests that the interpreter config has correct fields of correct type.""" | ||
ci = curves.CurvesInterpreter() | ||
spec = ci.config_spec() | ||
self.assertIn('Label', spec) | ||
self.assertIsInstance(spec['Label'], lit_types.CategoryLabel) | ||
|
||
def test_meta_spec(self): | ||
"""Tests that the interpreter meta has correct fields of correct type.""" | ||
ci = curves.CurvesInterpreter() | ||
spec = ci.meta_spec() | ||
self.assertIn('roc_data', spec) | ||
self.assertIsInstance(spec['roc_data'], lit_types.CurveDataPoints) | ||
self.assertIn('pr_data', spec) | ||
self.assertIsInstance(spec['pr_data'], lit_types.CurveDataPoints) | ||
|
||
def test_incompatible_model_prediction(self): | ||
"""A model is incompatible if prediction is not MulticlassPreds.""" | ||
ci = curves.CurvesInterpreter() | ||
self.assertFalse(ci.is_compatible(IncompatiblePredictionTestModel())) | ||
|
||
def test_no_parent_in_model_spec(self): | ||
"""A model is incompatible if there is no reference to the parent.""" | ||
ci = curves.CurvesInterpreter() | ||
self.assertFalse(ci.is_compatible(NoParentTestModel())) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |