Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PaliGemma #30814

Merged
merged 120 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
2fddcc9
add new model like
molbap Mar 4, 2024
e536f6a
Merge pull request #9 from huggingface/update
molbap Mar 4, 2024
6ca0bf7
add state dict slicing + new model config
molbap Mar 5, 2024
21db4a7
update palma config and weights, passes vision activations
molbap Mar 6, 2024
7985fc4
fix
molbap Mar 6, 2024
bcb341d
update
molbap Mar 7, 2024
38ad70e
reorder loading/unpacking
molbap Mar 8, 2024
929746a
clean up
molbap Mar 11, 2024
ae55ad9
add debug statements
molbap Mar 12, 2024
9d5f8fb
change device
molbap Mar 12, 2024
524a073
Merge branch 'add_palma' of github.com:huggingface/new-model-addition…
molbap Mar 12, 2024
d171c4e
fix
molbap Mar 12, 2024
1638a73
debugging
molbap Mar 13, 2024
4aad850
fix noncausal mask
molbap Mar 14, 2024
2972674
fixup sdpa + causal mask
molbap Mar 14, 2024
71fa912
fix activation function
molbap Mar 14, 2024
94e4806
remove debug before changing modeling file
molbap Mar 14, 2024
7b6f0b3
add variants
molbap Mar 15, 2024
4e8e1c6
debug attention mask in generate
molbap Mar 15, 2024
ba8fb4e
revert to non-debug sdpa
molbap Mar 15, 2024
6c2348d
revert gemma modifications
molbap Mar 15, 2024
906e87f
add custom language modeling
molbap Mar 15, 2024
f26361d
use Processor
molbap Mar 15, 2024
b3e4a03
add language modeling file to init
molbap Mar 15, 2024
500a360
try thin wrapper around generate
molbap Mar 15, 2024
96a82e2
Update
molbap Mar 18, 2024
347df2c
update mask
molbap Mar 19, 2024
d01b502
breakpoints galore
molbap Mar 19, 2024
bb8030c
remove conflict
molbap Mar 19, 2024
a6056c6
switch to left-padding
molbap Mar 19, 2024
dacd297
add incomplete model doc
molbap Mar 19, 2024
f133a1d
add paligemma global files
molbap Mar 19, 2024
37c1368
batch rename paligemma
molbap Mar 19, 2024
9935e2e
make generation match outputs and captioning
molbap Mar 19, 2024
e7fd049
style
molbap Mar 19, 2024
edcf1b1
style
molbap Mar 19, 2024
2a29a2e
remove copied from + doc
molbap Mar 19, 2024
5607c9d
remove more copied from
molbap Mar 19, 2024
8ce77f3
remove copy from projector
molbap Mar 19, 2024
3563926
minor fix
molbap Mar 19, 2024
f48f61d
update config and style
molbap Mar 19, 2024
b84d1c1
add readme - dummy
molbap Mar 21, 2024
deb35ba
CORRECT image captioning
molbap Mar 22, 2024
9275d53
moving to args
molbap Mar 26, 2024
5487734
add siglip proper + fix merging image + text features
molbap Mar 27, 2024
7f9c479
take update_causal_mask from upstream
molbap Mar 27, 2024
438b143
remove breakpoint
molbap Mar 27, 2024
00d2922
leverage AutoModel
molbap Mar 28, 2024
72697b6
fix input_ids slicing
molbap Mar 28, 2024
17b30fd
make siglip head conditional
molbap Mar 28, 2024
26cb46e
remove encoder_decoder value
molbap Mar 28, 2024
f49f389
remove unneeded modeling file
molbap Mar 28, 2024
d1964ea
add commented 4d attention mask
molbap Mar 28, 2024
6da4d2e
FIXED generation with 4D mask
molbap Mar 29, 2024
ef8c0fb
Merge branch 'main' of github.com:huggingface/new-model-addition
ArthurZucker Mar 30, 2024
9f261d7
Update src/transformers/models/siglip/modeling_siglip.py
molbap Mar 31, 2024
11c3488
fix left padding detection
molbap Apr 2, 2024
6d527e4
shuffle order of verifications
molbap Apr 2, 2024
8b679b1
Merge branch 'main' into add_palma
molbap Apr 2, 2024
fdba83a
fix missing labels for training
molbap Apr 3, 2024
eee3c1e
fix
molbap Apr 3, 2024
e4869dd
vectorize merging of features, improve slicing
molbap Apr 10, 2024
b973365
improve testing before conversion
molbap Apr 10, 2024
8cab59c
handle merging in processor
molbap Apr 10, 2024
9186b79
image token index depends on checkpoint
molbap Apr 10, 2024
78398a1
add variants, save processor too
molbap Apr 15, 2024
4ede86e
save processors, base tokenizer off spm file
molbap Apr 15, 2024
c899136
Merge branch 'main' of github.com:huggingface/transformers into main
ArthurZucker Apr 17, 2024
a912273
expand model embeddings due to additional image token
molbap Apr 24, 2024
2c8a508
pass image processing args
molbap Apr 24, 2024
7201d43
add convert rgb to siglip processor
molbap Apr 24, 2024
c74f3a3
add \n token separately
molbap May 3, 2024
23b12a3
Merge branch 'main' of github.com:huggingface/transformers
ArthurZucker May 9, 2024
4e4a957
Merge branch 'main' of github.com:huggingface/new-model-addition
ArthurZucker May 9, 2024
1f465c0
fix tokenizer and prompts
molbap May 9, 2024
38a0401
fix docstrings
molbap May 9, 2024
16baabc
change to camel
molbap May 9, 2024
9119db4
fix casing
molbap May 9, 2024
267a0da
debug pos_ids and sdpa
molbap May 9, 2024
3e32118
Merge branch 'main' into add_palma
molbap May 9, 2024
4ed0594
pass and use cache_position
molbap May 9, 2024
46c5c0b
add flag for newline tokenization
molbap May 10, 2024
5b47a5b
Update src/transformers/models/paligemma/processing_paligemma.py
molbap May 10, 2024
b80c591
simplify conversion script
molbap May 10, 2024
7d21409
add copied from
molbap May 10, 2024
b1eda0e
add precision to conversion script
molbap May 12, 2024
404abd8
Update src/transformers/models/paligemma/modeling_paligemma.py
molbap May 13, 2024
a7e2446
clean up
molbap May 13, 2024
2f6b17d
Merge branch 'add_palma' of github.com:huggingface/new-model-addition…
molbap May 13, 2024
60ad9c5
Shift attention mask from `1:`
pcuenca May 13, 2024
daccbe7
Merge pull request #6 from pcuenca/add_palma_shift_mask
molbap May 13, 2024
91c7aab
add docs, fix quality
molbap May 13, 2024
372d566
Merge branch 'add_palma' of github.com:huggingface/new-model-addition…
molbap May 13, 2024
63464ec
quality, tied weights inheritance, and logits/label alignment
molbap May 13, 2024
55187a1
fix more tests
molbap May 13, 2024
5ea6c9f
pass attn_implementation to language model correctly
molbap May 13, 2024
ddb7ac7
add SiglipVisionTransformer to no split modules
molbap May 13, 2024
7f32718
skip paligemma test for sdpa dispatch to flash
molbap May 13, 2024
ac9fd9a
skip incompatible tests
molbap May 13, 2024
cceb3d0
quality
molbap May 13, 2024
a264824
[broken archive maps]
molbap May 13, 2024
9310873
Apply suggestions
molbap May 14, 2024
0711b12
Update src/transformers/utils/dummy_pt_objects.py
molbap May 14, 2024
e7ec216
simplify conversion script
molbap May 14, 2024
8b0724d
add suggestions
molbap May 14, 2024
7bcea3e
add suggestions
molbap May 14, 2024
498bbde
add copied from
molbap May 14, 2024
a8bd223
fix
molbap May 14, 2024
04d962f
move labels out
molbap May 14, 2024
e7caa8a
revert
molbap May 14, 2024
ac5ed67
fix
molbap May 14, 2024
72f6fdc
remove placeholder labels if None
molbap May 14, 2024
c824771
use cache_position
molbap May 14, 2024
7a8e62e
fix quality + docstrings
molbap May 14, 2024
4913c07
fix quality
molbap May 14, 2024
0c8f2c9
Merge branch 'main' of github.com:huggingface/transformers into main
molbap May 14, 2024
54fd284
Merge branch 'main' into add_palma
molbap May 14, 2024
99c3ac5
fix paligemma 4d gemma mask incompatibility
molbap May 14, 2024
75c36c2
fix config docstring
molbap May 14, 2024
9b49838
fix query and attn_mask dtype
molbap May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,8 @@
title: OWL-ViT
- local: model_doc/owlv2
title: OWLv2
- local: model_doc/paligemma
title: PaliGemma
- local: model_doc/perceiver
title: Perceiver
- local: model_doc/pix2struct
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| [OPT](model_doc/opt) | ✅ | ✅ | ✅ |
| [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ |
| [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ |
| [PaliGemma](model_doc/paligemma) | ✅ | ❌ | ❌ |
| [PatchTSMixer](model_doc/patchtsmixer) | ✅ | ❌ | ❌ |
| [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ |
| [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ |
Expand Down
38 changes: 38 additions & 0 deletions docs/source/en/model_doc/paligemma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# PaliGemma

## Overview

The PaliGemma model was proposed by Google. It is a 3B VLM composed by a Siglip-400m vision encoder and a Gemma-2B decoder linked by a multimodal linear projection. It is not a chat model with images. It cuts an image into a fixed number of VIT tokens and prepends it to an optional prompt. One particularity is that the model uses full block attention on all the image tokens plus the input text tokens. It comes in 3 resolutions, 224x224, 448x448 and 896x896 with 3 base models, with 55 fine-tuned versions for different tasks, and 2 mix models.


This model was contributed by [Molbap](https://huggingface.co/Molbap).


## PaliGemmaConfig

[[autodoc]] PaliGemmaConfig

## PaliGemmaProcessor

[[autodoc]] PaliGemmaProcessor

## PaliGemmaForConditionalGeneration

[[autodoc]] PaliGemmaForConditionalGeneration
- forward
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@
"OwlViTTextConfig",
"OwlViTVisionConfig",
],
"models.paligemma": ["PaliGemmaConfig"],
"models.patchtsmixer": ["PatchTSMixerConfig"],
"models.patchtst": ["PatchTSTConfig"],
"models.pegasus": [
Expand Down Expand Up @@ -2651,6 +2652,13 @@
"OwlViTVisionModel",
]
)
_import_structure["models.paligemma"].extend(
[
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
"PaliGemmaProcessor",
]
)
_import_structure["models.patchtsmixer"].extend(
[
"PatchTSMixerForPrediction",
Expand Down Expand Up @@ -5126,6 +5134,9 @@
OwlViTTextConfig,
OwlViTVisionConfig,
)
from .models.paligemma import (
PaliGemmaConfig,
)
from .models.patchtsmixer import (
PatchTSMixerConfig,
)
Expand Down Expand Up @@ -6956,6 +6967,11 @@
OwlViTTextModel,
OwlViTVisionModel,
)
from .models.paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
PaliGemmaProcessor,
)
from .models.patchtsmixer import (
PatchTSMixerForPrediction,
PatchTSMixerForPretraining,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
opt,
owlv2,
owlvit,
paligemma,
patchtsmixer,
patchtst,
pegasus,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
("opt", "OPTConfig"),
("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"),
("paligemma", "PaliGemmaConfig"),
("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"),
("pegasus", "PegasusConfig"),
Expand Down Expand Up @@ -464,6 +465,7 @@
("opt", "OPT"),
("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"),
("paligemma", "PaliGemma"),
("patchtsmixer", "PatchTSMixer"),
("patchtst", "PatchTST"),
("pegasus", "Pegasus"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
("oneformer", "OneFormerImageProcessor"),
("owlv2", "Owlv2ImageProcessor"),
("owlvit", "OwlViTImageProcessor"),
("paligemma", "CLIPImageProcessor"),
("perceiver", "PerceiverImageProcessor"),
("pix2struct", "Pix2StructImageProcessor"),
("poolformer", "PoolFormerImageProcessor"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@
("nezha", "NezhaForPreTraining"),
("nllb-moe", "NllbMoeForConditionalGeneration"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("retribert", "RetriBertModel"),
("roberta", "RobertaForMaskedLM"),
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
Expand Down Expand Up @@ -697,6 +698,7 @@
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
("oneformer", "OneFormerProcessor"),
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("pix2struct", "Pix2StructProcessor"),
("pop2piano", "Pop2PianoProcessor"),
("sam", "SamProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"pegasus",
(
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,6 @@ def _init_weights(self, module):
"The bare Gemma Model outputting raw hidden-states without any specific head on top.",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma
class GemmaModel(GemmaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
Expand Down Expand Up @@ -988,8 +987,6 @@ def _update_causal_mask(

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
Expand Down
54 changes: 54 additions & 0 deletions src/transformers/models/paligemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available


_import_structure = {"configuration_paligemma": ["PaliGemmaConfig"]}


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_paligemma"] = [
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
]
_import_structure["processing_paligemma"] = ["PaliGemmaProcessor"]


if TYPE_CHECKING:
from .configuration_paligemma import PaliGemmaConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
)
from .processing_paligemma import PaliGemmaProcessor


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
130 changes: 130 additions & 0 deletions src/transformers/models/paligemma/configuration_paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaliGemmamodel configuration"""

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING


logger = logging.get_logger(__name__)


class PaliGemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.

e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vision_config (`PaliGemmaVisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 256000):
The image token index to encode the image prompt.
vocab_size (`int`, *optional*, defaults to 257152):
Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 2048):
Dimension of the multimodal projection space.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden layer of the Language model.

Example:

```python
>>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig

>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()

>>> # Initializing a PaliGemma config
>>> text_config = GemmaConfig()

>>> # Initializing a PaliGemma paligemma-3b-224 style configuration
>>> configuration = PaliGemmaConfig(vision_config, text_config)

>>> # Initializing a model from the paligemma-3b-224 style configuration
>>> model = PaliGemmaForConditionalGeneration(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "paligemma"
is_composition = False

def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
self.is_encoder_decoder = False

if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
intermediate_size=4096,
hidden_size=1152,
patch_size=14,
image_size=224,
num_hidden_layers=27,
num_attention_heads=16,
vocab_size=257152,
vision_use_head=False,
)
self.vocab_size = self.vocab_size

self.text_config = text_config

if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
self.vocab_size = self.text_config.vocab_size
elif text_config is None:
self.text_config = CONFIG_MAPPING["gemma"](
hidden_size=2048,
num_hidden_layers=18,
intermediate_size=16384,
num_attention_heads=8,
num_key_value_heads=1,
is_encoder_decoder=False,
)
self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
Loading
Loading