Skip to content

Commit

Permalink
Added a function to get embedding by sentence
Browse files Browse the repository at this point in the history
  • Loading branch information
goldpulpy committed Oct 10, 2024
1 parent 52f7c1b commit 4b5b907
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
17 changes: 17 additions & 0 deletions pysentence_similarity/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,23 @@ def get_sentences(self) -> List[str]:
"""
return self._sentences

def get_embedding_by_sentence(self, sentence: str) -> np.ndarray:
"""
Get the embedding for the specified sentence.
:param sentence: The sentence to get the embedding for.
:type sentence: str
:return: The embedding for the specified sentence.
:rtype: np.ndarray
:raises ValueError: If the sentence is not found in the storage.
"""
try:
index = self._sentences.index(sentence)
return self._embeddings[index]
except ValueError as err:
logger.error("Sentence not found: %s", err)
raise

def get_embeddings(self) -> List[np.ndarray]:
"""
Get the list of embeddings.
Expand Down
14 changes: 13 additions & 1 deletion tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self) -> None:
"This is a test sentence.",
"This is another sentence.",
]
self.embeddings = [np.random.rand(3), np.random.rand(3),]
self.embeddings = [np.random.rand(3), np.random.rand(3)]
self.storage = Storage(
sentences=self.sentences,
embeddings=self.embeddings
Expand Down Expand Up @@ -115,6 +115,18 @@ def test_remove_by_sentence_not_found(self) -> None:
with self.assertRaises(ValueError):
self.storage.remove_by_sentence("non-existent sentence")

def test_get_embedding_by_sentence(self) -> None:
"""Test getting an embedding by sentence."""
embedding = self.storage.get_embedding_by_sentence(
"This is a test sentence."
)
self.assertIsInstance(embedding, np.ndarray)

def test_get_embedding_by_sentence_not_found(self) -> None:
"""Test getting an embedding by sentence that does not exist."""
with self.assertRaises(ValueError):
self.storage.get_embedding_by_sentence("non-existent sentence")


if __name__ == "__main__":
unittest.main()

0 comments on commit 4b5b907

Please sign in to comment.