Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gemma 2 #31659

Merged
merged 16 commits into from
Jun 27, 2024
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Funnel Transformer](model_doc/funnel) | ✅ | ✅ | ❌ |
| [Fuyu](model_doc/fuyu) | ✅ | ❌ | ❌ |
| [Gemma](model_doc/gemma) | ✅ | ❌ | ✅ |
| [Gemma2](model_doc/gemma2) | ✅ | ❌ | ❌ |
| [GIT](model_doc/git) | ✅ | ❌ | ❌ |
| [GLPN](model_doc/glpn) | ✅ | ❌ | ❌ |
| [GPT Neo](model_doc/gpt_neo) | ✅ | ❌ | ✅ |
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@
],
"models.fuyu": ["FuyuConfig"],
"models.gemma": ["GemmaConfig"],
"models.gemma2": ["Gemma2Config"],
"models.git": [
"GitConfig",
"GitProcessor",
Expand Down Expand Up @@ -2181,6 +2182,15 @@
"GemmaPreTrainedModel",
]
)
_import_structure["models.gemma2"].extend(
[
"Gemma2ForCausalLM",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
"Gemma2Model",
"Gemma2PreTrainedModel",
]
)
_import_structure["models.git"].extend(
[
"GitForCausalLM",
Expand Down Expand Up @@ -5062,6 +5072,7 @@
)
from .models.fuyu import FuyuConfig
from .models.gemma import GemmaConfig
from .models.gemma2 import Gemma2Config
from .models.git import (
GitConfig,
GitProcessor,
Expand Down Expand Up @@ -6694,6 +6705,13 @@
GemmaModel,
GemmaPreTrainedModel,
)
from .models.gemma2 import (
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2ForTokenClassification,
Gemma2Model,
Gemma2PreTrainedModel,
)
from .models.git import (
GitForCausalLM,
GitModel,
Expand Down
122 changes: 122 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,125 @@ def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None


class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.is_sliding = torch.tensor(
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
max_batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
self.head_dim,
)
for i in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] > max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position >= max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out

def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
update_fn = self._sliding_update
else:
update_fn = self._static_update

return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)

def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return self.max_cache_len

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return None

def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
2 changes: 1 addition & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(self, **kwargs):
# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None:
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Cache,
DynamicCache,
HQQQuantizedCache,
HybridCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
Expand Down Expand Up @@ -112,7 +113,7 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}


Expand Down Expand Up @@ -1395,10 +1396,12 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):

past_length = 0
if model_kwargs.get("past_key_values") is not None:
if isinstance(model_kwargs["past_key_values"], Cache):
past_length = model_kwargs["past_key_values"].get_seq_length()
else:
past_length = model_kwargs["past_key_values"][0][0].shape[2]
cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"):
past_length = cache.get_seq_length()

if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
Expand Down Expand Up @@ -1739,7 +1742,9 @@ def generate(
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
funnel,
fuyu,
gemma,
gemma2,
git,
glpn,
gpt2,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
("funnel", "FunnelConfig"),
("fuyu", "FuyuConfig"),
("gemma", "GemmaConfig"),
("gemma2", "Gemma2Config"),
("git", "GitConfig"),
("glpn", "GLPNConfig"),
("gpt-sw3", "GPT2Config"),
Expand Down Expand Up @@ -385,6 +386,7 @@
("funnel", "Funnel Transformer"),
("fuyu", "Fuyu"),
("gemma", "Gemma"),
("gemma2", "Gemma2"),
("git", "GIT"),
("glpn", "GLPN"),
("gpt-sw3", "GPT-Sw3"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("git", "GitModel"),
("glpn", "GLPNModel"),
("gpt-sw3", "GPT2Model"),
Expand Down Expand Up @@ -454,6 +455,7 @@
("falcon", "FalconForCausalLM"),
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("git", "GitForCausalLM"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
Expand Down Expand Up @@ -863,6 +865,7 @@
("fnet", "FNetForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("gemma", "GemmaForSequenceClassification"),
("gemma2", "Gemma2ForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
Expand Down Expand Up @@ -1044,6 +1047,7 @@
("fnet", "FNetForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("gemma", "GemmaForTokenClassification"),
("gemma2", "Gemma2ForTokenClassification"),
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"gemma2",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/diff_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = 1 / math.sqrt(config.head_dim)

if self.hidden_size % self.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -305,7 +306,7 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = 1 / math.sqrt(config.head_dim)

if self.hidden_size % self.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -288,7 +289,7 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
Expand Down Expand Up @@ -898,6 +899,13 @@ def forward(
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand Down Expand Up @@ -1397,7 +1405,7 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
Expand All @@ -1407,7 +1415,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
Expand Down
Loading
Loading