diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 98c1f0803b..e8488abc69 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2364,12 +2364,6 @@ def _contrastive_search( ) # contrastive_search main logic end - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) if synced_gpus and this_peer_finished: continue @@ -2387,6 +2381,11 @@ def _contrastive_search( if streamer is not None: streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) # increase cur_len cur_len = cur_len + 1