From 3fe92644fa53c0c8cedc92d17fb25f0012a26aab Mon Sep 17 00:00:00 2001 From: Sam <40773225+sam-hey@users.noreply.github.com> Date: Fri, 10 Jan 2025 16:26:04 +0100 Subject: [PATCH] fix: fixes implementation of similarity() (#1748) * fix(#1594): fixes implementation of similarity() * fix: add similarity to SentenceTransformerWrapper --------- Co-authored-by: sam021313 <40773225+sam021313@users.noreply.github.com> --- mteb/evaluation/evaluators/RetrievalEvaluator.py | 9 ++++++--- mteb/models/sentence_transformer_wrapper.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mteb/evaluation/evaluators/RetrievalEvaluator.py b/mteb/evaluation/evaluators/RetrievalEvaluator.py index ed3a50d71f..251498d6b3 100644 --- a/mteb/evaluation/evaluators/RetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/RetrievalEvaluator.py @@ -167,12 +167,12 @@ def search( self.corpus_embeddings[request_qid].append(sub_corpus_embeddings) # Compute similarites using self defined similarity otherwise default to cosine-similarity - similarity_scores = cos_sim(query_embeddings, sub_corpus_embeddings) if hasattr(self.model, "similarity"): similarity_scores = self.model.similarity( - float(self.model.similarity(e1, e2)) - for e1, e2 in zip(query_embeddings, sub_corpus_embeddings) + query_embeddings, sub_corpus_embeddings ) + else: + similarity_scores = cos_sim(query_embeddings, sub_corpus_embeddings) is_nan = torch.isnan(similarity_scores) if is_nan.sum() > 0: logger.warning( @@ -376,6 +376,9 @@ def __init__(self, model, **kwargs): self.save_corpus_embeddings = kwargs.get("save_corpus_embeddings", False) self.corpus_embeddings = {} + if hasattr(self.model, "similarity") and callable(self.model.similarity): + self.similarity = self.model.similarity + def encode_corpus( self, corpus: list[dict[str, str]], diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index e580ef8959..4a7dbd8ffa 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -59,7 +59,7 @@ def __init__( if isinstance(self.model, CrossEncoder): self.predict = self._predict - if hasattr(self.model, "similarity"): + if hasattr(self.model, "similarity") and callable(self.model.similarity): self.similarity = self.model.similarity def encode(