From 7ab790ad630ed0ae003b996f4da078800a9994cc Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 8 Jan 2025 00:26:51 +0530 Subject: [PATCH] Update clip_score.py --- src/torchmetrics/functional/multimodal/clip_score.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 399866672b4..8672d59a899 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -73,10 +73,8 @@ def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) -> def _process_image_data(images: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]: """Helper function to process image data.""" - if not isinstance(images, list): - if images.ndim == 3: - images = [images] - + if not isinstance(images, list) and if images.ndim == 3: + images = [images] if not all(i.ndim == 3 for i in images): raise ValueError("Expected all images to be 3d but found image that has either more or less") return images @@ -126,8 +124,7 @@ def _get_features( ) processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings] processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings] - features = model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device)) - return features + return model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device)) raise ValueError(f"invalid modality {modality}")