Skip to content

Commit

Permalink
Fixes default value of softmax_scale in PhiFlashAttention2. (#28537)
Browse files Browse the repository at this point in the history
* fix(phi): Phi does not use softmax_scale in Flash-Attention.

* chore(docs): Update Phi docs.
  • Loading branch information
gugarosa authored Jan 17, 2024
1 parent a6adc05 commit d93ef7d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
33 changes: 16 additions & 17 deletions docs/source/en/model_doc/phi.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The Phi-1 model was proposed in [Textbooks Are All You Need](https://arxiv.org/a
The Phi-1.5 model was proposed in [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee.

### Summary

In Phi-1 and Phi-1.5 papers, the authors showed how important the quality of the data is in training relative to the model size.
They selected high quality "textbook" data alongside with synthetically generated data for training their small sized Transformer
based model Phi-1 with 1.3B parameters. Despite this small scale, phi-1 attains pass@1 accuracy 50.6% on HumanEval and 55.5% on MBPP.
Expand All @@ -31,7 +32,6 @@ to models 5x larger, and surpassing most non-frontier LLMs. Phi-1.5 exhibits man
to “think step by step” or perform some rudimentary in-context learning.
With these two experiments the authors successfully showed the huge impact of quality of training data when training machine learning models.


The abstract from the Phi-1 paper is the following:

*We introduce phi-1, a new large language model for code, with significantly smaller size than
Expand Down Expand Up @@ -60,32 +60,32 @@ including hallucinations and the potential for toxic and biased generations –e
are seeing improvement on that front thanks to the absence of web data. We open-source phi-1.5 to
promote further research on these urgent topics.*


This model was contributed by [Susnato Dhar](https://huggingface.co/susnato).
The original code for Phi-1 and Phi-1.5 can be found [here](https://huggingface.co/microsoft/phi-1/blob/main/modeling_mixformer_sequential.py) and [here](https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py) respectively.

The original code for Phi-2 can be found [here](https://huggingface.co/microsoft/phi-2).

The original code for Phi-1, Phi-1.5 and Phi-2 can be found [here](https://huggingface.co/microsoft/phi-1), [here](https://huggingface.co/microsoft/phi-1_5) and [here](https://huggingface.co/microsoft/phi-2), respectively.

## Usage tips

- This model is quite similar to `Llama` with the main difference in [`PhiDecoderLayer`], where they used [`PhiAttention`] and [`PhiMLP`] layers in parallel configuration.
- The tokenizer used for this model is identical to the [`CodeGenTokenizer`].


## How to use Phi-2

<Tip warning={true}>

The current weights at [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) are not in proper order to be used with the library model. Until that is resolved, please use [susnato/phi-2](https://huggingface.co/susnato/phi-2) to load using the library `phi` model.
Phi-2 has been integrated in the development version (4.37.0.dev) of `transformers`. Until the official version is released through `pip`, ensure that you are doing one of the following:

* When loading the model, ensure that `trust_remote_code=True` is passed as an argument of the `from_pretrained()` function.

* Update your local `transformers` to the development version: `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers`. The previous command is an alternative to cloning and installing from the source.

</Tip>

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> model = AutoModelForCausalLM.from_pretrained("susnato/phi-2")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2")
>>> model = AutoModelForCausalLM.from_pretrained("phi-2")

This comment has been minimized.

Copy link
@pashminacameron

pashminacameron Jan 18, 2024

Contributor

Isn't the key microsoft/phi-2?

This comment has been minimized.

Copy link
@amyeroberts

amyeroberts Jan 18, 2024

Collaborator

Yep! I've made a small PR to fix here: #28581

>>> tokenizer = AutoTokenizer.from_pretrained("phi-2")

>>> inputs = tokenizer('Can you help me write a formal email to a potential business partner proposing a joint venture?', return_tensors="pt", return_attention_mask=False)

Expand All @@ -95,15 +95,14 @@ The current weights at [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
'Can you help me write a formal email to a potential business partner proposing a joint venture?\nInput: Company A: ABC Inc.\nCompany B: XYZ Ltd.\nJoint Venture: A new online platform for e-commerce'
```


### Example :

```python
>>> from transformers import PhiForCausalLM, AutoTokenizer

>>> # define the model and tokenizer.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
>>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")

>>> # feel free to change the prompt to your liking.
>>> prompt = "If I were an AI that had just achieved"
Expand All @@ -118,7 +117,6 @@ The current weights at [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
'If I were an AI that had just achieved a breakthrough in machine learning, I would be thrilled'
```


## Combining Phi and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
Expand All @@ -136,8 +134,8 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import PhiForCausalLM, AutoTokenizer

>>> # define the model and tokenizer and push the model and tokens to the GPU.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
>>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")

>>> # feel free to change the prompt to your liking.
>>> prompt = "If I were an AI that had just achieved"
Expand All @@ -153,12 +151,13 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
```

### Expected speedups
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `susnato/phi-1_dev` checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048.

Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `microsoft/phi-1` checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/phi_1_speedup_plot.jpg">
</div>


## PhiConfig

[[autodoc]] PhiConfig
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def forward(
value_states = value_states.to(target_dtype)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down

0 comments on commit d93ef7d

Please sign in to comment.