Skip to content

Commit

Permalink
An interpreter that generates data for ROC and PR plots.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 428411165
  • Loading branch information
Googler committed Feb 14, 2022
1 parent 74aedcc commit 51842ba
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 1 deletion.
10 changes: 10 additions & 0 deletions lit_nlp/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,4 +472,14 @@ class Boolean(LitType):
default: bool = False


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class CurveDataPoints(LitType):
"""Represents data points of a curve.
A list of tuples where the first and second elements of the tuple are the
x and y coordinates of the corresponding curve point respectively.
"""
pass


# LINT.ThenChange(../client/lib/types.ts)
2 changes: 1 addition & 1 deletion lit_nlp/client/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export type LitName = 'type'|'LitType'|'String'|'TextSegment'|'GeneratedText'|
'AttentionHeads'|'SparseMultilabel'|'FieldMatcher'|'MultiFieldMatcher'|
'Gradients'|'Boolean'|'TokenSalience'|'ImageBytes'|'SparseMultilabelPreds'|
'ImageGradients'|'ImageSalience'|'SequenceSalience'|'ReferenceScores'|
'FeatureSalience'|'TopTokens';
'FeatureSalience'|'TopTokens'|'CurveDataPoints';

export const listFieldTypes: LitName[] =
['Tokens', 'SequenceTags', 'SpanLabels', 'EdgeLabels', 'SparseMultilabel'];
Expand Down
127 changes: 127 additions & 0 deletions lit_nlp/components/curves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""An interpreters for generating data for ROC and PR curves."""

from typing import cast, List, Optional, Sequence, Text

from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.lib import utils as lit_utils
import sklearn.metrics as metrics

JsonDict = types.JsonDict
IndexedInput = types.IndexedInput
Spec = types.Spec

# The config key for specifying model output to use for calculations.
TARGET_PREDICTION_KEY = 'Prediction field'
# The config key for specifying the class label to use for calculations.
TARGET_LABEL_KEY = 'Label'
# They field name in the interpreter output that contains ROC curve data.
ROC_DATA = 'roc_data'
# They field name in the interpreter output that contains PR curve data.
PR_DATA = 'pr_data'


class CurvesInterpreter(lit_components.Interpreter):
"""Returns data for rendering ROC and Precision-Recall curves."""

def run_with_metadata(self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
model_outputs: Optional[List[JsonDict]] = None,
config: Optional[JsonDict] = None):

if (not config) or (TARGET_LABEL_KEY not in config):
raise ValueError(
f'The config \'{TARGET_LABEL_KEY}\' field should contain the positive'
f' class label.')
target_label = config.get(TARGET_LABEL_KEY)

# Find the prediction field key in the model output to use for calculations.
output_spec = model.output_spec()
supported_keys = self._find_supported_pred_keys(output_spec)
if len(supported_keys) == 1:
predictions_key = supported_keys[0]
else:
if TARGET_PREDICTION_KEY not in config:
raise ValueError(
f'The config \'{TARGET_PREDICTION_KEY}\' should contain the name'
f' of the prediction field to use for calculations.')
predictions_key = config.get(TARGET_PREDICTION_KEY)

# Run prediction if needed:
if model_outputs is None:
model_outputs = list(model.predict_with_metadata(indexed_inputs))

# Get scores for the target label.
pred_spec = cast(types.MulticlassPreds, output_spec[predictions_key])
labels = pred_spec.vocab
target_index = labels.index(target_label)
scores = [o[predictions_key][target_index] for o in model_outputs]

# Get ground truth for the target label.
parent_key = pred_spec.parent
ground_truth_list = []
for indexed_input in indexed_inputs:
ground_truth_label = indexed_input['data'][parent_key]
ground_truth = 1.0 if ground_truth_label == target_label else 0.0
ground_truth_list.append(ground_truth)

# Compute ROC curve data.
x, y, _ = metrics.roc_curve(ground_truth_list, scores)
roc_data = list(zip(x, y))

# Compute PR curve data.
x, y, _ = metrics.precision_recall_curve(ground_truth_list, scores)
pr_data = list(zip(x, y))

# Create and return the result.
return {ROC_DATA: roc_data, PR_DATA: pr_data}

def is_compatible(self, model: lit_model.Model) -> bool:
# A model is compatible if it is a classification model and has
# reference to the ground truth in the dataset.
output_spec = model.output_spec()
return True if self._find_supported_pred_keys(output_spec) else False

def config_spec(self) -> types.Spec:
# If a model is a multiclass classifier, a user can specify which
# class label to use for plotting the curves. If the label is not
# specified then the label with index 0 is used by default.
return {TARGET_LABEL_KEY: types.CategoryLabel()}

def meta_spec(self) -> types.Spec:
return {ROC_DATA: types.CurveDataPoints(), PR_DATA: types.CurveDataPoints()}

def _find_supported_pred_keys(self, output_spec: types.Spec) -> List[Text]:
"""Returns the list of supported prediction keys in the model output.
Args:
output_spec: The model output specification.
Returns:
The list of keys.
"""
all_keys = lit_utils.find_spec_keys(output_spec, types.MulticlassPreds)
supported_keys = [
k for k in all_keys
if cast(types.MulticlassPreds, output_spec[k]).parent
]
return supported_keys
225 changes: 225 additions & 0 deletions lit_nlp/components/curves_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for lit_nlp.components.curves."""

from typing import List, NamedTuple, Text, Tuple
from absl.testing import absltest

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.api.dataset import JsonDict
from lit_nlp.api.dataset import Spec
from lit_nlp.components import curves
from lit_nlp.lib import caching

# Labels used in the test dataset.
COLORS = ['red', 'green', 'blue']


class TestDataEntry(NamedTuple):
prediction: Tuple[float, float, float]
label: Text


TEST_DATA = {
0: TestDataEntry((0.7, 0.2, 0.1), 'red'),
1: TestDataEntry((0.3, 0.5, 0.2), 'red'),
2: TestDataEntry((0.6, 0.1, 0.3), 'blue'),
}


class TestModel(lit_model.Model):
"""A test model for the interpreter that uses 'TEST_DATA' as model output."""

def input_spec(self) -> lit_types.Spec:
return {
'x': lit_types.Scalar(),
}

def output_spec(self) -> lit_types.Spec:
return {
'pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'),
'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label')
}

def predict_minibatch(self, inputs: List[lit_types.JsonDict],
**unused) -> List[lit_types.JsonDict]:
output = []

def predict_example(ex: lit_types.JsonDict) -> Tuple[float, float, float]:
x = ex['x']
return TEST_DATA[x].prediction

for example in inputs:
output.append({
'pred': predict_example(example),
'aux_pred': [1 / 3, 1 / 3, 1 / 3]
})
return output


class IncompatiblePredictionTestModel(lit_model.Model):
"""A model with unsupported output type."""

def input_spec(self) -> lit_types.Spec:
return {
'x': lit_types.Scalar(),
}

def output_spec(self) -> lit_types.Spec:
return {'pred': lit_types.RegressionScore(parent='label')}

def predict_minibatch(self, inputs: List[lit_types.JsonDict],
**unused) -> List[lit_types.JsonDict]:
return []


class NoParentTestModel(lit_model.Model):
"""A model that doesn't specify the ground truth field in the dataset."""

def input_spec(self) -> lit_types.Spec:
return {
'x': lit_types.Scalar(),
}

def output_spec(self) -> lit_types.Spec:
return {'pred': lit_types.MulticlassPreds(vocab=COLORS)}

def predict_minibatch(self, inputs: List[lit_types.JsonDict],
**unused) -> List[lit_types.JsonDict]:
return []


class TestDataset(lit_dataset.Dataset):
"""Dataset for testing the interpreter that uses 'TEST_DATA' as the source."""

def spec(self) -> Spec:
return {
'x': lit_types.Scalar(),
'label': lit_types.Scalar(),
}

