Skip to content

Commit

Permalink
Update clip_score.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 authored Jan 7, 2025
1 parent 4ff62e8 commit 7ab790a
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) ->

def _process_image_data(images: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
"""Helper function to process image data."""
if not isinstance(images, list):
if images.ndim == 3:
images = [images]

if not isinstance(images, list) and if images.ndim == 3:
images = [images]
if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")
return images
Expand Down Expand Up @@ -126,8 +124,7 @@ def _get_features(
)
processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings]
processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings]
features = model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device))
return features
return model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device))
raise ValueError(f"invalid modality {modality}")


Expand Down

0 comments on commit 7ab790a

Please sign in to comment.