From 08fa6599ebd054962f2c611d54beed37e255ecaa Mon Sep 17 00:00:00 2001 From: goldpulpy Date: Wed, 9 Oct 2024 04:09:06 +0300 Subject: [PATCH] Tests for the compute score function --- tests/test_utils.py | 61 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..ef02ffb --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,61 @@ +"""Test sentence similarity module.""" +import unittest + +import numpy as np +from pysentence_similarity import Model, compute_score + + +class TestModel(unittest.TestCase): + """Test sentence similarity model.""" + @classmethod + def setUpClass(cls) -> None: + """Set up resources that are shared across tests.""" + cls.model = Model("all-MiniLM-L6-v2", dtype="fp16") + + def test_similarity_score(self) -> None: + """Test similarity score between sentences.""" + source_embedding = self.model.encode("This is a test.") + target_embedding = self.model.encode("This is another test.") + + score = compute_score(source_embedding, [target_embedding]) + + self.assertIsInstance(score, list) + self.assertEqual(len(score), 1) + self.assertGreaterEqual(score[0], -1) + self.assertLessEqual(score[0], 1) + + def test_similarity_score_invalid_input(self) -> None: + """Test similarity score raises error on invalid inputs.""" + with self.assertRaises(ValueError): + compute_score(123, [np.array([0.5, 0.1])]) + + def test_similarity_score_rounding(self) -> None: + """Test similarity score with different rounding values.""" + source_embedding = self.model.encode("This is a test.") + target_embedding = self.model.encode("This is another test.") + + for rounding in range(0, 11): + score = compute_score( + source_embedding, [target_embedding], + rounding=rounding + ) + self.assertIsInstance(score, list) + self.assertEqual(len(score), 1) + + def test_similarity_score_multiple_embeddings(self) -> None: + """Test similarity score with multiple embeddings.""" + embeddings = self.model.encode( + ["This is a test.", + "This is another test."] + ) + + score = compute_score(embeddings, embeddings) + + self.assertIsInstance(score, list) + self.assertEqual(len(score), 2) + self.assertGreaterEqual(score[0][0], -1) + self.assertLessEqual(score[0][0], 1) + + +if __name__ == "__main__": + unittest.main()