From 9baac29b96970ef7fa64f2f36ce2c79ff73707b7 Mon Sep 17 00:00:00 2001
From: Ryan Mullins <ryanmullins@google.com>
Date: Thu, 5 Sep 2024 16:15:39 +0000
Subject: [PATCH] Code health update on model server tests

---
 lit_nlp/examples/gcp/model_server_test.py | 53 ++++++++++++++---------
 1 file changed, 32 insertions(+), 21 deletions(-)

diff --git a/lit_nlp/examples/gcp/model_server_test.py b/lit_nlp/examples/gcp/model_server_test.py
index cdeb8242..59f0d949 100644
--- a/lit_nlp/examples/gcp/model_server_test.py
+++ b/lit_nlp/examples/gcp/model_server_test.py
@@ -2,21 +2,23 @@
 from unittest import mock
 
 from absl.testing import absltest
+from absl.testing import parameterized
 from lit_nlp.examples.gcp import model_server
 from lit_nlp.examples.prompt_debugging import utils as pd_utils
 import webtest
 
 
-class TestWSGIApp(absltest.TestCase):
+class TestWSGIApp(parameterized.TestCase):
 
-  @mock.patch('lit_nlp.examples.prompt_debugging.models.get_models')
-  def test_predict_endpoint(self, mock_get_models):
+  @classmethod
+  def setUpClass(cls):
     test_model_name = 'lit_on_gcp_test_model'
+    sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name)
     test_model_config = f'{test_model_name}:test_model_path'
     os.environ['MODEL_CONFIG'] = test_model_config
 
-    mock_model = mock.MagicMock()
-    mock_model.predict.side_effect = [[{'response': 'test output text'}]]
+    generation_model = mock.MagicMock()
+    generation_model.predict.side_effect = [[{'response': 'test output text'}]]
 
     salience_model = mock.MagicMock()
     salience_model.predict.side_effect = [[{
@@ -30,33 +32,42 @@ def test_predict_endpoint(self, mock_get_models):
         [{'tokens': ['test', 'output', 'text']}]
     ]
 
-    sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name)
-
-    mock_get_models.return_value = {
-        test_model_name: mock_model,
+    cls.mock_models = {
+        test_model_name: generation_model,
         sal_name: salience_model,
         tok_name: tokenize_model,
     }
-    app = webtest.TestApp(model_server.get_wsgi_app())
 
-    response = app.post_json('/predict', {'inputs': 'test_input'})
-    self.assertEqual(response.status_code, 200)
-    self.assertEqual(response.json, [{'response': 'test output text'}])
 
-    response = app.post_json('/salience', {'inputs': 'test_input'})
-    self.assertEqual(response.status_code, 200)
-    self.assertEqual(
-        response.json,
-        [{
+  @parameterized.named_parameters(
+    dict(
+        testcase_name='predict',
+        endpoint='/predict',
+        expected=[{'response': 'test output text'}],
+    ),
+    dict(
+        testcase_name='salience',
+        endpoint='/salience',
+        expected=[{
             'tokens': ['test', 'output', 'text'],
             'grad_l2': [0.1234, 0.3456, 0.5678],
             'grad_dot_input': [0.1234, -0.3456, 0.5678],
         }],
-    )
+    ),
+    dict(
+        testcase_name='tokenize',
+        endpoint='/tokenize',
+        expected=[{'tokens': ['test', 'output', 'text']}],
+    ),
+  )
+  @mock.patch('lit_nlp.examples.prompt_debugging.models.get_models')
+  def test_endpoint(self, mock_get_models, endpoint, expected):
+    mock_get_models.return_value = self.mock_models
+    app = webtest.TestApp(model_server.get_wsgi_app())
 
-    response = app.post_json('/tokenize', {'inputs': 'test_input'})
+    response = app.post_json(endpoint, {'inputs': [{'prompt': 'test input'}]})
     self.assertEqual(response.status_code, 200)
-    self.assertEqual(response.json, [{'tokens': ['test', 'output', 'text']}])
+    self.assertEqual(response.json, expected)
 
 
 if __name__ == '__main__':