From 4b5b907d5eae07fa7466fca7847ebd4ccf6838ef Mon Sep 17 00:00:00 2001 From: goldpulpy Date: Thu, 10 Oct 2024 15:42:33 +0300 Subject: [PATCH] Added a function to get embedding by sentence --- pysentence_similarity/_storage.py | 17 +++++++++++++++++ tests/test_storage.py | 14 +++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pysentence_similarity/_storage.py b/pysentence_similarity/_storage.py index 2a86bb0..c8dbf96 100644 --- a/pysentence_similarity/_storage.py +++ b/pysentence_similarity/_storage.py @@ -178,6 +178,23 @@ def get_sentences(self) -> List[str]: """ return self._sentences + def get_embedding_by_sentence(self, sentence: str) -> np.ndarray: + """ + Get the embedding for the specified sentence. + + :param sentence: The sentence to get the embedding for. + :type sentence: str + :return: The embedding for the specified sentence. + :rtype: np.ndarray + :raises ValueError: If the sentence is not found in the storage. + """ + try: + index = self._sentences.index(sentence) + return self._embeddings[index] + except ValueError as err: + logger.error("Sentence not found: %s", err) + raise + def get_embeddings(self) -> List[np.ndarray]: """ Get the list of embeddings. diff --git a/tests/test_storage.py b/tests/test_storage.py index cfdf372..2195a2a 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -14,7 +14,7 @@ def setUp(self) -> None: "This is a test sentence.", "This is another sentence.", ] - self.embeddings = [np.random.rand(3), np.random.rand(3),] + self.embeddings = [np.random.rand(3), np.random.rand(3)] self.storage = Storage( sentences=self.sentences, embeddings=self.embeddings @@ -115,6 +115,18 @@ def test_remove_by_sentence_not_found(self) -> None: with self.assertRaises(ValueError): self.storage.remove_by_sentence("non-existent sentence") + def test_get_embedding_by_sentence(self) -> None: + """Test getting an embedding by sentence.""" + embedding = self.storage.get_embedding_by_sentence( + "This is a test sentence." + ) + self.assertIsInstance(embedding, np.ndarray) + + def test_get_embedding_by_sentence_not_found(self) -> None: + """Test getting an embedding by sentence that does not exist.""" + with self.assertRaises(ValueError): + self.storage.get_embedding_by_sentence("non-existent sentence") + if __name__ == "__main__": unittest.main()