Skip to content

Commit

Permalink
Moved functions, added search similar function, added tests, updated …
Browse files Browse the repository at this point in the history
…function documentation
  • Loading branch information
goldpulpy committed Oct 10, 2024
1 parent 5d9cae1 commit 953a933
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 6 deletions.
5 changes: 5 additions & 0 deletions pysentence_similarity/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Utils functions module."""
from ._compute_score import compute_score
from ._search_similar import search_similar

__all__ = ["compute_score", "search_similar"]
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Utils functions module."""
"""Compute functions module."""
import logging
from typing import List, Union, Callable

import numpy as np
from tqdm import tqdm

from .compute import cosine
from ._storage import Storage
from ..compute import cosine
from .._storage import Storage

# Set up logging
logger = logging.getLogger("pysentence-similarity:utils")
Expand Down Expand Up @@ -99,9 +99,15 @@ def compute_score(
rounding: int = 2,
progress_bar: bool = False
) -> List[float]:
"""Compute similarity scores between a source embedding and an array of
"""Compute similarity scores between a source embedding and an array of
embeddings.
This function calculates similarity scores between a given source
embedding (or a list of embeddings) and a set of embeddings using
a specified similarity computation function. It allows for
flexibility in the input types and provides options for rounding
the scores and displaying a progress bar.
:param source: Source embedding for comparison.
:type source: Union[np.ndarray, List[np.ndarray]]
:param embeddings: Embeddings to compare against.
Expand Down
126 changes: 126 additions & 0 deletions pysentence_similarity/utils/_search_similar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Search similar module."""
import logging
from typing import List, Callable, Tuple, Optional
import numpy as np

from ..compute import cosine
from .._storage import Storage, InvalidDataError
from ._compute_score import compute_score

# Set up logging
logger = logging.getLogger("pysentence-similarity:utils")


def search_similar(
query_embedding: np.ndarray,
sentences: Optional[List[str]] = None,
embeddings: Optional[List[np.ndarray]] = None,
storage: Optional[Storage] = None,
top_k: int = 5,
compute_function: Callable = cosine,
rounding: int = 2,
progress_bar: bool = False,
sort_order: str = 'desc'
) -> List[Tuple[str, float]]:
"""
Search for similar sentences based on the provided query embedding.
This function retrieves and computes similarity scores between a given
query embedding and a set of candidate sentences (and their corresponding
embeddings). It returns the top K most similar sentences based on the
specified similarity metric.
- If `storage` is provided, it will be used to retrieve both sentences and
embeddings, allowing the other parameters (`sentences` and `embeddings`) to
be omitted.
- Similarity scores are calculated using the specified `compute_function`.
- Results can be sorted in either ascending or descending order based
on the specified `sort_order`.
:param query_embedding: The embedding of the query sentence.
:type query_embedding: np.ndarray
:param sentences: List of candidate sentences. Optional if `storage` is
provided.
:type sentences: Optional[List[str]]
:param embeddings: List of embeddings corresponding to the sentences.
Optional if `storage` is provided.
:type embeddings: Optional[List[np.ndarray]]
:param storage: An instance of `Storage` to retrieve sentences and
embeddings from.
:type storage: Optional[Storage]
:param top_k: The number of top similar sentences to return. Default is 5.
:type top_k: int
:param compute_function: Function used to compute similarity between
embeddings. Default is cosine similarity.
:type compute_function: Callable
:param rounding: The number of decimal places to round the similarity
scores to. Default is 2.
:type rounding: int
:param progress_bar: Whether to show a progress bar during score
computation. Default is False.
:type progress_bar: bool
:param sort_order: The order to sort results by similarity score. 'asc'
for ascending, 'desc' for descending. Default is 'desc'.
:type sort_order: str
:return: A list of tuples containing the top `top_k` similar sentences and
their similarity scores.
:rtype: List[Tuple[str, float]]
:raises ValueError: If both `sentences` and `embeddings` or `storage` are
not provided.
:raises InvalidDataError: If there is an inconsistency in data (e.g.,
different lengths of `sentences` and `embeddings`).
"""

if storage is not None:
try:
sentences = storage.get_sentences()
embeddings = storage.get_embeddings()
except Exception as err:
logger.error("Failed to retrieve data from storage: %s", err)
raise InvalidDataError(
"Failed to retrieve data from storage.") from err

if not sentences:
logger.error("No sentences provided.")
raise ValueError("No sentences provided.")

if not embeddings:
logger.error("No embeddings provided.")
raise ValueError("No embeddings provided.")

if len(sentences) != len(embeddings):
logger.error(
"Mismatch between number of sentences (%d) and embeddings "
"(%d).", len(sentences), len(embeddings))
raise InvalidDataError(
"Number of sentences and embeddings must match.")

if sort_order not in ['asc', 'desc']:
logger.error("Invalid sort order: %s", sort_order)
raise ValueError("Invalid sort order, must be 'asc' or 'desc'.")

try:
scores = compute_score(
source=query_embedding,
embeddings=embeddings,
compute_function=compute_function,
rounding=rounding,
progress_bar=progress_bar
)
if sort_order == 'asc':
sorted_indices = np.argsort(scores)[:top_k]
else:
sorted_indices = np.argsort(scores)[-top_k:][::-1]

# Return top_k sentences and their scores
top_similar = [
(sentences[i], round(scores[i], rounding))
for i in sorted_indices
]
return top_similar

except Exception as err:
logger.exception("An error occurred during similarity search: %s", err)
raise RuntimeError(
"An error occurred while searching for similar sentences."
) from err
58 changes: 56 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import unittest

import numpy as np
from pysentence_similarity import Model
from pysentence_similarity.utils import compute_score
from pysentence_similarity import Model, Storage
from pysentence_similarity.utils import compute_score, search_similar


class TestModel(unittest.TestCase):
Expand Down Expand Up @@ -57,6 +57,60 @@ def test_similarity_score_multiple_embeddings(self) -> None:
self.assertGreaterEqual(score[0][0], -1)
self.assertLessEqual(score[0][0], 1)

def test_search_similar(self) -> None:
"""Test search_similar returns the correct similar sentences."""
query_embedding = self.model.encode("This is a test.")
sentences = [
"This is another test.",
"This is a test.",
"This is yet another test."
]
embeddings = self.model.encode(sentences)

results = search_similar(
query_embedding=query_embedding,
sentences=sentences,
embeddings=embeddings,
top_k=1
)

self.assertIsInstance(results, list)
self.assertEqual(len(results), 1)
self.assertEqual(results[0], ("This is a test.", 1.0))

def test_search_similar_no_sentences(self) -> None:
"""Test search_similar with no sentences raises error."""
query_embedding = self.model.encode("This is a test.")
embeddings = np.zeros((0, 0))
with self.assertRaises(ValueError) as context:
search_similar(
query_embedding=query_embedding,
sentences=[],
embeddings=embeddings
)
self.assertEqual(str(context.exception), "No sentences provided.")

def test_search_similar_empty_embeddings(self) -> None:
"""Test search_similar with no sentences raises error."""
query_embedding = self.model.encode("This is a test.")
with self.assertRaises(ValueError) as context:
search_similar(
query_embedding=query_embedding,
sentences=["This is a test."],
embeddings=[]
)
self.assertEqual(str(context.exception), "No embeddings provided.")

def test_search_similar_empty_storage(self) -> None:
"""Test search_similar with empty storage raises error."""
query_embedding = self.model.encode("This is a test.")
empty_storage = Storage()
with self.assertRaises(ValueError):
search_similar(
query_embedding=query_embedding,
storage=empty_storage
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 953a933

Please sign in to comment.