Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Falcon: batched generation #26137

Merged
merged 5 commits into from
Sep 13, 2023
Merged

Falcon: batched generation #26137

merged 5 commits into from
Sep 13, 2023

Conversation

gante
Copy link
Member

@gante gante commented Sep 13, 2023

What does this PR do?

This PR does three things:

  1. Fixes the minimum float number added to the attention mask, in the positions the attention mask is 0. In some numerical precisions, the numerical attention mask was getting -inf, which wrecked downstream computations.
  2. Adds the position_ids input to Falcon, which is needed for proper batched generation. When it is not passed, the forward pass builds the position ids from the sequence length, which does not account for the left-padding in batched generation -- the model could still generate, but the results should be slightly better after the fix.
  3. Add tests for batched generation with left padding

total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the slicing here is equivalent to building position ids from the sequence length, without taking into account any potential left-padding

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante gante requested a review from ArthurZucker September 13, 2023 10:47
@gante
Copy link
Member Author

gante commented Sep 13, 2023

@ArthurZucker woops, sorry, there are still tests to fix, I will ping you again when they are fixed!

@gante gante marked this pull request as ready for review September 13, 2023 11:18
@gante
Copy link
Member Author

gante commented Sep 13, 2023

@ArthurZucker ready now

@@ -415,7 +437,11 @@ def forward(
else:
present = None

attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
Copy link
Member Author

@gante gante Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Rocketknight1 this 1e-9 was causing problems in some numerical precisions (it would be converted to -inf) :p

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding batch support! Let's be careful with padding as we have been getting a lot of issues regarding this! If we have one reference implementation (used in Llama) would be great to re-use!
Otherwise, LGTM

src/transformers/models/falcon/modeling_falcon.py Outdated Show resolved Hide resolved
src/transformers/models/falcon/modeling_falcon.py Outdated Show resolved Hide resolved
@@ -99,19 +99,40 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
def cos_sin(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna be a bit noisy here, but this looks a LOT like the rotary embedding we have in Llama no?
The query expansion is also supported there, not sure how much of an overhead it is to first apply rotary then expand:

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

and also storing the full size key and values is less memory efficient no? (unrelated to the PR).

Copy link
Member Author

@gante gante Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are the same. So, we would benefit from copying the structure (at least in terms of complexity for us, the maintainers) 👍

I would like to push it to the future, though, as I'm about to go on long holidays and I'd like to enable batched generation on Falcon :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just add a TODO then 😉

@gante
Copy link
Member Author

gante commented Sep 13, 2023

Let's be careful with padding as we have been getting a lot of issues regarding this!

@ArthurZucker 100% agreed! If you come across a new model, plz make sure there is a test for this 🙏

@gante
Copy link
Member Author

gante commented Sep 13, 2023

@ArthurZucker suggestions applied 💪

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Let's add a TODO so that we don't ever forget we have to refactor this in the futur! 😉

@@ -99,19 +99,40 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
def cos_sin(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just add a TODO then 😉

@gante gante merged commit a796f7e into huggingface:main Sep 13, 2023
@gante gante deleted the falcon_batch branch September 13, 2023 16:02
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants