Skip to content

Commit

Permalink
Added functions of remove by index and by sentence, as well as tests …
Browse files Browse the repository at this point in the history
…for them
  • Loading branch information
goldpulpy committed Oct 9, 2024
1 parent 319e0bf commit 7cb5c5f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
33 changes: 33 additions & 0 deletions pysentence_similarity/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,39 @@ def add(
if save:
self.save(filename)

def remove_by_index(self, index: int) -> None:
"""
Remove the sentence and embedding at the specified index.
:param index: Index of the item to remove.
:type index: int
:raises IndexError: If the index is out of bounds.
:return: None
"""
try:
removed_sentence = self._sentences.pop(index)
self._embeddings.pop(index)
logger.info("Removed sentence: %s", removed_sentence)
except IndexError as err:
logger.error("Index out of range: %s", err)
raise

def remove_by_sentence(self, sentence: str) -> None:
"""
Remove the sentence and its corresponding embedding by sentence.
:param sentence: The sentence to remove.
:type sentence: str
:raises ValueError: If the sentence is not found in the storage.
:return: None
"""
try:
index = self._sentences.index(sentence)
self.remove_by_index(index)
except ValueError as err:
logger.error("Sentence not found: %s", err)
raise

def get_sentences(self) -> List[str]:
"""
Get the list of sentences.
Expand Down
36 changes: 34 additions & 2 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def setUp(self) -> None:
"""Set up test data before each test."""
self.sentences = [
"This is a test sentence.",
"This is another sentence."
"This is another sentence.",
]
self.embeddings = [np.random.rand(3), np.random.rand(3)]
self.embeddings = [np.random.rand(3), np.random.rand(3),]
self.storage = Storage(
sentences=self.sentences,
embeddings=self.embeddings
Expand Down Expand Up @@ -83,6 +83,38 @@ def test_index_out_of_range(self) -> None:
with self.assertRaises(IndexError):
self.storage[10]

def test_remove_by_index_valid(self) -> None:
"""Test removing a sentence and embedding by valid index."""
self.storage.remove_by_index(1)
expected_sentences = ["This is a test sentence."]

self.assertEqual(self.storage.get_sentences(), expected_sentences)

def test_remove_by_index_out_of_range(self) -> None:
"""Test removing a sentence and embedding by out-of-range index."""
with self.assertRaises(IndexError):
self.storage.remove_by_index(5)

def test_remove_by_index_boundary(self) -> None:
"""Test removing the first and last elements."""
self.storage.remove_by_index(0)
expected_sentences_after_first = ["This is another sentence."]
self.assertEqual(
self.storage.get_sentences(),
expected_sentences_after_first
)

def test_remove_by_sentence_valid(self) -> None:
"""Test removing a sentence and embedding by valid sentence."""
self.storage.remove_by_sentence("This is a test sentence.")
expected_sentences = ["This is another sentence."]
self.assertEqual(self.storage.get_sentences(), expected_sentences)

def test_remove_by_sentence_not_found(self) -> None:
"""Test removing a sentence that does not exist."""
with self.assertRaises(ValueError):
self.storage.remove_by_sentence("non-existent sentence")


if __name__ == "__main__":
unittest.main()

0 comments on commit 7cb5c5f

Please sign in to comment.