Skip to content

Commit

Permalink
Function documentation has been updated
Browse files Browse the repository at this point in the history
  • Loading branch information
goldpulpy committed Oct 10, 2024
1 parent 953a933 commit 50b5cce
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 22 deletions.
33 changes: 27 additions & 6 deletions pysentence_similarity/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
import logging
import copy
from typing import List, Union, Callable

import onnxruntime as ort
Expand Down Expand Up @@ -37,13 +38,23 @@ def __init__(
"""
Initialize the sentence similarity task model.
This constructor initializes the necessary components to load a model
for sentence similarity tasks, including the model, tokenizer, and the
device configurations.
:param model: The name of the model to be used.
:type model: str
:param dtype: The dtype of the model ('fp32', 'fp16', 'int8').
:param dtype: The data type of the model. Options include 'fp32' for
32-bit floating point, 'fp16' for 16-bit floating point, and 'int8' for
8-bit integer. Default is 'fp32'.
:type dtype: str
:param cache_dir: Directory to cache the model and tokenizer.
:param cache_dir: The directory where the model and tokenizer should be
cached. If not provided, a default cache directory is used based on
the package name.
:type cache_dir: str
:param device: Device to use for inference ('cuda', 'cpu').
:param device: The device to use for inference. Options include 'cuda'
for GPU acceleration and 'cpu' for running on the CPU.
Default is 'cpu'.
:type device: str
:raises ValueError: If the model or tokenizer cannot be loaded.
"""
Expand All @@ -56,8 +67,8 @@ def __init__(

try:
self._providers = self._get_providers()
self._tokenizer = self._load_tokenizer()
self._session = self._load_model()
self._tokenizer = self._load_tokenizer()
except Exception as err:
logger.error("Error initializing model: %s", err)
raise
Expand All @@ -77,7 +88,11 @@ def encode(
pooling_function: Callable = mean_pooling,
progress_bar: bool = False
) -> Union[np.ndarray, List[np.ndarray]]:
"""Convert a single sentence to an embedding vector.
"""Convert a single sentence or a list of sentences to an embedding
vector.
This method takes one or more sentences as input and converts them
into embedding vectors using a specified pooling function.
:param sentences: Sentence or list of sentences to convert.
:type sentences: Union[str, List[str]]
Expand Down Expand Up @@ -320,8 +335,14 @@ def __repr__(self) -> str:
"""Return a string representation of the Model object."""
return self.__str__()

def __copy__(self):
def __copy__(self) -> "Model":
"""Create a shallow copy of the Model object."""
new_instance = self.__class__.__new__(self.__class__)
new_instance.__dict__.update(self.__dict__)
return new_instance

def __deepcopy__(self, memo) -> "Model":
"""
Create a deep copy of the Model object.
"""
return copy.deepcopy(self, memo)
38 changes: 32 additions & 6 deletions pysentence_similarity/_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
class Splitter:
"""
A class to split text into sentences.
Supports splitting by periods, exclamation marks, question marks, and
newline characters.
"""

def __init__(
Expand All @@ -25,7 +23,10 @@ def __init__(
preserve_markers: bool = False,
) -> None:
"""
Initializes the Splitter object.
Initializes the Splitter object, which is used to split a given text
based on specific characters or markers. This class allows flexible
splitting based on one or more characters and provides the option to
preserve these markers in the split result.
:param markers_to_split: A string or list of characters (e.g.,
punctuation marks) used to split the text. Default is a newline
Expand All @@ -51,7 +52,13 @@ def split_from_text(
text: str,
) -> List[str]:
"""
Splits the given text into sentences based on punctuation and newlines.
Splits the given text into sentences based on specified punctuation and
newlines.
This method uses regular expressions to identify splitting points in
the input text. It can preserve split markers (such as punctuation)
based on the `preserve_markers` attribute set during initialization.
:param text: The input text to split.
:type text: str
Expand Down Expand Up @@ -93,8 +100,12 @@ def split_from_file(
file_path: str,
) -> List[str]:
"""
Splits the contents of a txt file into sentences based on punctuation
and newlines.
Splits the contents of a text file into sentences based on specified
punctuation and newlines.
This method reads the entire content of the specified text file and
utilizes the `split_from_text` method to split the content into
sentences. It expects the file to be encoded in UTF-8.
:param file_path: The path to the file to split.
:type file_path: str
Expand Down Expand Up @@ -127,6 +138,10 @@ def split_from_url(
Fetches the content from a URL, removes HTML tags, and splits the
cleaned text into sentences.
This method retrieves the content from the provided URL, removes all
HTML tags, and splits the remaining plain text into sentences based on
the specified split markers.
:param url: The URL of the webpage to split.
:type url: str
:param timeout: The number of seconds to wait for the request to
Expand Down Expand Up @@ -166,6 +181,11 @@ def split_from_csv(
Reads a CSV file and splits the text from specified columns into
sentences.
This method reads the contents of a CSV file, extracts text from the
specified columns, and then splits the text into sentences based on
the markers defined in the `Splitter` object. It can handle multiple
columns and combines the results into a single list of sentences.
:param file_path: The path to the CSV file to read.
:type file_path: str
:param column_names: A list of column names to extract text from.
Expand Down Expand Up @@ -234,6 +254,12 @@ def split_from_json(self, file_path: str, keys: List[str]) -> List[str]:
"""
Reads a JSON file and splits text from specified keys into sentences.
This method processes a JSON file by extracting text values from
specified keys. The extracted text is then split into sentences based
on the markers defined in the `Splitter` object. It can handle nested
JSON structures and recursively extract values from deeply nested
objects.
:param file_path: The path to the JSON file to read.
:type file_path: str
:param keys: A list of keys to extract text from.
Expand Down
47 changes: 42 additions & 5 deletions pysentence_similarity/_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Class to store embeddings in memory."""
import logging
from typing import List, Optional, Union
import copy
from typing import List, Optional, Union, Tuple

import h5py
import numpy as np
Expand All @@ -25,6 +26,10 @@ def __init__(
"""
Initialize the storage class.
This constructor initializes an instance of the storage class, allowing
for the optional provision of a list of sentences and their
corresponding embeddings.
:param sentences: List of sentences.
:type sentences: List[str], optional
:param embeddings: List of embeddings.
Expand All @@ -41,7 +46,9 @@ def save(self, filename: str) -> None:
"""
Save the embeddings and sentences to a file.
Save the embeddings and sentences to a file.
This method saves the embeddings and sentences into an HDF5 file format
using the h5py library. It validates the data before saving to ensure
that it is in the correct format.
:param filename: The name of the file to save the embeddings to.
:type filename: str
Expand All @@ -68,6 +75,10 @@ def load(filename: str) -> "Storage":
Factory method to load the embeddings and sentences from a file and
return a new Storage instance.
This method reads the embeddings and sentences from an HDF5 file
using the h5py library. It constructs and returns a new instance of
the Storage class with the loaded data.
:param filename: The name of the file to load the embeddings from.
:type filename: str
:return: A new instance of Storage class populated with the loaded
Expand Down Expand Up @@ -99,7 +110,11 @@ def add(
filename: str = None
) -> None:
"""
Add a new sentences and embeddings to the storage.
Add new sentences and embeddings to the storage.
This method appends new sentences and their corresponding embeddings
to the internal storage. If specified, it can also save the updated
data to a file.
:param sentence: The sentence to add.
:type sentence: Union[str, List[str]]
Expand Down Expand Up @@ -140,6 +155,10 @@ def remove_by_index(self, index: int) -> None:
"""
Remove the sentence and embedding at the specified index.
This method removes a sentence and its corresponding embedding
from the storage based on the provided index. If the index is
out of range, it raises an IndexError.
:param index: Index of the item to remove.
:type index: int
:raises IndexError: If the index is out of bounds.
Expand All @@ -157,6 +176,10 @@ def remove_by_sentence(self, sentence: str) -> None:
"""
Remove the sentence and its corresponding embedding by sentence.
This method searches for a specific sentence in the storage and
removes it along with its corresponding embedding. If the sentence
is not found, it raises a ValueError.
:param sentence: The sentence to remove.
:type sentence: str
:raises ValueError: If the sentence is not found in the storage.
Expand All @@ -173,6 +196,8 @@ def get_sentences(self) -> List[str]:
"""
Get the list of sentences.
This method retrieves the stored sentences from the storage.
:return: The list of sentences.
:rtype: List[str]
"""
Expand All @@ -182,6 +207,10 @@ def get_embedding_by_sentence(self, sentence: str) -> np.ndarray:
"""
Get the embedding for the specified sentence.
This method retrieves the stored embedding corresponding to the given
sentence.
:param sentence: The sentence to get the embedding for.
:type sentence: str
:return: The embedding for the specified sentence.
Expand All @@ -199,6 +228,8 @@ def get_embeddings(self) -> List[np.ndarray]:
"""
Get the list of embeddings.
This method retrieves all stored embeddings.
:return: The list of embeddings.
:rtype: List[np.ndarray]
"""
Expand Down Expand Up @@ -252,11 +283,17 @@ def __copy__(self):
new_instance.__dict__.update(self.__dict__)
return new_instance

def __deepcopy__(self, memo) -> "Storage":
"""
Create a deep copy of the Storage object.
"""
return copy.deepcopy(self, memo)

def __len__(self) -> int:
"""Return the number of sentences."""
return len(self._sentences)

def __getitem__(self, index: int) -> List[Union[str, np.ndarray]]:
def __getitem__(self, index: int) -> Tuple[str, np.ndarray]:
"""
Get the sentence and embedding at the specified index.
Expand All @@ -266,7 +303,7 @@ def __getitem__(self, index: int) -> List[Union[str, np.ndarray]]:
:raises IndexError: If the index is out of bounds.
"""
try:
return [self._sentences[index], self._embeddings[index]]
return self._sentences[index], self._embeddings[index]
except IndexError as e:
logger.error("Index out of range: %s", e)
raise
Loading

0 comments on commit 50b5cce

Please sign in to comment.