Skip to content

Commit

Permalink
added basic test
Browse files Browse the repository at this point in the history
  • Loading branch information
koaning committed Apr 5, 2024
1 parent 5a29e15 commit d92bdae
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
7 changes: 2 additions & 5 deletions embetter/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from embetter.error import NotInstalled

try:
from embetter.text._sbert import SentenceEncoder
except ModuleNotFoundError:
SentenceEncoder = NotInstalled("SentenceEncoder", "sentence-tfm")
from embetter.text._sbert import SentenceEncoder, MatrouskaEncoder

try:
from embetter.text._s2v import Sense2VecEncoder
Expand Down Expand Up @@ -36,6 +32,7 @@

__all__ = [
"SentenceEncoder",
"MatrouskaEncoder",
"Sense2VecEncoder",
"BytePairEncoder",
"spaCyEncoder",
Expand Down
8 changes: 4 additions & 4 deletions embetter/text/_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from embetter.base import EmbetterBase


class MatrouskaEncoder(EmbetterBase):
def __init__(self, name="tomaarsen/mpnet-base-nli-matryoshka", **kwargs):
return SentenceEncoder(name=name, **kwargs)

class SentenceEncoder(EmbetterBase):
"""
Encoder that can numerically encode sentences.
Expand Down Expand Up @@ -98,3 +94,7 @@ def transform(self, X, y=None):
X = X.to_numpy()

return self.tfm.encode(X)


def MatrouskaEncoder(name="tomaarsen/mpnet-base-nli-matryoshka", **kwargs):
return SentenceEncoder(name=name, **kwargs)
12 changes: 7 additions & 5 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SentenceEncoder,
GensimEncoder,
spaCyEncoder,
MatrouskaEncoder,
learn_lite_text_embeddings,
LiteTextEncoder,
)
Expand Down Expand Up @@ -46,17 +47,18 @@ def test_word2vec(setting):
assert repr(encoder)


def test_basic_sentence_encoder():
@pytest.mark.parametrize("encoder", [MatrouskaEncoder, SentenceEncoder])
def test_basic_sentence_encoder(encoder):
"""Check correct dimensions and repr for SentenceEncoder."""
encoder = SentenceEncoder()
enc = encoder()
# Embedding dim of underlying model
output_dim = encoder.tfm._modules["1"].word_embedding_dimension
output = encoder.fit_transform(test_sentences)
output_dim = enc.tfm._modules["1"].word_embedding_dimension
output = enc.fit_transform(test_sentences)
assert isinstance(output, np.ndarray)
assert output.shape == (len(test_sentences), output_dim)
# scikit-learn configures repr dynamically from defined attributes.
# To test correct implementation we should test if calling repr breaks.
assert repr(encoder)
assert repr(enc)


@pytest.mark.parametrize("setting", ["max", "mean", "both"])
Expand Down

0 comments on commit d92bdae

Please sign in to comment.