From 1ccf3074ee392ca557a08e709cacb0d696faa58b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 4 Mar 2024 14:17:38 +0000 Subject: [PATCH 1/2] torch.compile gen config preparation --- src/transformers/generation/utils.py | 90 ++++++++++++++++++---------- src/transformers/utils/__init__.py | 1 + 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5b7d18e06c1d..b812b8058c21 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,7 +34,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging +from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -1217,6 +1217,57 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + def _prepare_generation_config( + self, generation_config: GenerationConfig, **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. + """ + # TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400) + # replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with + # the parameterization in `fullgraph=False` so as to enable `fullgraph=True`. + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # three conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same); + # 3) the user must have set generation parameters in the model config. + # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. + if ( + not is_torchdynamo_compiling() + and self.generation_config._from_model_config + and self.generation_config._original_object_hash == hash(self.generation_config) + and self.config._has_non_default_generation_parameters() + ): + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. + if is_torchdynamo_compiling(): + model_kwargs = kwargs + generate_attributes_in_kwargs = [key for key in kwargs.keys() if hasattr(generation_config, key)] + if len(generate_attributes_in_kwargs) > 0: + raise ValueError( + "`torch.compile` exception: all generation configuration attributes must be passed within a " + f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})." + ) + else: + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + + return generation_config, model_kwargs + @torch.no_grad() def generate( self, @@ -1315,44 +1366,17 @@ def generate( - [`~generation.GenerateEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`] """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) + self._validate_model_kwargs(model_kwargs.copy()) + # 2. Set generation parameters if not already defined if synced_gpus is None: if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: synced_gpus = True else: synced_gpus = False - - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - self._validate_model_class() - - # priority: `generation_config` argument > `model.generation_config` (the default generation config) - if generation_config is None: - # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # three conditions must be met - # 1) the generation config must have been created from the model config (`_from_model_config` field); - # 2) the generation config must have seen no modification since its creation (the hash is the same); - # 3) the user must have set generation parameters in the model config. - if ( - self.generation_config._from_model_config - and self.generation_config._original_object_hash == hash(self.generation_config) - and self.config._has_non_default_generation_parameters() - ): - new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: - warnings.warn( - "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" - ) - self.generation_config = new_generation_config - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - self._validate_model_kwargs(model_kwargs.copy()) - - # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 154077924bea..021faed7cb94 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -192,6 +192,7 @@ is_torchaudio_available, is_torchdistx_available, is_torchdynamo_available, + is_torchdynamo_compiling, is_torchvision_available, is_training_run_on_sagemaker, is_vision_available, From f76ddedcde2423bb97ca4e28e614ddf0c32ce38b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 6 Mar 2024 11:10:48 +0000 Subject: [PATCH 2/2] ignore redundant parameterization --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b812b8058c21..437b687b7f4b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1256,7 +1256,9 @@ def _prepare_generation_config( # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. if is_torchdynamo_compiling(): model_kwargs = kwargs - generate_attributes_in_kwargs = [key for key in kwargs.keys() if hasattr(generation_config, key)] + generate_attributes_in_kwargs = [ + key for key, value in kwargs.items() if getattr(generation_config, key, None) != value + ] if len(generate_attributes_in_kwargs) > 0: raise ValueError( "`torch.compile` exception: all generation configuration attributes must be passed within a "