diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index ac33834398..e51943bd4a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -36,7 +36,7 @@ from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache -from .utils import MULTI_QUERY_ATTN_MODELS, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME +from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME if TYPE_CHECKING: @@ -328,6 +328,7 @@ def prepare_past_key_values( if use_torch and use_cache_branch is not None: use_cache_branch = use_cache_branch.to(self.device) + pkv_output_shape = {} # Generate dummy past for the first forward if uses a merged decoder if past_key_values is None: batch_size = input_ids.shape[0] @@ -338,6 +339,7 @@ def prepare_past_key_values( embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads dtype = constructor.float16 if self.use_fp16 else constructor.float32 + # TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY. # "1" is the dummy sequence length if self.model_type == "bloom": @@ -353,6 +355,13 @@ def prepare_past_key_values( past_key_values = tuple( key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value] ) + + for name, value in zip(self.key_value_output_names, past_key_values): + shape = [*value.shape] + index = 1 if "value" in name else 2 + + shape[index] += sequence_length + pkv_output_shape[name] = shape elif self.model_type == "gpt_bigcode": # GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor. shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) @@ -362,16 +371,15 @@ def prepare_past_key_values( key_and_value = key_and_value.to(self.device) past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names))) - elif self.model_type == "falcon": - shape = (batch_size * self.num_key_value_heads, 0, embed_size_per_head) - key_or_value = constructor.zeros(shape, dtype=dtype) - if use_torch: - key_or_value = key_or_value.to(self.device) - - past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) + for name, value in zip(self.key_value_output_names, past_key_values): + shape = [*value.shape] + shape[1] += sequence_length + pkv_output_shape[name] = shape else: - shape = (batch_size, num_attention_heads, 0, embed_size_per_head) + num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads + + shape = (batch_size, num_key_value_heads, 0, embed_size_per_head) key_or_value = constructor.zeros(shape, dtype=dtype) if use_torch: @@ -379,17 +387,10 @@ def prepare_past_key_values( past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) - pkv_output_shape = {} - for name, value in zip(self.key_value_output_names, past_key_values): - shape = [*value.shape] - index = ( - 1 - if self.model_type in MULTI_QUERY_ATTN_MODELS or (self.model_type == "bloom" and "value" in name) - else 2 - ) - - shape[index] += sequence_length - pkv_output_shape[name] = shape + for name, value in zip(self.key_value_output_names, past_key_values): + shape = [*value.shape] + shape[2] += sequence_length + pkv_output_shape[name] = shape return use_cache_branch, past_key_values, pkv_output_shape @@ -841,8 +842,9 @@ def __init__( self.num_key_value_heads = ( config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1 ) + self.use_alibi = config.alibi - # Copied from https://github.com/huggingface/transformers/pull/26199 + # Copied from transformers.models.falcon.modeling_falcon.FalconForCausalLM._reorder_cache def _reorder_cache( self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: @@ -850,9 +852,9 @@ def _reorder_cache( This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. + Output shares the same memory storage as `past`. """ - standardized_past = self._convert_cache_to_standard_format(past, batch_size=len(beam_idx)) # Get a copy of `beam_idx` on all the devices where we need those indices. device_to_beam_idx = { @@ -863,11 +865,11 @@ def _reorder_cache( layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), ) - for layer_past in standardized_past + for layer_past in past ) - return self._convert_to_rw_cache(reordered_past) + return reordered_past - # Copied from https://github.com/huggingface/transformers/pull/26199 + # Adapted from transformers.models.falcon.modeling_falcon.FalconForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, @@ -877,21 +879,19 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict: if past_key_values is not None: - if past_key_values[0][0].ndim != 3: - # Compared to transformers, we do not use _convert_cache_to_standard_format in the model itself, hence the 3D cache. - raise ValueError("Falcon uses 3D KV cache.") + past_length = past_key_values[0][0].shape[2] - past_length = past_key_values[0][0].shape[1] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 + input_ids = input_ids[:, remove_prefix_length:] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. - if not self.config.alibi and attention_mask is not None and position_ids is None: + if not self.use_alibi and attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index c64654d67c..e4a1b02007 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -34,7 +34,7 @@ AutoModelForVision2Seq, GenerationConfig, Pix2StructForConditionalGeneration, # Pix2struct does not support AutoModel - ) +) from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput @@ -1755,7 +1755,7 @@ def generate( if segment_input_slice.shape[-1] < num_segment_frames: # pad to 3000 if necessary - segment_input_slice = F.pad( + segment_input_slice = torch.nn.functional.pad( segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) ) @@ -1829,7 +1829,7 @@ def generate( max_total_length = max(max_total_length, len(sequences[-1])) for i in range(batch_size): - sequences[i] = F.pad( + sequences[i] = torch.nn.functional.pad( sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id ) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 1ec21c9802..aea997eb39 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -54,9 +54,6 @@ "tensor(double)": np.float64, } -# TODO: this is likely bugged as Falcon handles both the MQA and non-MQA implem -MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} - def _is_gpu_available(): """