Skip to content

Commit

Permalink
Fix no output of 1st token and 2nd tokens in transformers 3.9 (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#335)

1. Copied greedy search and beam search from transformers v4.39.2 for
hooking
2. LLM bench supports output of 1st and 2nd tokens latency transformers
version from 4.36.0 to 4.39.2
  • Loading branch information
wgzintel authored Apr 2, 2024
1 parent 373889b commit 9c742a9
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 92 deletions.
168 changes: 117 additions & 51 deletions llm_bench/python/utils/hook_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformers
import torch.distributed as dist
import logging as log
import utils.hook_common as hook_common
from torch import nn
from packaging import version
from typing import Optional, Tuple, Union, List
Expand All @@ -20,35 +21,39 @@
from transformers.utils import ModelOutput


class BeamSearchEncoderDecoderOutput(ModelOutput):
class GenerateBeamDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None


class BeamSearchDecoderOnlyOutput(ModelOutput):
class GenerateBeamEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None


BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]

tm_list = []
tm_infer_list = []


# Transformers version: Release/v4.35.2 514de24abfd4416aeba6a6455ad5920f57f3567d
# Copied from https://github.com/huggingface/transformers/blob/514de24abfd4416aeba6a6455ad5920f57f3567d/src/transformers/generation/utils.py#L2894
# Transformers version: Release/v4.39.2 97c00cdfe132164dbd793447a088432fa359fd36
# Copied from https://github.com/huggingface/transformers/blob/v4.39-release/src/transformers/generation/utils.py#L2823
# Add the function of collecting latency
def new_beam_search(
self,
Expand All @@ -62,17 +67,19 @@ def new_beam_search(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
sequential: Optional[bool] = None,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
Expand Down Expand Up @@ -103,21 +110,28 @@ def new_beam_search(
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_logits (`bool`, *optional*, defaults to `False`):
Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for
more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
sequential (`bool`, defaults to `False`):
By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for
more details). This flag will avoid parallelizing the beam search and will instead run beam search
sequentially.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`generation.BeamSearchDecoderOnlyOutput`], [`~generation.BeamSearchEncoderDecoderOutput`] or
[`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.BeamSearchEncoderDecoderOutput`] if
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Expand All @@ -133,8 +147,8 @@ def new_beam_search(
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
Expand Down Expand Up @@ -167,18 +181,19 @@ def new_beam_search(
... ]
... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
sequential = sequential if sequential is not None else self.generation_config.low_memory
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
Expand All @@ -189,6 +204,8 @@ def new_beam_search(
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else \
(self.generation_config.output_logits if hasattr(self.generation_config, 'output_logits') else None)
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
Expand All @@ -205,6 +222,7 @@ def new_beam_search(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
raise ValueError(
Expand All @@ -213,6 +231,7 @@ def new_beam_search(

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
Expand All @@ -233,27 +252,57 @@ def new_beam_search(
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
while True:
tic = time.perf_counter()
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder

while hook_common._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
tic = time.perf_counter()
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# if sequential is True, split the input to batches of batch_size and run sequentially
tic_infer = time.perf_counter()
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if sequential:
if any(
model_name in self.__class__.__name__.lower()
for model_name in [
"fsmt",
"reformer",
"bloom",
"ctrl",
"gpt_bigcode",
"transo_xl",
"xlnet",
"cpm",
]
):
raise RuntimeError(
f"Currently generation for {self.__class__.__name__} is not supported "
f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
)

inputs_per_sub_batches = hook_common._split_model_inputs(
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
)
outputs_per_sub_batch = [
self(
**inputs_per_sub_batch,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for inputs_per_sub_batch in inputs_per_sub_batches
]

outputs = hook_common.stack_model_outputs(outputs_per_sub_batch)

else: # Unchanged original behavior
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
tm_infer_list.append(time.perf_counter() - tic_infer)

if synced_gpus and this_peer_finished:
Expand All @@ -274,13 +323,14 @@ def new_beam_search(
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores_processed,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)

if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
Expand Down Expand Up @@ -310,6 +360,7 @@ def new_beam_search(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)

beam_scores = beam_outputs["next_beam_scores"]
Expand All @@ -319,21 +370,27 @@ def new_beam_search(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

# increase cur_len
cur_len = cur_len + 1
tm_list.append(time.perf_counter() - tic)
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
bStop = stopping_criteria(input_ids, scores)
if isinstance(bStop, bool):
if beam_scorer.is_done or bStop:
this_peer_finished = True
else:
if beam_scorer.is_done or all(bStop):
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
Expand All @@ -345,32 +402,37 @@ def new_beam_search(
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)

if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None

if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
return GenerateBeamEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return BeamSearchDecoderOnlyOutput(
return GenerateBeamDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequence_outputs["sequences"]
Expand Down Expand Up @@ -405,10 +467,14 @@ def get_time_infer_list(self):

def new_forward(self, model, model_type=None):
"""Define a new beam search function."""
min_version = version.parse('4.34.0')
min_version = version.parse(hook_common.TRANS_MIN_VERSION)
trans_version = version.parse(transformers.__version__)
if trans_version < min_version:
log.warning(f'The function of getting latency of beam search will not be available with current transformers version:{trans_version}')
else:
bound_method = new_beam_search.__get__(model, model.__class__)
model.beam_search = bound_method
min_second_version = version.parse(hook_common.TRANS_SENCOND_VERSION)
if trans_version >= min_second_version:
model._beam_search = bound_method
else:
model.beam_search = bound_method
Loading

0 comments on commit 9c742a9

Please sign in to comment.