From 9baac29b96970ef7fa64f2f36ce2c79ff73707b7 Mon Sep 17 00:00:00 2001 From: Ryan Mullins <ryanmullins@google.com> Date: Thu, 5 Sep 2024 16:15:39 +0000 Subject: [PATCH] Code health update on model server tests --- lit_nlp/examples/gcp/model_server_test.py | 53 ++++++++++++++--------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/lit_nlp/examples/gcp/model_server_test.py b/lit_nlp/examples/gcp/model_server_test.py index cdeb8242..59f0d949 100644 --- a/lit_nlp/examples/gcp/model_server_test.py +++ b/lit_nlp/examples/gcp/model_server_test.py @@ -2,21 +2,23 @@ from unittest import mock from absl.testing import absltest +from absl.testing import parameterized from lit_nlp.examples.gcp import model_server from lit_nlp.examples.prompt_debugging import utils as pd_utils import webtest -class TestWSGIApp(absltest.TestCase): +class TestWSGIApp(parameterized.TestCase): - @mock.patch('lit_nlp.examples.prompt_debugging.models.get_models') - def test_predict_endpoint(self, mock_get_models): + @classmethod + def setUpClass(cls): test_model_name = 'lit_on_gcp_test_model' + sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name) test_model_config = f'{test_model_name}:test_model_path' os.environ['MODEL_CONFIG'] = test_model_config - mock_model = mock.MagicMock() - mock_model.predict.side_effect = [[{'response': 'test output text'}]] + generation_model = mock.MagicMock() + generation_model.predict.side_effect = [[{'response': 'test output text'}]] salience_model = mock.MagicMock() salience_model.predict.side_effect = [[{ @@ -30,33 +32,42 @@ def test_predict_endpoint(self, mock_get_models): [{'tokens': ['test', 'output', 'text']}] ] - sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name) - - mock_get_models.return_value = { - test_model_name: mock_model, + cls.mock_models = { + test_model_name: generation_model, sal_name: salience_model, tok_name: tokenize_model, } - app = webtest.TestApp(model_server.get_wsgi_app()) - response = app.post_json('/predict', {'inputs': 'test_input'}) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, [{'response': 'test output text'}]) - response = app.post_json('/salience', {'inputs': 'test_input'}) - self.assertEqual(response.status_code, 200) - self.assertEqual( - response.json, - [{ + @parameterized.named_parameters( + dict( + testcase_name='predict', + endpoint='/predict', + expected=[{'response': 'test output text'}], + ), + dict( + testcase_name='salience', + endpoint='/salience', + expected=[{ 'tokens': ['test', 'output', 'text'], 'grad_l2': [0.1234, 0.3456, 0.5678], 'grad_dot_input': [0.1234, -0.3456, 0.5678], }], - ) + ), + dict( + testcase_name='tokenize', + endpoint='/tokenize', + expected=[{'tokens': ['test', 'output', 'text']}], + ), + ) + @mock.patch('lit_nlp.examples.prompt_debugging.models.get_models') + def test_endpoint(self, mock_get_models, endpoint, expected): + mock_get_models.return_value = self.mock_models + app = webtest.TestApp(model_server.get_wsgi_app()) - response = app.post_json('/tokenize', {'inputs': 'test_input'}) + response = app.post_json(endpoint, {'inputs': [{'prompt': 'test input'}]}) self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, [{'tokens': ['test', 'output', 'text']}]) + self.assertEqual(response.json, expected) if __name__ == '__main__':