diff --git a/lit_nlp/components/nearest_neighbors.py b/lit_nlp/components/nearest_neighbors.py index ecc6862b..0b9a5eb3 100644 --- a/lit_nlp/components/nearest_neighbors.py +++ b/lit_nlp/components/nearest_neighbors.py @@ -124,7 +124,11 @@ def run( # [emb_size] dataset_embs = [output[nnconf.embedding_name] for output in dataset_outputs] + dataset_embs = [emb.astype(np.float32) for emb in dataset_embs] + example_embs = [example_output[nnconf.embedding_name]] + example_embs = [emb.astype(np.float32) for emb in example_embs] + distances = distance.cdist(example_embs, dataset_embs)[0] sorted_indices = np.argsort(distances) k = nnconf.num_neighbors