From 9863b465c9b93dbfdc04fe2c2dafb697733a9b2c Mon Sep 17 00:00:00 2001 From: Eugen Ajechiloae Date: Wed, 31 Jan 2024 15:10:14 +0200 Subject: [PATCH] tidy code based on code review --- .../integrations/integration_utils.py | 43 +++++++------------ 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 8829d63a0189..7e433be7f1ab 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -1502,10 +1502,7 @@ def setup(self, args, state, model, tokenizer, **kwargs): self._clearml_task = self._clearml.Task.init( project_name=os.getenv("CLEARML_PROJECT", "HuggingFace Transformers"), task_name=os.getenv("CLEARML_TASK", "Trainer"), - auto_connect_frameworks={ - "tensorboard": False, - "pytorch": False, - }, + auto_connect_frameworks={"tensorboard": False, "pytorch": False}, output_uri=True, ) self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union( @@ -1515,27 +1512,21 @@ def setup(self, args, state, model, tokenizer, **kwargs): logger.info("ClearML Task has been initialized.") self._initialized = True - ignore_hparams_config_section = ( - ClearMLCallback._hparams_section - + ClearMLCallback.log_suffix - + "/" - + ClearMLCallback._ignore_hparams_overrides - ) + suffixed_hparams_section = ClearMLCallback._hparams_section + ClearMLCallback.log_suffix + ignore_hparams_config_section = suffixed_hparams_section + "/" + ClearMLCallback._ignore_hparams_overrides if self._clearml.Task.running_locally(): - self._copy_training_args_as_hparams( - args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix - ) + self._copy_training_args_as_hparams(args, suffixed_hparams_section) self._clearml_task.set_parameter( name=ignore_hparams_config_section, value=True, value_type=bool, description=( - "If True, ignore hyperparameters overrides done in the UI section" - + "when running remotely. Otherwise, the overrides will be used" + "If True, ignore Transformers hyperparameters overrides done in the UI/backend " + + "when running remotely. Otherwise, the overrides will be applied when running remotely" ), ) elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True): - self._clearml_task.connect(args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix) + self._clearml_task.connect(args, suffixed_hparams_section) else: self._copy_training_args_as_hparams( args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix @@ -1543,10 +1534,7 @@ def setup(self, args, state, model, tokenizer, **kwargs): if getattr(model, "config", None) is not None: ignore_model_config_section = ( - ClearMLCallback._hparams_section - + ClearMLCallback.log_suffix - + "/" - + ClearMLCallback._ignoge_model_config_overrides + suffixed_hparams_section + "/" + ClearMLCallback._ignoge_model_config_overrides ) configuration_object_description = ClearMLCallback._model_config_description.format( ClearMLCallback._model_connect_counter @@ -1559,8 +1547,8 @@ def setup(self, args, state, model, tokenizer, **kwargs): value=True, value_type=bool, description=( - "If True, ignore model configuration overrides done in the UI section " - + "when running remotely. Otherwise, the overrides will be used" + "If True, ignore Transformers model configuration overrides done in the UI/backend " + + "when running remotely. Otherwise, the overrides will be applied when running remotely" ), ) self._clearml_task.set_configuration_object( @@ -1652,8 +1640,8 @@ def on_save(self, args, state, control, **kwargs): if self._log_model and self._clearml_task and state.is_world_process_zero: ckpt_dir = f"checkpoint-{state.global_step}" artifact_path = os.path.join(args.output_dir, ckpt_dir) - logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.") name = ckpt_dir + ClearMLCallback.log_suffix + logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.") output_model = self._clearml.OutputModel(task=self._clearml_task, name=name) output_model.connect(task=self._clearml_task, name=name) output_model.update_weights_package( @@ -1681,10 +1669,11 @@ def on_save(self, args, state, control, **kwargs): self._checkpoints_saved = self._checkpoints_saved[1:] def _copy_training_args_as_hparams(self, training_args, prefix): - as_dict = {field.name: getattr(training_args, field.name) for field in fields(training_args) if field.init} - token_keys = [k for k in as_dict.keys() if k.endswith("_token")] - for token_key in token_keys: - as_dict.pop(token_key, None) + as_dict = { + field.name: getattr(training_args, field.name) + for field in fields(training_args) + if field.init and not field.name.endswith("_token") + } flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()} self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)