Skip to content

Commit

Permalink
fix falcon
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Dec 13, 2023
1 parent df039a3 commit 460d621
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 36 deletions.
60 changes: 30 additions & 30 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache
from .utils import MULTI_QUERY_ATTN_MODELS, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME
from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME


if TYPE_CHECKING:
Expand Down Expand Up @@ -328,6 +328,7 @@ def prepare_past_key_values(
if use_torch and use_cache_branch is not None:
use_cache_branch = use_cache_branch.to(self.device)

pkv_output_shape = {}
# Generate dummy past for the first forward if uses a merged decoder
if past_key_values is None:
batch_size = input_ids.shape[0]
Expand All @@ -338,6 +339,7 @@ def prepare_past_key_values(
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

dtype = constructor.float16 if self.use_fp16 else constructor.float32

# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
# "1" is the dummy sequence length
if self.model_type == "bloom":
Expand All @@ -353,6 +355,13 @@ def prepare_past_key_values(
past_key_values = tuple(
key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value]
)

for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
index = 1 if "value" in name else 2

shape[index] += sequence_length
pkv_output_shape[name] = shape
elif self.model_type == "gpt_bigcode":
# GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor.
shape_key_and_value = (batch_size, 0, embed_size_per_head * 2)
Expand All @@ -362,34 +371,26 @@ def prepare_past_key_values(
key_and_value = key_and_value.to(self.device)

past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names)))
elif self.model_type == "falcon":
shape = (batch_size * self.num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)

if use_torch:
key_or_value = key_or_value.to(self.device)

past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
shape[1] += sequence_length
pkv_output_shape[name] = shape
else:
shape = (batch_size, num_attention_heads, 0, embed_size_per_head)
num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads

shape = (batch_size, num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)

if use_torch:
key_or_value = key_or_value.to(self.device)

past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))

pkv_output_shape = {}
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
index = (
1
if self.model_type in MULTI_QUERY_ATTN_MODELS or (self.model_type == "bloom" and "value" in name)
else 2
)

shape[index] += sequence_length
pkv_output_shape[name] = shape
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
shape[2] += sequence_length
pkv_output_shape[name] = shape

return use_cache_branch, past_key_values, pkv_output_shape

Expand Down Expand Up @@ -841,18 +842,19 @@ def __init__(
self.num_key_value_heads = (
config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1
)
self.use_alibi = config.alibi

# Copied from https://github.com/huggingface/transformers/pull/26199
# Copied from transformers.models.falcon.modeling_falcon.FalconForCausalLM._reorder_cache
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
standardized_past = self._convert_cache_to_standard_format(past, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
Expand All @@ -863,11 +865,11 @@ def _reorder_cache(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
for layer_past in past
)
return self._convert_to_rw_cache(reordered_past)
return reordered_past

# Copied from https://github.com/huggingface/transformers/pull/26199
# Adapted from transformers.models.falcon.modeling_falcon.FalconForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
Expand All @@ -877,21 +879,19 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
if past_key_values is not None:
if past_key_values[0][0].ndim != 3:
# Compared to transformers, we do not use _convert_cache_to_standard_format in the model itself, hence the 3D cache.
raise ValueError("Falcon uses 3D KV cache.")
past_length = past_key_values[0][0].shape[2]

past_length = past_key_values[0][0].shape[1]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.config.alibi and attention_mask is not None and position_ids is None:
if not self.use_alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
Expand Down
6 changes: 3 additions & 3 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
AutoModelForVision2Seq,
GenerationConfig,
Pix2StructForConditionalGeneration, # Pix2struct does not support AutoModel
)
)
from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
Expand Down Expand Up @@ -1755,7 +1755,7 @@ def generate(

if segment_input_slice.shape[-1] < num_segment_frames:
# pad to 3000 if necessary
segment_input_slice = F.pad(
segment_input_slice = torch.nn.functional.pad(
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
)

Expand Down Expand Up @@ -1829,7 +1829,7 @@ def generate(
max_total_length = max(max_total_length, len(sequences[-1]))

for i in range(batch_size):
sequences[i] = F.pad(
sequences[i] = torch.nn.functional.pad(
sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id
)

Expand Down
3 changes: 0 additions & 3 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@
"tensor(double)": np.float64,
}

# TODO: this is likely bugged as Falcon handles both the MQA and non-MQA implem
MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"}


def _is_gpu_available():
"""
Expand Down

0 comments on commit 460d621

Please sign in to comment.