From 3bfd3e4803d62afacac81a0ab5dc7eb69c676263 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Wed, 28 Aug 2024 09:24:06 +0200 Subject: [PATCH] Fix: Jamba batched generation (#32914) * init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring --- .../models/jamba/modeling_jamba.py | 54 ++++++++++++++++--- tests/models/jamba/test_modeling_jamba.py | 52 ++---------------- 2 files changed, 50 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 2722a5e06909..60e1670a3c27 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -649,7 +649,12 @@ def __init__(self, config: JambaConfig, layer_idx): " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config" ) - def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None): + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: HybridMambaAttentionDynamicCache = None, + attention_mask: Optional[torch.LongTensor] = None, + ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( cache_params is not None @@ -666,6 +671,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid # inner layernorms which isn't supported by this fused kernel hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if use_precomputed_states: @@ -683,6 +691,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid cache_params.conv_states[self.layer_idx].copy_(conv_states) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -742,14 +753,17 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None): + def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask: Optional[torch.LongTensor] = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) - use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) # 2. Convolution sequence transformation if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: if self.training: @@ -784,6 +798,9 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -821,14 +838,19 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa return contextualized_states # fmt: on - def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None): + def forward( + self, + hidden_states, + cache_params: HybridMambaAttentionDynamicCache = None, + attention_mask: Optional[torch.LongTensor] = None, + ): if self.use_fast_kernels: if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type: raise ValueError( "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device" ) - return self.cuda_kernels_forward(hidden_states, cache_params) - return self.slow_forward(hidden_states, cache_params) + return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) + return self.slow_forward(hidden_states, cache_params, attention_mask) # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba @@ -1040,6 +1062,7 @@ def forward( hidden_states = self.mamba( hidden_states=hidden_states, cache_params=past_key_value, + attention_mask=attention_mask, ) self_attn_weights = None @@ -1279,12 +1302,16 @@ def forward( position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1292,7 +1319,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + layer_mask, position_ids, past_key_values, output_attentions, @@ -1303,7 +1330,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=layer_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -1384,6 +1411,17 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): return causal_mask + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba class JambaForCausalLM(JambaPreTrainedModel): diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index ed824586e223..6cbfe62cfe17 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -458,51 +458,6 @@ def test_attention_outputs(self): [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], ) - def test_left_padding_compatibility(self): - r""" - Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences - effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value. - """ - import inspect - # NOTE: left-padding results in small numerical differences. This is expected. - # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 - - # First, filter out models that don't support left padding - generative and decoder-only. - # Jamba is a decoder-only architecture - decoder_only_classes = self.all_generative_model_classes - - # Then, test left-padding - def _prepare_model_kwargs(input_ids, attention_mask, signature): - model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} - if "position_ids" in signature: - position_ids = torch.cumsum(attention_mask, dim=-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - model_kwargs["position_ids"] = position_ids - if "cache_position" in signature: - cache_position = torch.arange(input_ids.shape[-1], device=torch_device) - model_kwargs["cache_position"] = cache_position - return model_kwargs - - for model_class in decoder_only_classes: - config, input_ids, attention_mask = self._get_input_ids_and_config() - model = model_class(config).to(torch_device).eval() - signature = inspect.signature(model.forward).parameters.keys() - - # Without padding - model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) - next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] - - # With left-padding (length 32) - pad_size = (input_ids.shape[0], 32) - padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id - padded_input_ids = torch.cat((padding, input_ids), dim=1) - padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) - model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) - next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] - - # They should result in very similar logits - self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) - @require_flash_attn @require_torch_gpu @require_bitsandbytes @@ -692,7 +647,7 @@ def test_simple_generate(self): EXPECTED_LOGITS_NO_GRAD = torch.tensor( [ 0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660, - -0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279, + -0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279, 0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292, 0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906, 0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562 @@ -737,10 +692,11 @@ def test_simple_batched_generate_with_padding(self): with torch.no_grad(): logits = self.model(input_ids=inputs["input_ids"]).logits + # TODO fix logits EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( [ 0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641, - -0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289, + -0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289, 0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261, 0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945, 0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583 @@ -749,7 +705,7 @@ def test_simple_batched_generate_with_padding(self): EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( [ - -0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874, + -0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874, 0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852, 0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129, 0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,