Skip to content

Commit

Permalink
fix: fixes implementation of similarity() (#1748)
Browse files Browse the repository at this point in the history
* fix(#1594): fixes implementation of similarity()

* fix: add similarity to SentenceTransformerWrapper

---------

Co-authored-by: sam021313 <40773225+sam021313@users.noreply.github.com>
  • Loading branch information
sam-hey and sam-hey authored Jan 10, 2025
1 parent edd9d7f commit 3fe9264
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions mteb/evaluation/evaluators/RetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ def search(
self.corpus_embeddings[request_qid].append(sub_corpus_embeddings)

# Compute similarites using self defined similarity otherwise default to cosine-similarity
similarity_scores = cos_sim(query_embeddings, sub_corpus_embeddings)
if hasattr(self.model, "similarity"):
similarity_scores = self.model.similarity(
float(self.model.similarity(e1, e2))
for e1, e2 in zip(query_embeddings, sub_corpus_embeddings)
query_embeddings, sub_corpus_embeddings
)
else:
similarity_scores = cos_sim(query_embeddings, sub_corpus_embeddings)
is_nan = torch.isnan(similarity_scores)
if is_nan.sum() > 0:
logger.warning(
Expand Down Expand Up @@ -376,6 +376,9 @@ def __init__(self, model, **kwargs):
self.save_corpus_embeddings = kwargs.get("save_corpus_embeddings", False)
self.corpus_embeddings = {}

if hasattr(self.model, "similarity") and callable(self.model.similarity):
self.similarity = self.model.similarity

def encode_corpus(
self,
corpus: list[dict[str, str]],
Expand Down
2 changes: 1 addition & 1 deletion mteb/models/sentence_transformer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
if isinstance(self.model, CrossEncoder):
self.predict = self._predict

if hasattr(self.model, "similarity"):
if hasattr(self.model, "similarity") and callable(self.model.similarity):
self.similarity = self.model.similarity

def encode(
Expand Down

0 comments on commit 3fe9264

Please sign in to comment.