@property
def examples(self) -> List[JsonDict]:
data = []
for x, entry in TEST_DATA.items():
data.append({'x': x, 'label': entry.label})
return data


class CurvesInterpreterTest(absltest.TestCase):
"""Tests CurvesInterpreter."""

def setUp(self):
super().setUp()
self.dataset = lit_dataset.IndexedDataset(
base=TestDataset(), id_fn=caching.input_hash)
self.model = TestModel()

def test_label_not_in_config(self):
"""The interpreter throws an error if the config doesn't have Label."""
ci = curves.CurvesInterpreter()
with self.assertRaisesRegex(
ValueError, 'The config \'Label\' field should contain the positive'
' class label.'):
ci.run_with_metadata(
indexed_inputs=self.dataset.indexed_examples,
model=self.model,
dataset=self.dataset,
)

def test_model_output_is_missing_in_config(self):
"""Tests the case when the name of the model output is absent in config.
The interpreter throws an error if the name of the output is absent.
"""
ci = curves.CurvesInterpreter()
with self.assertRaisesRegex(
ValueError, 'The config \'Prediction field\' should contain'):
ci.run_with_metadata(
indexed_inputs=self.dataset.indexed_examples,
model=self.model,
dataset=self.dataset,
config={'Label': 'red'})

def test_interpreter_honors_user_selected_label(self):
"""Tests a happy scenario when a user doesn't specify the class label."""
ci = curves.CurvesInterpreter()
self.assertTrue(ci.is_compatible(self.model))

# Test the curve data for 'red' label.
curves_data = ci.run_with_metadata(
indexed_inputs=self.dataset.indexed_examples,
model=self.model,
dataset=self.dataset,
config={
'Label': 'red',
'Prediction field': 'pred'
})
self.assertIn('roc_data', curves_data)
self.assertIn('pr_data', curves_data)
roc_data = curves_data['roc_data']
self.assertEqual(roc_data, [(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)])
pr_data = curves_data['pr_data']
self.assertEqual(pr_data, [(2 / 3, 1.0), (0.5, 0.5), (1.0, 0.5),
(1.0, 0.0)])

# Test the curve data for 'blue' label.
curves_data = ci.run_with_metadata(
indexed_inputs=self.dataset.indexed_examples,
model=self.model,
dataset=self.dataset,
config={
'Label': 'blue',
'Prediction field': 'pred'
})
self.assertIn('roc_data', curves_data)
self.assertIn('pr_data', curves_data)
roc_data = curves_data['roc_data']
self.assertEqual(roc_data, [(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)])
pr_data = curves_data['pr_data']
self.assertEqual(pr_data, [(1.0, 1.0), (1.0, 0.0)])

def test_config_spec(self):
"""Tests that the interpreter config has correct fields of correct type."""
ci = curves.CurvesInterpreter()
spec = ci.config_spec()
self.assertIn('Label', spec)
self.assertIsInstance(spec['Label'], lit_types.CategoryLabel)

def test_meta_spec(self):
"""Tests that the interpreter meta has correct fields of correct type."""
ci = curves.CurvesInterpreter()
spec = ci.meta_spec()
self.assertIn('roc_data', spec)
self.assertIsInstance(spec['roc_data'], lit_types.CurveDataPoints)
self.assertIn('pr_data', spec)
self.assertIsInstance(spec['pr_data'], lit_types.CurveDataPoints)

def test_incompatible_model_prediction(self):
"""A model is incompatible if prediction is not MulticlassPreds."""
ci = curves.CurvesInterpreter()
self.assertFalse(ci.is_compatible(IncompatiblePredictionTestModel()))

def test_no_parent_in_model_spec(self):
"""A model is incompatible if there is no reference to the parent."""
ci = curves.CurvesInterpreter()
self.assertFalse(ci.is_compatible(NoParentTestModel()))


if __name__ == '__main__':
absltest.main()

0 comments on commit 51842ba

Please sign in to comment.