Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output from model.Generate & model.forward not same when output attention/hidden_state is True #32117

Closed
2 of 4 tasks
AamodThakur opened this issue Jul 21, 2024 · 2 comments
Closed
2 of 4 tasks

Comments

@AamodThakur
Copy link

System Info

  • transformers version: 4.42.4
  • Platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.24.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.31.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No
  • GPU type: Tesla V100-SXM2-32GB

Who can help?

@gante @arth

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Common Step:

tokenizer = GemmaTokenizer.from_pretrained(model_path, device_map = device)
model = GemmaForCausalLM.from_pretrained(model_path, device_map = device)

prompt = ["[INST]What is USA?[/INST]"]
input_tokens = tokenizer(prompt, return_tensors = 'pt', padding=True).to(device)

## Lets call this common_output
generate_ids = model.generate(**input_tokens, return_dict_in_generate=True,
                            temperature=None,
                            top_p=None, max_new_tokens=5, use_cache=False)

Case 1:

test_input_tokens = input_tokens

generate_ids_1 = model.forward(**test_input_tokens, output_attentions=True, output_hidden_states=True)
next_token_logits = generate_ids_1.logits[:, -1, :]
test_input_tokens["input_ids"] = torch.concat((test_input_tokens["input_ids"], 
                                   next_token_logits.argmax().unsqueeze(dim=0).unsqueeze(dim=0)), dim=-1)
test_input_tokens["attention_mask"] = torch.concat((test_input_tokens["attention_mask"], torch.ones((1,1), dtype=torch.long).to("cuda")), dim=-1)

Case 2:

test_input_tokens = input_tokens

generate_ids_1 = model.forward(**test_input_tokens)
next_token_logits = generate_ids_1.logits[:, -1, :]
test_input_tokens["input_ids"] = torch.concat((test_input_tokens["input_ids"], 
                                   next_token_logits.argmax().unsqueeze(dim=0).unsqueeze(dim=0)), dim=-1)
test_input_tokens["attention_mask"] = torch.concat((test_input_tokens["attention_mask"], torch.ones((1,1), dtype=torch.long).to("cuda")), dim=-1)

We are getting output of generate function & output of Case 2 same. But for output Case 1 is different because we have set "output_attentions=True, output_hidden_states=True" which we believe should be used for debugging purpose.

Expected behavior

Output of generate should be same as Case 1 & Case 2's output.

@gante
Copy link
Member

gante commented Aug 23, 2024

Hi @AamodThakur 👋

The differences are expected :) You are hitting an edge case of the optimized SDPA attention (the implementation used by default) when output_attentions=True, where it falls back to the eager attention implementation. As with most optimized FP operations, there are tiny fluctuations to be expected when compared to other implementations (see here for an explanation of why it happens).

You should also be getting a warning like "GemmaModel is using GemmaSdpaAttention, but torch.nn.functional.scaled_dot_product_attention does not support output_attentions=True (...)" 🤗

For reference, here are the relevant code lines:

if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants