From fb2c2ce620f58fe77c6f023ab218ce82d489a3e1 Mon Sep 17 00:00:00 2001 From: swu671 Date: Mon, 6 Nov 2023 12:37:23 -0500 Subject: [PATCH 1/2] report grad_norm during training --- src/transformers/trainer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6850f4dca067..3b499ce5b661 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1783,6 +1783,7 @@ def _inner_training_loop( self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() + grad_norm: Optional[float] = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) @@ -1900,18 +1901,19 @@ def _inner_training_loop( # deepspeed does its own clipping if is_sagemaker_mp_enabled() and args.fp16: - self.optimizer.clip_master_grads(args.max_grad_norm) + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision - nn.utils.clip_grad_norm_( + _grad_norm = nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm, ) else: - self.accelerator.clip_grad_norm_( + _grad_norm = self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) + grad_norm = _grad_norm.item() if _grad_norm is not None else None # Optimizer step self.optimizer.step() @@ -1926,7 +1928,7 @@ def _inner_training_loop( self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) @@ -1941,7 +1943,7 @@ def _inner_training_loop( self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_tpu_available(): @@ -2246,7 +2248,7 @@ def _issue_warnings_after_load(self, load_result): f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." ) - def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: if is_torch_tpu_available(): xm.mark_step() @@ -2260,6 +2262,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for tr_loss -= tr_loss logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar From 111823a939eb5c63ccc7184367b707b6f4bedf60 Mon Sep 17 00:00:00 2001 From: swu671 Date: Wed, 17 Jan 2024 17:04:15 -0500 Subject: [PATCH 2/2] support getting grad_norm from deepspeed --- src/transformers/trainer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3b499ce5b661..2ba41c7821f8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -195,6 +195,7 @@ from accelerate import __version__ as accelerate_version from accelerate.utils import ( DistributedDataParallelKwargs, + DistributedType, GradientAccumulationPlugin, load_fsdp_model, load_fsdp_optimizer, @@ -1913,7 +1914,14 @@ def _inner_training_loop( model.parameters(), args.max_grad_norm, ) - grad_norm = _grad_norm.item() if _grad_norm is not None else None + + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + else: + grad_norm = _grad_norm.item() if _grad_norm is not None else None # Optimizer step self.optimizer.step()