Skip to content

Commit

Permalink
Add support for beam search's num_return_sequencs flag in flax (#23082)
Browse files Browse the repository at this point in the history
* add code for numReturnSeq

* add flax support for num return sequences

* Make Fix up for changes

* add test for num return sequences

* lint
  • Loading branch information
mayankagarwals authored May 3, 2023
1 parent ee4bc07 commit c4e32e2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def generate(
logits_processor=logits_processor,
trace=trace,
params=params,
num_return_sequences=generation_config.num_return_sequences,
model_kwargs=model_kwargs,
)
else:
Expand Down Expand Up @@ -749,6 +750,7 @@ def _beam_search(
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
num_return_sequences: Optional[int] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
"""
Expand Down Expand Up @@ -793,6 +795,9 @@ def gather_fn(tensor):
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences
)

batch_size, num_beams, cur_len = input_ids.shape

Expand Down Expand Up @@ -996,8 +1001,8 @@ def beam_search_body_fn(state, input_ids_length=1):
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)

# take best beam for each batch
sequences = sequences[:, 0]
scores = scores[:, 0]
# Take best beams for each batch (the score is sorted in descending order)
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences])

return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
13 changes: 13 additions & 0 deletions tests/generation/test_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ def test_beam_search_generate(self):

self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())

def test_beam_search_generate_num_return_sequences(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.num_beams = 2
config.num_return_sequences = 2

for model_class in self.all_generative_model_classes:
model = model_class(config)

generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[0], input_ids.shape[0] * config.num_return_sequences)

def test_sample_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
Expand Down

0 comments on commit c4e32e2

Please sign in to comment.