Skip to content

Commit

Permalink
clean up trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 25, 2025
1 parent 64af4df commit 1bcf275
Showing 1 changed file with 50 additions and 33 deletions.
83 changes: 50 additions & 33 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
The Gaudi Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

import contextlib
import copy
import functools
import inspect
Expand Down Expand Up @@ -108,7 +107,6 @@
from optimum.utils import logging

from ..distributed import parallel_state
from ..local_accelerate.utils import FP8ContextWrapper, convert_model
from ..utils import (
HabanaProfile,
get_hpu_memory_stats,
Expand Down Expand Up @@ -748,9 +746,34 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio

self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)

# Is this necessary ? why is it not in acceleate ?
# Wrap `_gradient_checkpointing_func` in the model with `transformer_engine` `activation_checkpointing` context.
if self.accelerator.state.mixed_precision == "fp8":
FP8ContextWrapper.gradient_checkpointing_wrap(self.model)
import intel_transformer_engine as te

def _gradient_checkpointing_wrap(func, *args, **kwargs):
"""
`_gradient_checkpointing_func` always takes the function to be recomputed as the first argument. The function
below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context.
"""
_args = list(args)
_args[0] = te.distributed.activation_checkpointing()(_args[0])
args = tuple(_args)

return func(*args, **kwargs)

if hasattr(self.model, "gradient_checkpointing") and self.model.gradient_checkpointing:
self.model._gradient_checkpointing_func = functools.partial(
_gradient_checkpointing_wrap, self.model._gradient_checkpointing_func
)
return

for module in self.model.modules():
if hasattr(module, "gradient_checkpointing") and module.gradient_checkpointing:
module._gradient_checkpointing_func = functools.partial(
_gradient_checkpointing_wrap, module._gradient_checkpointing_func
)

else:
# Hack because `RegressionModel` in test_trainer.py doesn't have `gradient_checkpointing_disable`
if hasattr(self.model, "gradient_checkpointing_disable"):
Expand Down Expand Up @@ -783,9 +806,6 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio
# In this case we are in DDP + LOMO, which should be supported
self.optimizer = self.accelerator.prepare(self.optimizer)

if self.accelerator.state.mixed_precision == "fp8":
self.model = convert_model(model, _minimize_memory=self.args.minimize_memory)

if self.is_fsdp_enabled:
self.model = self.model_wrapped = model

Expand Down Expand Up @@ -1182,6 +1202,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio

return TrainOutput(self.state.global_step, train_loss, metrics)

# why is this rewritten ?
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
Expand Down Expand Up @@ -1321,6 +1342,7 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign
if self.args.adjust_throughput:
self.log_evaluate_save_time += time.perf_counter() - save_start

# we should do this in transformers directly
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
Expand Down Expand Up @@ -1360,6 +1382,7 @@ def _load_rng_state(self, checkpoint):
"\nThis won't yield the same results as if the training had not been interrupted."
)

# in transformers directly
def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
rng_states = {
Expand Down Expand Up @@ -1531,24 +1554,25 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return data.to(**kwargs)
return data

def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation. Modified by Habana to enable using `autocast` on Gaudi devices.
"""
if self.use_cpu_amp:
ctx_manager = torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled)
elif self.use_hpu_amp:
ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True)
else:
ctx_manager = contextlib.nullcontext()

# Merge autocast context and `fp8_autocast` context if FP8 is enabled.
# Currently FP8 is enabled only for training.
if self.accelerator.state.mixed_precision == "fp8" and self.model.training:
ctx_manager = FP8ContextWrapper(ctx_manager, self.accelerator.fp8_recipe_handler)

return ctx_manager
# handled by accelerate now (in model preparation)
# def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
# """
# A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
# arguments, depending on the situation. Modified by Habana to enable using `autocast` on Gaudi devices.
# """
# if self.use_cpu_amp:
# ctx_manager = torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled)
# elif self.use_hpu_amp:
# ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True)
# else:
# ctx_manager = contextlib.nullcontext()

# # Merge autocast context and `fp8_autocast` context if FP8 is enabled.
# # Currently FP8 is enabled only for training.
# if self.accelerator.state.mixed_precision == "fp8" and self.model.training:
# ctx_manager = FP8ContextWrapper(ctx_manager, self.accelerator.fp8_recipe_handler)

# return ctx_manager

def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Expand Down Expand Up @@ -1602,15 +1626,8 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te
self.accelerator.backward(loss, **kwargs)
self.model.base_model.update_and_allocate(self.state.global_step)
else:
if self.accelerator.state.mixed_precision == "fp8" and self.args.gradient_checkpointing:
# The precision used in backward pass should be same as the one used in forward pass.
# However when training with gradient_checkpointing and FP8 precision, recompute forward
# in backward does not automatically run with FP8 precision. In order to handle this,
# the backward is run in `fp8_autocast` context
with FP8ContextWrapper.create_fp8_context(self.accelerator.fp8_recipe_handler):
self.accelerator.backward(loss, **kwargs)
else:
self.accelerator.backward(loss, **kwargs)
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Expand Down

0 comments on commit 1bcf275

Please sign in to comment.