From 953a933a45196f04d8a3e0bad3a07ff16d03076a Mon Sep 17 00:00:00 2001 From: goldpulpy Date: Thu, 10 Oct 2024 22:58:25 +0300 Subject: [PATCH] Moved functions, added search similar function, added tests, updated function documentation --- pysentence_similarity/utils/__init__.py | 5 + .../{utils.py => utils/_compute_score.py} | 14 +- .../utils/_search_similar.py | 126 ++++++++++++++++++ tests/test_utils.py | 58 +++++++- 4 files changed, 197 insertions(+), 6 deletions(-) create mode 100644 pysentence_similarity/utils/__init__.py rename pysentence_similarity/{utils.py => utils/_compute_score.py} (92%) create mode 100644 pysentence_similarity/utils/_search_similar.py diff --git a/pysentence_similarity/utils/__init__.py b/pysentence_similarity/utils/__init__.py new file mode 100644 index 0000000..3e68658 --- /dev/null +++ b/pysentence_similarity/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utils functions module.""" +from ._compute_score import compute_score +from ._search_similar import search_similar + +__all__ = ["compute_score", "search_similar"] diff --git a/pysentence_similarity/utils.py b/pysentence_similarity/utils/_compute_score.py similarity index 92% rename from pysentence_similarity/utils.py rename to pysentence_similarity/utils/_compute_score.py index a1aad28..ebc39ee 100644 --- a/pysentence_similarity/utils.py +++ b/pysentence_similarity/utils/_compute_score.py @@ -1,12 +1,12 @@ -"""Utils functions module.""" +"""Compute functions module.""" import logging from typing import List, Union, Callable import numpy as np from tqdm import tqdm -from .compute import cosine -from ._storage import Storage +from ..compute import cosine +from .._storage import Storage # Set up logging logger = logging.getLogger("pysentence-similarity:utils") @@ -99,9 +99,15 @@ def compute_score( rounding: int = 2, progress_bar: bool = False ) -> List[float]: - """Compute similarity scores between a source embedding and an array of + """Compute similarity scores between a source embedding and an array of embeddings. + This function calculates similarity scores between a given source + embedding (or a list of embeddings) and a set of embeddings using + a specified similarity computation function. It allows for + flexibility in the input types and provides options for rounding + the scores and displaying a progress bar. + :param source: Source embedding for comparison. :type source: Union[np.ndarray, List[np.ndarray]] :param embeddings: Embeddings to compare against. diff --git a/pysentence_similarity/utils/_search_similar.py b/pysentence_similarity/utils/_search_similar.py new file mode 100644 index 0000000..63d53ee --- /dev/null +++ b/pysentence_similarity/utils/_search_similar.py @@ -0,0 +1,126 @@ +"""Search similar module.""" +import logging +from typing import List, Callable, Tuple, Optional +import numpy as np + +from ..compute import cosine +from .._storage import Storage, InvalidDataError +from ._compute_score import compute_score + +# Set up logging +logger = logging.getLogger("pysentence-similarity:utils") + + +def search_similar( + query_embedding: np.ndarray, + sentences: Optional[List[str]] = None, + embeddings: Optional[List[np.ndarray]] = None, + storage: Optional[Storage] = None, + top_k: int = 5, + compute_function: Callable = cosine, + rounding: int = 2, + progress_bar: bool = False, + sort_order: str = 'desc' +) -> List[Tuple[str, float]]: + """ + Search for similar sentences based on the provided query embedding. + + This function retrieves and computes similarity scores between a given + query embedding and a set of candidate sentences (and their corresponding + embeddings). It returns the top K most similar sentences based on the + specified similarity metric. + + - If `storage` is provided, it will be used to retrieve both sentences and + embeddings, allowing the other parameters (`sentences` and `embeddings`) to + be omitted. + - Similarity scores are calculated using the specified `compute_function`. + - Results can be sorted in either ascending or descending order based + on the specified `sort_order`. + + :param query_embedding: The embedding of the query sentence. + :type query_embedding: np.ndarray + :param sentences: List of candidate sentences. Optional if `storage` is + provided. + :type sentences: Optional[List[str]] + :param embeddings: List of embeddings corresponding to the sentences. + Optional if `storage` is provided. + :type embeddings: Optional[List[np.ndarray]] + :param storage: An instance of `Storage` to retrieve sentences and + embeddings from. + :type storage: Optional[Storage] + :param top_k: The number of top similar sentences to return. Default is 5. + :type top_k: int + :param compute_function: Function used to compute similarity between + embeddings. Default is cosine similarity. + :type compute_function: Callable + :param rounding: The number of decimal places to round the similarity + scores to. Default is 2. + :type rounding: int + :param progress_bar: Whether to show a progress bar during score + computation. Default is False. + :type progress_bar: bool + :param sort_order: The order to sort results by similarity score. 'asc' + for ascending, 'desc' for descending. Default is 'desc'. + :type sort_order: str + :return: A list of tuples containing the top `top_k` similar sentences and + their similarity scores. + :rtype: List[Tuple[str, float]] + :raises ValueError: If both `sentences` and `embeddings` or `storage` are + not provided. + :raises InvalidDataError: If there is an inconsistency in data (e.g., + different lengths of `sentences` and `embeddings`). + """ + + if storage is not None: + try: + sentences = storage.get_sentences() + embeddings = storage.get_embeddings() + except Exception as err: + logger.error("Failed to retrieve data from storage: %s", err) + raise InvalidDataError( + "Failed to retrieve data from storage.") from err + + if not sentences: + logger.error("No sentences provided.") + raise ValueError("No sentences provided.") + + if not embeddings: + logger.error("No embeddings provided.") + raise ValueError("No embeddings provided.") + + if len(sentences) != len(embeddings): + logger.error( + "Mismatch between number of sentences (%d) and embeddings " + "(%d).", len(sentences), len(embeddings)) + raise InvalidDataError( + "Number of sentences and embeddings must match.") + + if sort_order not in ['asc', 'desc']: + logger.error("Invalid sort order: %s", sort_order) + raise ValueError("Invalid sort order, must be 'asc' or 'desc'.") + + try: + scores = compute_score( + source=query_embedding, + embeddings=embeddings, + compute_function=compute_function, + rounding=rounding, + progress_bar=progress_bar + ) + if sort_order == 'asc': + sorted_indices = np.argsort(scores)[:top_k] + else: + sorted_indices = np.argsort(scores)[-top_k:][::-1] + + # Return top_k sentences and their scores + top_similar = [ + (sentences[i], round(scores[i], rounding)) + for i in sorted_indices + ] + return top_similar + + except Exception as err: + logger.exception("An error occurred during similarity search: %s", err) + raise RuntimeError( + "An error occurred while searching for similar sentences." + ) from err diff --git a/tests/test_utils.py b/tests/test_utils.py index 21366ac..58d858f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,8 +2,8 @@ import unittest import numpy as np -from pysentence_similarity import Model -from pysentence_similarity.utils import compute_score +from pysentence_similarity import Model, Storage +from pysentence_similarity.utils import compute_score, search_similar class TestModel(unittest.TestCase): @@ -57,6 +57,60 @@ def test_similarity_score_multiple_embeddings(self) -> None: self.assertGreaterEqual(score[0][0], -1) self.assertLessEqual(score[0][0], 1) + def test_search_similar(self) -> None: + """Test search_similar returns the correct similar sentences.""" + query_embedding = self.model.encode("This is a test.") + sentences = [ + "This is another test.", + "This is a test.", + "This is yet another test." + ] + embeddings = self.model.encode(sentences) + + results = search_similar( + query_embedding=query_embedding, + sentences=sentences, + embeddings=embeddings, + top_k=1 + ) + + self.assertIsInstance(results, list) + self.assertEqual(len(results), 1) + self.assertEqual(results[0], ("This is a test.", 1.0)) + + def test_search_similar_no_sentences(self) -> None: + """Test search_similar with no sentences raises error.""" + query_embedding = self.model.encode("This is a test.") + embeddings = np.zeros((0, 0)) + with self.assertRaises(ValueError) as context: + search_similar( + query_embedding=query_embedding, + sentences=[], + embeddings=embeddings + ) + self.assertEqual(str(context.exception), "No sentences provided.") + + def test_search_similar_empty_embeddings(self) -> None: + """Test search_similar with no sentences raises error.""" + query_embedding = self.model.encode("This is a test.") + with self.assertRaises(ValueError) as context: + search_similar( + query_embedding=query_embedding, + sentences=["This is a test."], + embeddings=[] + ) + self.assertEqual(str(context.exception), "No embeddings provided.") + + def test_search_similar_empty_storage(self) -> None: + """Test search_similar with empty storage raises error.""" + query_embedding = self.model.encode("This is a test.") + empty_storage = Storage() + with self.assertRaises(ValueError): + search_similar( + query_embedding=query_embedding, + storage=empty_storage + ) + if __name__ == "__main__": unittest.main()