Skip to content

Commit

Permalink
Falcon: batched generation (#26137)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Sep 13, 2023
1 parent 95a9041 commit a796f7e
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 14 deletions.
78 changes: 66 additions & 12 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
class FalconRotaryEmbedding(nn.Module):
"""Implementation of RotaryEmbedding from GPT-NeoX.
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
Expand Down Expand Up @@ -99,19 +100,40 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
def cos_sin(
self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16
) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
)
# Gather cos, sin at the designated position ids
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
return cos, sin

def forward(self, query, key, past_key_values_length, position_ids):
_, seq_len, _ = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
# avoid unnecessary repeat_interleave operations.
query_expansion_factor = int(query.shape[0] / cos.shape[0])
if query_expansion_factor > 1:
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0)
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0)
else:
query_cos, query_sin = cos, sin

key_expansion_factor = int(key.shape[0] / cos.shape[0])
if key_expansion_factor > 1:
if key_expansion_factor != query_expansion_factor:
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0)
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)
else:
key_cos, key_sin = query_cos, query_sin
else:
key_cos, key_sin = cos, sin

def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin)


class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
Expand Down Expand Up @@ -270,7 +292,7 @@ def __init__(self, config: FalconConfig):
f" {self.num_heads})."
)

self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k)
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k)

# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
Expand Down Expand Up @@ -378,6 +400,7 @@ def forward(
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
Expand All @@ -399,7 +422,7 @@ def forward(
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)

if layer_past is not None:
past_key, past_value = layer_past
Expand All @@ -415,7 +438,8 @@ def forward(
else:
present = None

attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
float_min = torch.finfo(query_layer.dtype).min
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype)

query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
Expand Down Expand Up @@ -536,6 +560,7 @@ def forward(
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
Expand All @@ -554,6 +579,7 @@ def forward(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
Expand Down Expand Up @@ -632,6 +658,11 @@ def forward(
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
Expand Down Expand Up @@ -836,6 +867,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -892,6 +924,14 @@ def forward(
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

causal_mask = self._prepare_attn_mask(
attention_mask,
Expand Down Expand Up @@ -922,13 +962,15 @@ def custom_forward(*inputs):
hidden_states,
alibi,
causal_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -988,13 +1030,23 @@ def prepare_inputs_for_generation(
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values is not None:
input_ids = input_ids[:, -1:]

# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
Expand All @@ -1011,6 +1063,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
Expand All @@ -1032,6 +1085,7 @@ def forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down
41 changes: 39 additions & 2 deletions tests/models/falcon/test_modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@

from parameterized import parameterized

from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed
from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
FalconConfig,
is_torch_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger, require_bitsandbytes, require_torch, slow, tooslow, torch_device
from transformers.utils import logging as transformers_logging

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -502,6 +510,35 @@ def test_lm_generation_use_cache(self):
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)

@require_bitsandbytes
@slow
def test_batched_generation(self):
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-7b",
device_map="auto",
load_in_4bit=True,
)

test_text = "A sequence: 1, 2" # should generate the rest of the sequence

unpadded_inputs = tokenizer([test_text], return_tensors="pt").to("cuda:0")
unpadded_inputs.pop("token_type_ids")
unpadded_gen_out = model.generate(**unpadded_inputs, max_new_tokens=20)
unpadded_gen_text = tokenizer.batch_decode(unpadded_gen_out, skip_special_tokens=True)

dummy_text = "This is a longer text " * 2 # forces left-padding on `test_text`
padded_inputs = tokenizer([test_text, dummy_text], return_tensors="pt", padding=True).to("cuda:0")
padded_inputs.pop("token_type_ids")
padded_gen_out = model.generate(**padded_inputs, max_new_tokens=20)
padded_gen_text = tokenizer.batch_decode(padded_gen_out, skip_special_tokens=True)

expected_output = "A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
self.assertEqual(unpadded_gen_text[0], expected_output)
self.assertEqual(padded_gen_text[0], expected_output)


# TODO Lysandre: Remove this in version v4.34
class FalconOverrideTest(unittest.TestCase):
Expand Down

0 comments on commit a796f7e

Please sign in to comment.