-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma2: add cache warning #32279
Gemma2: add cache warning #32279
Changes from 6 commits
ecc7e43
9fc3a4f
7c27d11
ff62a4f
f34c819
26dbe82
8bf1136
a69c06b
240500d
724c62b
ac71e7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -747,7 +747,19 @@ def forward( | |
inputs_embeds = self.embed_tokens(input_ids) | ||
|
||
if cache_position is None: | ||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) | ||
if past_key_values is None: | ||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) | ||
else: | ||
raise ValueError("When `past_key_values` is passed, `cache_position` must be too") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is actually valid. Someone that is using past outside generate. |
||
|
||
# Probably a forward call with caching, so we set up cache for one call only | ||
if use_cache and past_key_values is None and not self.training: | ||
logger.warning_once( | ||
"You are calling the model with `use_cache=True` but didn't pass `past_key_values`. ", | ||
"Make sure to pass an instance of `HybridCache`. Caching will be disabled. See for more: " | ||
"(https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", | ||
) | ||
zucchini-nlp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
use_cache = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH I am bothered by adding a LOT of if else if else and stuff like that, which should be more often than not handle outside . Tho I guess we don't have a choice There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if we change the good old use_cache = use_cache if use_cache is not None else self.config.use_cache || not self.training There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do, the main concern is mentioned by user in the issue, that people expect gemma2 to be as all other models and return cache if use_cache is True. If we simply change the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. There a tradeoff to have between people's expectations and warning everywhere. IMO we should:
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
In llama we already have this, meaning if you past_key_value is None we only create dynamic cache when not training There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we should also document it, will add some info. So, init an empty HybridCache should be okay, as long as we warn and document it |
||
|
||
if position_ids is None: | ||
position_ids = cache_position.unsqueeze(0) | ||
|
@@ -832,7 +844,7 @@ def _update_causal_mask( | |
dtype, device = input_tensor.dtype, input_tensor.device | ||
min_dtype = torch.finfo(dtype).min | ||
sequence_length = input_tensor.shape[1] | ||
if past_key_values is not None: | ||
if past_key_values is not None and past_key_values.get_max_length() is not None: | ||
target_length = past_key_values.get_max_length() | ||
else: | ||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!