diff --git a/lit_nlp/components/curves_test.py b/lit_nlp/components/curves_test.py index 23fbfd00..cef377c4 100644 --- a/lit_nlp/components/curves_test.py +++ b/lit_nlp/components/curves_test.py @@ -148,38 +148,6 @@ def test_model_output_is_missing_in_config(self): config={'Label': 'red'}, ) - @parameterized.named_parameters( - dict( - testcase_name='red', - label='red', - exp_roc=[(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)], - exp_pr=[(0.5, 0.5), (2 / 3, 1.0), (1.0, 0.5), (1.0, 0.0)], - ), - dict( - testcase_name='blue', - label='blue', - exp_roc=[(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)], - exp_pr=[(1.0, 1.0), (1.0, 0.0)], - ), - ) - def test_interpreter_honors_user_selected_label( - self, label: str, exp_roc: _Curve, exp_pr: _Curve - ): - """Tests a happy scenario when a user doesn't specify the class label.""" - curves_data = self.ci.run( - inputs=self.dataset.examples, - model=self.model, - dataset=self.dataset, - config={ - curves.TARGET_LABEL_KEY: label, - curves.TARGET_PREDICTION_KEY: 'pred', - }, - ) - self.assertIn(curves.ROC_DATA, curves_data) - self.assertIn(curves.PR_DATA, curves_data) - self.assertEqual(curves_data[curves.ROC_DATA], exp_roc) - self.assertEqual(curves_data[curves.PR_DATA], exp_pr) - def test_config_spec(self): """Tests that the interpreter config has correct fields of correct type.""" spec = self.ci.config_spec()