Skip to content

Commit

Permalink
LLaVa: add cache class attribute (#32278)
Browse files Browse the repository at this point in the history
cache class flag
  • Loading branch information
zucchini-nlp authored Aug 1, 2024
1 parent 14ee232 commit 453e748
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

def _init_weights(self, module):
std = (
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only
Expand Down

0 comments on commit 453e748

Please sign in to comment.