-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Falcon: batched generation #26137
Falcon: batched generation #26137
Conversation
total_length = seq_len + past_key_values_length | ||
if total_length > self.seq_len_cached: | ||
self._set_cos_sin_cache(total_length, device, dtype) | ||
return ( | ||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the slicing here is equivalent to building position ids from the sequence length, without taking into account any potential left-padding
The documentation is not available anymore as the PR was closed or merged. |
@ArthurZucker woops, sorry, there are still tests to fix, I will ping you again when they are fixed! |
@ArthurZucker ready now |
@@ -415,7 +437,11 @@ def forward( | |||
else: | |||
present = None | |||
|
|||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Rocketknight1 this 1e-9
was causing problems in some numerical precisions (it would be converted to -inf
) :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, my bad!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding batch support! Let's be careful with padding as we have been getting a lot of issues regarding this! If we have one reference implementation (used in Llama) would be great to re-use!
Otherwise, LGTM
@@ -99,19 +99,40 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): | |||
self.cos_cached = self.cos_cached.type(dtype) | |||
self.sin_cached = self.sin_cached.type(dtype) | |||
|
|||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: | |||
def cos_sin( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm gonna be a bit noisy here, but this looks a LOT like the rotary embedding we have in Llama no?
The query expansion is also supported there, not sure how much of an overhead it is to first apply rotary then expand:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
and also storing the full size key and values is less memory efficient no? (unrelated to the PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they are the same. So, we would benefit from copying the structure (at least in terms of complexity for us, the maintainers) 👍
I would like to push it to the future, though, as I'm about to go on long holidays and I'd like to enable batched generation on Falcon :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just add a TODO then 😉
@ArthurZucker 100% agreed! If you come across a new model, plz make sure there is a test for this 🙏 |
@ArthurZucker suggestions applied 💪 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Let's add a TODO so that we don't ever forget we have to refactor this in the futur! 😉
@@ -99,19 +99,40 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): | |||
self.cos_cached = self.cos_cached.type(dtype) | |||
self.sin_cached = self.sin_cached.type(dtype) | |||
|
|||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: | |||
def cos_sin( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just add a TODO then 😉
What does this PR do?
This PR does three things:
0
. In some numerical precisions, the numerical attention mask was getting-inf
, which wrecked downstream computations.position_ids
input to Falcon, which is needed for proper batched generation. When it is not passed, the forward pass builds the position ids from the sequence length, which does not account for the left-padding in batched generation -- the model could still generate, but the results should be slightly better after the fix.