From 84f0186e8971a21bcda9b446a8a74f0f1a958f1c Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 4 Mar 2025 12:24:33 +0100 Subject: [PATCH] Add aya (#36521) * initial commit * small fix * move stuff to image processing file * remove stuff in validate turn and fix return tensor * remove liquid stuff * in the process of addressing comments * changes to get the right tokenization * new __init__ works * fixing defulat std and mean * works * small testing scipt -- to be deleted before merge * remove redundant code * addressing comments * fix inits, add docs templates * refactor processor, switch to gotocr image processor * remove image proc from init * refactor to working llava-style architecture * Change AyaVisionModel to AyaVisionForConditionalGeneration * add tests * fixups * update doc * Adding logits_to_keep explicitly in ayavision forward to enable compatibility with cohere model * better variable names + remove code paths * Updates to aya_vision.md * address comments * adding copied from * make style and remove unused projector_hidden_act from config * sort init * include usage of fast image proc and proc on cuda in doc * update checkpoint iin test processor * update checkpoint in test processor 2 * remove test_model and update docstring * skip failing tests --------- Co-authored-by: Saurabh Dash Co-authored-by: yonigozlan --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/aya_vision.md | 243 ++++++++ src/transformers/__init__.py | 7 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/aya_vision/__init__.py | 28 + .../aya_vision/configuration_aya_vision.py | 109 ++++ .../models/aya_vision/modeling_aya_vision.py | 503 +++++++++++++++ .../aya_vision/processing_aya_vision.py | 264 ++++++++ src/transformers/utils/dummy_pt_objects.py | 14 + tests/generation/test_utils.py | 13 +- tests/models/aya_vision/__init__.py | 0 .../aya_vision/test_modeling_aya_vision.py | 576 ++++++++++++++++++ .../aya_vision/test_processor_aya_vision.py | 164 +++++ 17 files changed, 1928 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/model_doc/aya_vision.md create mode 100644 src/transformers/models/aya_vision/__init__.py create mode 100644 src/transformers/models/aya_vision/configuration_aya_vision.py create mode 100644 src/transformers/models/aya_vision/modeling_aya_vision.py create mode 100644 src/transformers/models/aya_vision/processing_aya_vision.py create mode 100644 tests/models/aya_vision/__init__.py create mode 100644 tests/models/aya_vision/test_modeling_aya_vision.py create mode 100644 tests/models/aya_vision/test_processor_aya_vision.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 055b1d0844a0..624f4d7352ec 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -874,6 +874,8 @@ title: AltCLIP - local: model_doc/aria title: Aria + - local: model_doc/aya_vision + title: AyaVision - local: model_doc/blip title: BLIP - local: model_doc/blip-2 diff --git a/docs/source/en/model_doc/aya_vision.md b/docs/source/en/model_doc/aya_vision.md new file mode 100644 index 000000000000..17daf4949206 --- /dev/null +++ b/docs/source/en/model_doc/aya_vision.md @@ -0,0 +1,243 @@ + + +# AyaVision + +## Overview + +The Aya Vision 8B and 32B models is a state-of-the-art multilingual multimodal models developed by Cohere For AI. They build on the Aya Expanse recipe to handle both visual and textual information without compromising on the strong multilingual textual performance of the original model. + +Aya Vision 8B combines the `Siglip2-so400-384-14` vision encoder with the Cohere CommandR-7B language model further post-trained with the Aya Expanse recipe, creating a powerful vision-language model capable of understanding images and generating text across 23 languages. Whereas, Aya Vision 32B uses Aya Expanse 32B as the language model. + +Key features of Aya Vision include: +- Multimodal capabilities in 23 languages +- Strong text-only multilingual capabilities inherited from CommandR-7B post-trained with the Aya Expanse recipe and Aya Expanse 32B +- High-quality visual understanding using the Siglip2-so400-384-14 vision encoder +- Seamless integration of visual and textual information in 23 languages. + + + +Tips: + +- Aya Vision is a multimodal model that takes images and text as input and produces text as output. +- Images are represented using the `` tag in the templated input. +- For best results, use the `apply_chat_template` method of the processor to format your inputs correctly. +- The model can process multiple images in a single conversation. +- Aya Vision can understand and generate text in 23 languages, making it suitable for multilingual multimodal applications. + +This model was contributed by [saurabhdash](https://huggingface.co/saurabhdash) and [yonigozlan](https://huggingface.co/yonigozlan). + + +## Usage + +Here's how to use Aya Vision for inference: + +```python +from transformers import AutoProcessor, AutoModelForImageTextToText +import torch + +model_id = "CohereForAI/aya-vision-8b" +torch_device = "cuda:0" + +# Use fast image processor +processor = AutoProcessor.from_pretrained(model_id, use_fast=True) +model = AutoModelForImageTextToText.from_pretrained( + model_id, device_map=torch_device, torch_dtype=torch.float16 +) + +# Format message with the aya-vision chat template +messages = [ + {"role": "user", + "content": [ + {"type": "image", "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium"}, + {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"}, + ]}, + ] + +# Process image on CUDA +inputs = processor.apply_chat_template( + messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device +).to(model.device) + +gen_tokens = model.generate( + **inputs, + max_new_tokens=300, + do_sample=True, + temperature=0.3, +) + +gen_text = print(processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)) +``` +### Pipeline + +```python +from transformers import pipeline + +pipe = pipeline(model="CohereForAI/aya-vision-8b", task="image-text-to-text", device_map="auto") + +# Format message with the aya-vision chat template +messages = [ + {"role": "user", + "content": [ + {"type": "image", "url": "https://media.istockphoto.com/id/458012057/photo/istanbul-turkey.jpg?s=612x612&w=0&k=20&c=qogAOVvkpfUyqLUMr_XJQyq-HkACXyYUSZbKhBlPrxo="}, + {"type": "text", "text": "Bu resimde hangi anıt gösterilmektedir?"}, + ]}, + ] +outputs = pipe(text=messages, max_new_tokens=300, return_full_text=False) + +print(outputs) +``` + +### Multiple Images and Batched Inputs + +Aya Vision can process multiple images in a single conversation. Here's how to use it with multiple images: + +```python +from transformers import AutoProcessor, AutoModelForImageTextToText +import torch + +model_id = "CohereForAI/aya-vision-8b" + +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForImageTextToText.from_pretrained( + model_id, device_map="cuda:0", torch_dtype=torch.float16 +) + +# Example with multiple images in a single message +messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + { + "type": "image", + "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + }, + { + "type": "text", + "text": "These images depict two different landmarks. Can you identify them?", + }, + ], + }, +] + +inputs = processor.apply_chat_template( + messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" +).to(model.device) + +gen_tokens = model.generate( + **inputs, + max_new_tokens=300, + do_sample=True, + temperature=0.3, +) + +gen_text = processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) +print(gen_text) +``` + +For processing batched inputs (multiple conversations at once): + +```python +from transformers import AutoProcessor, AutoModelForImageTextToText +import torch + +model_id = "CohereForAI/aya-vision-8b" + +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForImageTextToText.from_pretrained( + model_id, device_map="cuda:0", torch_dtype=torch.float16 +) + +# Prepare two different conversations +batch_messages = [ + # First conversation with a single image + [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + {"type": "text", "text": "Write a haiku for this image"}, + ], + }, + ], + # Second conversation with multiple images + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + { + "type": "image", + "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + }, + { + "type": "text", + "text": "These images depict two different landmarks. Can you identify them?", + }, + ], + }, + ], +] + +# Process each conversation separately and combine into a batch +batch_inputs = processor.apply_chat_template( + batch_messages, + padding=True, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt" +).to(model.device) + +# Generate responses for the batch +batch_outputs = model.generate( + **batch_inputs, + max_new_tokens=300, + do_sample=True, + temperature=0.3, +) + +# Decode the generated responses +for i, output in enumerate(batch_outputs): + response = processor.tokenizer.decode( + output[batch_inputs.input_ids.shape[1]:], + skip_special_tokens=True + ) + print(f"Response {i+1}:\n{response}\n") +``` + +## AyaVisionProcessor + +[[autodoc]] AyaVisionProcessor + +## AyaVisionConfig + +[[autodoc]] AyaVisionConfig + +## AyaVisionForConditionalGeneration + +[[autodoc]] AyaVisionForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f05a7b3b2c19..4b6994654832 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -194,6 +194,7 @@ "AutoTokenizer", ], "models.autoformer": ["AutoformerConfig"], + "models.aya_vision": ["AyaVisionConfig", "AyaVisionProcessor"], "models.bamba": ["BambaConfig"], "models.bark": [ "BarkCoarseConfig", @@ -1600,6 +1601,7 @@ "AutoformerPreTrainedModel", ] ) + _import_structure["models.aya_vision"].extend(["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]) _import_structure["models.bamba"].extend( [ "BambaForCausalLM", @@ -5320,6 +5322,10 @@ from .models.autoformer import ( AutoformerConfig, ) + from .models.aya_vision import ( + AyaVisionConfig, + AyaVisionProcessor, + ) from .models.bamba import BambaConfig from .models.bark import ( BarkCoarseConfig, @@ -6765,6 +6771,7 @@ AutoformerModel, AutoformerPreTrainedModel, ) + from .models.aya_vision import AyaVisionForConditionalGeneration, AyaVisionPreTrainedModel from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel from .models.bark import ( BarkCausalModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 74dad4a2418b..3884daabd973 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -20,6 +20,7 @@ audio_spectrogram_transformer, auto, autoformer, + aya_vision, bamba, bark, bart, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 8b2b514496d8..fa4de1955430 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -39,6 +39,7 @@ ("aria_text", "AriaTextConfig"), ("audio-spectrogram-transformer", "ASTConfig"), ("autoformer", "AutoformerConfig"), + ("aya_vision", "AyaVisionConfig"), ("bamba", "BambaConfig"), ("bark", "BarkConfig"), ("bart", "BartConfig"), @@ -359,6 +360,7 @@ ("aria_text", "AriaText"), ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), ("autoformer", "Autoformer"), + ("aya_vision", "AyaVision"), ("bamba", "Bamba"), ("bark", "Bark"), ("bart", "BART"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index cf6518c41760..d9fd502c1fae 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -818,6 +818,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( [ ("aria", "AriaForConditionalGeneration"), + ("aya_vision", "AyaVisionForConditionalGeneration"), ("blip", "BlipForConditionalGeneration"), ("blip-2", "Blip2ForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 03b8c860f60b..2d6da5ac13b4 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -48,6 +48,7 @@ ("align", "AlignProcessor"), ("altclip", "AltCLIPProcessor"), ("aria", "AriaProcessor"), + ("aya_vision", "AyaVisionProcessor"), ("bark", "BarkProcessor"), ("blip", "BlipProcessor"), ("blip-2", "Blip2Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 61c2c2e23d2f..57bcd31296cc 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -69,6 +69,7 @@ ), ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("bart", ("BartTokenizer", "BartTokenizerFast")), ( diff --git a/src/transformers/models/aya_vision/__init__.py b/src/transformers/models/aya_vision/__init__.py new file mode 100644 index 000000000000..f8be47cb228b --- /dev/null +++ b/src/transformers/models/aya_vision/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_aya_vision import * + from .modeling_aya_vision import * + from .processing_aya_vision import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/aya_vision/configuration_aya_vision.py b/src/transformers/models/aya_vision/configuration_aya_vision.py new file mode 100644 index 000000000000..574a5755abd6 --- /dev/null +++ b/src/transformers/models/aya_vision/configuration_aya_vision.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2025 Cohere team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AyaVision model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class AyaVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AyaVisionForConditionalGeneration`]. It is used to instantiate an + AyaVision model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of AyaVision. + e.g. [CohereForAI/aya-vision-8b](https://huggingface.co/CohereForAI/aya-vision-8b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + vision_feature_layer (`int`, *optional*, defaults to -1): + The index of the layer to select the vision feature. + downsample_factor (`int`, *optional*, defaults to 2): + The downsample factor to apply to the vision features. + adapter_layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon value used for layer normalization in the adapter. + image_token_index (`int`, *optional*, defaults to 255036): + The image token index to encode the image prompt. + """ + + model_type = "aya_vision" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + vision_feature_select_strategy="full", + vision_feature_layer=-1, + downsample_factor=2, + adapter_layer_norm_eps=1e-6, + image_token_index=255036, + **kwargs, + ): + self.image_token_index = image_token_index + self.downsample_factor = downsample_factor + self.adapter_layer_norm_eps = adapter_layer_norm_eps + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["siglip_vision_model"]( + hidden_size=1152, + intermediate_size=4304, + patch_size=14, + image_size=384, + num_hidden_layers=26, + num_attention_heads=14, + vision_use_head=False, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["cohere2"]() + + self.text_config = text_config + + super().__init__(**kwargs) + + +__all__ = ["AyaVisionConfig"] diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py new file mode 100644 index 000000000000..f2ae7b8858e3 --- /dev/null +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -0,0 +1,503 @@ +# coding=utf-8 +# Copyright 2025 the Cohere Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch AyaVision model.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_aya_vision import AyaVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "AyaVisionConfig" + + +# copied from transformers.models.Llava.modeling_llava.LlavaCausalLMOutputWithPast +@dataclass +class AyaVisionCausalLMOutputWithPast(ModelOutput): + """ + Base class for AyaVision causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class AyaVisionMultiModalProjector(nn.Module): + def __init__(self, config: AyaVisionConfig): + super().__init__() + self.config = config + self.downsample_factor = config.downsample_factor + self.alignment_intermediate_size = getattr( + config, "alignment_intermediate_size", config.text_config.hidden_size + ) + self.layernorm = nn.LayerNorm( + config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps + ) + + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * (config.downsample_factor**2), + self.alignment_intermediate_size, + bias=True, + ) + + self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation + # For SwiGLU, project down to half size since we split intermediate dim + self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True) + + def forward(self, image_features): + image_features = self.pixel_shuffle(image_features) + image_features = self.layernorm(image_features) + hidden_states = self.linear_1(image_features) + + # Split along last dimension and apply SwiGLU + x, gate = hidden_states.chunk(2, dim=-1) + hidden_states = self.act(gate) * x + + hidden_states = self.linear_2(hidden_states) + return hidden_states + + def pixel_shuffle(self, image_features): # B, S, D + batch_size, seq_length, feature_dim = image_features.shape + height = width = int(seq_length**0.5) + image_features = image_features.reshape(image_features.shape[0], width, height, -1) + channels = image_features.shape[-1] + image_features = image_features.reshape( + batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor) + ) + image_features = image_features.permute(0, 2, 1, 3) + image_features = image_features.reshape( + batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1 + ) + image_features = image_features.permute(0, 2, 1, 3) + return image_features + + +AYA_VISION_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`AyaVisionConfig`] or [`AyaVisionVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Aya Vision Model outputting raw hidden-states without any specific head on top.", + AYA_VISION_START_DOCSTRING, +) +class AyaVisionPreTrainedModel(PreTrainedModel): + config_class = AyaVisionConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["AyaVisionVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + # important: this ported version of AyaVision isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +AYA_VISION_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`CohereProcessor`] uses + [`GotOcr2ImageProcessor`] for processing images. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + vision_feature_layer (`int`, *optional*, defaults to -2): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """The AyaVision model which consists of a vision backbone and a language model.""", + AYA_VISION_START_DOCSTRING, +) +class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin): + def __init__(self, config: AyaVisionConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = AyaVisionMultiModalProjector(config) + + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{key}" for key in self.language_model._tied_weights_keys] + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_image_feature = image_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + + return image_features + + @add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=AyaVisionCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + num_logits_to_keep: int = 0, + **lm_kwargs, + ) -> Union[Tuple, AyaVisionCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + last_cache_position=last_cache_position, + num_logits_to_keep=num_logits_to_keep, + logits_to_keep=0, + **lm_kwargs, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return AyaVisionCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + **kwargs, + ) + + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"] diff --git a/src/transformers/models/aya_vision/processing_aya_vision.py b/src/transformers/models/aya_vision/processing_aya_vision.py new file mode 100644 index 000000000000..392c44993c54 --- /dev/null +++ b/src/transformers/models/aya_vision/processing_aya_vision.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Union + +from transformers.processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from ...image_processing_utils import BatchFeature +from ...image_utils import ( + ImageInput, + make_flat_list_of_images, +) + + +class AyaVisionImagesKwargs(ImagesKwargs, total=False): + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +class AyaVisionProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: AyaVisionImagesKwargs + _defaults = { + "text_kwargs": { + "padding_side": "left", + "padding": True, + }, + "images_kwargs": { + "crop_to_patches": True, + }, + } + + +class AyaVisionProcessor(ProcessorMixin): + r""" + Constructs a AyaVision processor which wraps a [`AutoImageProcessor`] and + [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and + tokenizer functionalities. See the [`~AyaVisionProcessor.__call__`] and [`~AyaVisionProcessor.decode`] for more information. + Args: + image_processor ([`AutoImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*, defaults to 28): + The size of image patches for tokenization. + img_size (`int`, *optional*, defaults to 364): + The size of the image to be tokenized. This should correspond to the size given to the image processor. + vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`): + The feature selection strategy used to select the vision feature from the vision backbone. + image_token (`str`, *optional*, defaults to `""`): + The token to be used to represent an image in the text. + downsample_factor (`int`, *optional*, defaults to 1): + The factor by which to scale the patch size. + start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`): + The token to be used to represent the start of an image in the text. + end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`): + The token to be used to represent the end of an image in the text. + img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`): + The token to be used to represent an image patch in the text. + img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`): + The token to be used to represent a line break in the text. + tile_token (`str`, *optional*, defaults to `"TILE"`): + The token to be used to represent an image patch in the text. + tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`): + The token to be used to represent the cover image in the text. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "image_token", + "patch_size", + "img_size", + "downsample_factor", + "vision_feature_select_strategy", + "start_of_img_token", + "end_of_img_token", + "img_patch_token", + "img_line_break_token", + "tile_token", + "tile_global_token", + ] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 28, + img_size: int = 364, + vision_feature_select_strategy="full", + image_token="", # set the default and let users change if they have peculiar special tokens in rare cases + downsample_factor: int = 1, + start_of_img_token="<|START_OF_IMG|>", + end_of_img_token="<|END_OF_IMG|>", + img_patch_token="<|IMG_PATCH|>", + img_line_break_token="<|IMG_LINE_BREAK|>", + tile_token="TILE", + tile_global_token="TILE_GLOBAL", + chat_template=None, + **kwargs, + ): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + self.image_token = image_token + self.patch_size = patch_size * downsample_factor + self.img_size = img_size + self.vision_feature_select_strategy = vision_feature_select_strategy + + self.start_of_img_token = start_of_img_token + self.end_of_img_token = end_of_img_token + self.img_patch_token = img_patch_token + self.img_line_break_token = img_line_break_token + self.tile_token = tile_token + self.tile_global_token = tile_global_token + + def _prompt_split_image(self, num_patches): + """ + Create a structured string representation of image tokens + + Args: + num_patches: Number of patches in the image + + Returns: + String with appropriate image tokens + """ + + img_patches_per_tile = (self.img_size // self.patch_size) ** 2 + img_string = f"{self.start_of_img_token}" + if num_patches > 1: + for idx in range(1, num_patches): + img_string += f"{self.tile_token}_{idx}" + f"{self.img_patch_token}" * img_patches_per_tile + + img_string += f"{self.tile_global_token}" + f"{self.img_patch_token}" * img_patches_per_tile + img_string += f"{self.end_of_img_token}" + return img_string + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[AyaVisionProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text. + To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to + GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if text is None: + raise ValueError("You have to specify text.") + + output_kwargs = self._merge_kwargs( + AyaVisionProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if not isinstance(text, (list, tuple)): + text = [text] + + # Process images + image_inputs = {} + if images is not None: + images = make_flat_list_of_images(images) + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + num_patches = image_inputs.pop("num_patches") + image_index = 0 + img_start_idx = 0 + processed_text = [] + image_num_patches = [] + for prompt in text: + new_prompt = prompt + curr_num_image_patches = 0 + while "" in new_prompt: + # Replace the image placeholder with structured image tokens + image_tokens = self._prompt_split_image(num_patches[image_index]) + new_prompt = new_prompt.replace("", image_tokens, 1) + curr_num_image_patches += num_patches[image_index] + image_index += 1 + + processed_text.append(new_prompt) + image_num_patches.append(curr_num_image_patches.item()) + img_start_idx += curr_num_image_patches + + if image_index != len(images): + raise ValueError("Number of image placeholders in the prompt does not match the number of images.") + + text = processed_text + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(tokenizer_input_names) + list(image_processor_input_names) + + +__all__ = ["AyaVisionProcessor"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d409238588d0..5b43469abe5f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1146,6 +1146,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class AyaVisionForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AyaVisionPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BambaForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 119c466d9ed2..f66dca2125d8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -113,7 +113,18 @@ # TODO: raushan remove this when VLMs start accepting input embeds -VLM_CLASS_NAMES = ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3", "gotocr2", "qwen2vl", "qwen2_5_vl"] +VLM_CLASS_NAMES = [ + "llava", + "idefics2", + "idefics3", + "mllama", + "paligemma", + "emu3", + "gotocr2", + "qwen2vl", + "qwen2_5_vl", + "ayavision", +] class GenerationTesterMixin: diff --git a/tests/models/aya_vision/__init__.py b/tests/models/aya_vision/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/aya_vision/test_modeling_aya_vision.py b/tests/models/aya_vision/test_modeling_aya_vision.py new file mode 100644 index 000000000000..dc8fe2503df1 --- /dev/null +++ b/tests/models/aya_vision/test_modeling_aya_vision.py @@ -0,0 +1,576 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch GotOcr2 model.""" + +import unittest + +from parameterized import parameterized + +from transformers import ( + AutoProcessor, + AyaVisionConfig, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + cleanup, + require_read_token, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + AyaVisionForConditionalGeneration, + ) + + +if is_vision_available(): + pass + + +class AyaVisionVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + vision_feature_layer=-1, + downsample_factor=2, + ignore_index=-100, + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + image_token_index=1, + num_channels=3, + image_size=64, + model_type="aya_vision", + is_training=True, + text_config={ + "model_type": "cohere2", + "vocab_size": 99, + "hidden_size": 128, + "intermediate_size": 37, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "output_channels": 64, + "hidden_act": "silu", + "max_position_embeddings": 512, + "tie_word_embeddings": True, + "bos_token_id": 0, + "eos_token_id": 0, + "pad_token_id": 0, + }, + vision_config={ + "model_type": "siglip_vision_model", + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 128, + "image_size": 64, + "patch_size": 8, + "vision_use_head": False, + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.image_token_index = image_token_index + self.model_type = model_type + self.text_config = text_config + self.vision_config = vision_config + self.batch_size = batch_size + self.vision_feature_layer = vision_feature_layer + self.downsample_factor = downsample_factor + self.is_training = is_training + self.num_channels = num_channels + self.image_size = image_size + self.image_seq_length = (image_size // (vision_config["patch_size"] * downsample_factor)) ** 2 + self.seq_length = seq_length + self.image_seq_length + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + + def get_config(self): + return AyaVisionConfig( + text_config=self.text_config, + vision_config=self.vision_config, + model_type=self.model_type, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + image_token_index=self.image_token_index, + vision_feature_layer=self.vision_feature_layer, + downsample_factor=self.downsample_factor, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + print("attention_mask", attention_mask.shape) + # input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.image_token_index] = self.pad_token_id + input_ids[:, : self.image_seq_length] = self.image_token_index + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask): + model = AyaVisionForConditionalGeneration(config=config) + model.to(torch_device) + model.half() + model.eval() + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask): + config.torch_dtype = torch.float16 + model = AyaVisionForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type="cuda", dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (AyaVisionForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-text-to-text": AyaVisionForConditionalGeneration, + } + if is_torch_available() + else {} + ) + fx_compatible = False + test_pruning = False + test_torchscript = False + test_head_masking = False + _is_composite = True + + def setUp(self): + self.model_tester = AyaVisionVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=AyaVisionConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + + @unittest.skip("Failing because of unique cache (HybridCache)") + def test_model_outputs_equivalence(self, **kwargs): + pass + + @unittest.skip("Cohere2's forcefully disables sdpa due to softcapping") + def test_sdpa_can_dispatch_non_composite_models(self): + pass + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different") + def test_eager_matches_sdpa_inference(self): + pass + + @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different") + def test_eager_matches_sdpa_generate(self): + pass + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("Cohere2 has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation") + def test_beam_search_low_memory(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_with_static_cache(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.") + def test_generate_continue_from_inputs_embeds(self): + pass + + @unittest.skip("Failing because of unique cache (HybridCache)") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different") + def test_sdpa_equivalence(self): + pass + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training(self): + pass + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipVisionModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + @unittest.skip(reason="Compile not yet supported because in LLava models") + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + # todo: yoni - fix or improve the test + @unittest.skip("Difference is slightly higher than the threshold") + def test_batching_equivalence(self): + pass + + +@require_read_token +@require_torch +class AyaVisionIntegrationTest(unittest.TestCase): + def setUp(self): + self.model_checkpoint = "CohereForAI/aya-vision-8b" + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + @require_torch_gpu + def test_small_model_integration_forward(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = AyaVisionForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.float16 + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, + {"type": "text", "text": "Please describe the image explicitly."}, + ], + } + ] + + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(torch_device, dtype=torch.float16) + # Forward + with torch.inference_mode(): + output = model(**inputs) + + actual_logits = output.logits[0, -1, :5].cpu() + print("actual_logits", actual_logits) + expected_logits = torch.tensor([0.4109, 0.1532, 0.8018, 2.1328, 0.5483], dtype=torch.float16) + self.assertTrue( + torch.allclose(actual_logits, expected_logits, atol=0.1), + f"Actual logits: {actual_logits}" + f"\nExpected logits: {expected_logits}" + f"\nDifference: {torch.abs(actual_logits - expected_logits)}", + ) + + @slow + @require_torch_gpu + def test_small_model_integration_generate_text_only(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = AyaVisionForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.float16 + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Write a haiku"}, + ], + } + ] + + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(torch_device, dtype=torch.float16) + with torch.no_grad(): + generate_ids = model.generate(**inputs, max_new_tokens=25, do_sample=False) + decoded_output = processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + print("decoded_output", decoded_output) + expected_output = "Whispers on the breeze,\nLeaves dance under moonlit skies,\nNature's quiet song." + self.assertEqual(decoded_output, expected_output) + + @slow + @require_torch_gpu + def test_small_model_integration_generate_chat_template(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = AyaVisionForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.float16 + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, + {"type": "text", "text": "Please describe the image explicitly."}, + ], + } + ] + + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(torch_device, dtype=torch.float16) + with torch.no_grad(): + generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False) + decoded_output = processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + print("decoded_output", decoded_output) + expected_output = "The image depicts a cozy scene of two cats resting on a bright pink blanket. The cats," # fmt: skip + self.assertEqual(decoded_output, expected_output) + + @slow + @require_torch_gpu + def test_small_model_integration_batched_generate(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = AyaVisionForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.float16 + ) + # Prepare inputs + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + {"type": "text", "text": "Write a haiku for this image"}, + ], + }, + ], + [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Describe this image"}, + ], + }, + ], + ] + inputs = processor.apply_chat_template( + messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(model.device, dtype=torch.float16) + + output = model.generate(**inputs, do_sample=False, max_new_tokens=25) + + # Check first output + decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + print("decoded_output", decoded_output) + expected_output = "Wooden path to water,\nMountains echo in stillness,\nPeaceful forest scene." # fmt: skip + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) + + # Check second output + decoded_output = processor.decode(output[1, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + print("decoded_output", decoded_output) + expected_output = 'This image captures a vibrant street scene in a bustling urban area, likely in an Asian city. The focal point is a' # fmt: skip + + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) + + @slow + @require_torch_gpu + def test_small_model_integration_batched_generate_multi_image(self): + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + model = AyaVisionForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device, torch_dtype=torch.float16 + ) + # Prepare inputs + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"}, + {"type": "text", "text": "Write a haiku for this image"}, + ], + }, + ], + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + { + "type": "image", + "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + }, + { + "type": "text", + "text": "These images depict two different landmarks. Can you identify them?", + }, + ], + }, + ], + ] + inputs = processor.apply_chat_template( + messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(model.device, dtype=torch.float16) + output = model.generate(**inputs, do_sample=False, max_new_tokens=25) + + # Check first output + decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + # Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232 + expected_output = "Wooden path to water,\nMountains echo in stillness,\nPeaceful forest scene." # fmt: skip + print("decoded_output", decoded_output) + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) + + # Check second output + decoded_output = processor.decode(output[1, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + print("decoded_output", decoded_output) + expected_output = "The first image showcases the Statue of Liberty, a colossal neoclassical sculpture on Liberty Island in New York Harbor. Standing at a" # fmt: skip + self.assertEqual( + decoded_output, + expected_output, + f"Decoded output: {decoded_output}\nExpected output: {expected_output}", + ) diff --git a/tests/models/aya_vision/test_processor_aya_vision.py b/tests/models/aya_vision/test_processor_aya_vision.py new file mode 100644 index 000000000000..8830f85c50c1 --- /dev/null +++ b/tests/models/aya_vision/test_processor_aya_vision.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest +from typing import Optional + +from transformers import AutoProcessor, AutoTokenizer, AyaVisionProcessor +from transformers.testing_utils import require_read_token, require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from transformers import GotOcr2ImageProcessor + + +@require_read_token +@require_vision +class AyaVisionProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = AyaVisionProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + image_processor = GotOcr2ImageProcessor( + do_resize=True, + size={"height": 20, "width": 20}, + max_patches=2, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + do_convert_rgb=True, + ) + tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-vision-8b", padding_side="left") + processor_kwargs = self.prepare_processor_dict() + processor = AyaVisionProcessor.from_pretrained( + "CohereForAI/aya-vision-8b", + image_processor=image_processor, + tokenizer=tokenizer, + **processor_kwargs, + ) + processor.save_pretrained(self.tmpdirname) + + def prepare_processor_dict(self): + return {"patch_size": 10, "img_size": 20} + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def get_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + # todo: yoni, fix this test + @unittest.skip("Chat template has long system prompt") + def test_chat_template_accepts_processing_kwargs(self, **kwargs): + pass + + # Override as AyaVisionProcessor needs image tokens in prompts + def prepare_text_inputs(self, batch_size: Optional[int] = None): + if batch_size is None: + return "lower newer " + + if batch_size < 1: + raise ValueError("batch_size must be greater than 0") + + if batch_size == 1: + return ["lower newer "] + return ["lower newer ", " upper older longer string"] + [" lower newer"] * ( + batch_size - 2 + ) + + @require_torch + def test_process_interleaved_images_videos(self): + processor = self.get_processor() + + messages = [ + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + { + "type": "image", + "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + }, + {"type": "text", "text": "What are the differences between these two images?"}, + ], + }, + ], + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://llava-vl.github.io/static/images/view.jpg", + }, + {"type": "text", "text": "Write a haiku for this image"}, + ], + } + ], + ] + + inputs_batched = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + ) + + # Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together + images_patches_index = 0 + for i, message in enumerate(messages): + inputs = processor.apply_chat_template( + message, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + ) + # We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded + torch.testing.assert_close( + inputs["input_ids"][0], inputs_batched["input_ids"][i][-inputs["input_ids"].shape[1] :] + ) + torch.testing.assert_close( + inputs["pixel_values"], + inputs_batched["pixel_values"][ + images_patches_index : images_patches_index + inputs["pixel_values"].shape[0] + ], + ) + images_patches_index += inputs["pixel_values"].shape[0]