Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma2: add cache warning #32279

Merged
merged 11 commits into from
Aug 7, 2024
9 changes: 9 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,15 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length
- reset

[[autodoc]] HybridCache
- update
- get_seq_length
- reset

[[autodoc]] SlidingWindowCache
- update
- reset

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

[[autodoc]] EncoderDecoderCache
- get_seq_length
- to_legacy_cache
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,19 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
if past_key_values is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
else:
raise ValueError("When `past_key_values` is passed, `cache_position` must be too")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is actually valid. Someone that is using past outside generate.


# Probably a forward call with caching, so we set up cache for one call only
if use_cache and past_key_values is None and not self.training:
logger.warning_once(
"You are calling the model with `use_cache=True` but didn't pass `past_key_values`. ",
"Make sure to pass an instance of `HybridCache`. Caching will be disabled. See for more: "
"(https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)",
)
use_cache = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I am bothered by adding a LOT of if else if else and stuff like that, which should be more often than not handle outside . Tho I guess we don't have a choice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we change the good old

use_cache = use_cache if use_cache is not None else self.config.use_cache || not self.training

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do, the main concern is mentioned by user in the issue, that people expect gemma2 to be as all other models and return cache if use_cache is True. If we simply change the use_cache it won't make it more explicit for those not familiar with cache classes. So my main point is to make users familiar that gemma2 is different and they should not expect same behavior as in other models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. There a tradeoff to have between people's expectations and warning everywhere. IMO we should:

  • better document Gemma2: add in the gemma2.md that gemma has a different cache class
  • return the HybridCache only when you have use_cache / :
        return_legacy_cache = False
        if (
            use_cache and not isinstance(past_key_values, Cache) and not self.training
        ):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
                "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
            )

In llama we already have this, meaning if you past_key_value is None we only create dynamic cache when not training

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should also document it, will add some info. So, init an empty HybridCache should be okay, as long as we warn and document it


if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down Expand Up @@ -832,7 +844,7 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None:
if past_key_values is not None and past_key_values.get_max_length() is not None:
target_length = past_key_values.get_max_length()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down