diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index e3a66f420424..67bb2ae4f594 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -130,7 +130,7 @@ from torch import nn from transformers import Trainer class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") # forward pass outputs = model(**inputs) @@ -156,9 +156,7 @@ class EarlyStoppingCallback(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): if state.global_step >= self.num_steps: - return {"should_training_stop": True} - else: - return {} + control.should_training_stop = True ``` Then pass it to the [`Trainer`]'s `callback` parameter. @@ -737,7 +735,7 @@ accelerate launch --num_processes=2 \ --fsdp_transformer_layer_cls_to_wrap="BertLayer" \ --fsdp_sharding_strategy=1 \ --fsdp_state_dict_type=FULL_STATE_DICT \ - ./examples/pytorch/text-classification/run_glue.py + ./examples/pytorch/text-classification/run_glue.py \ --model_name_or_path google-bert/bert-base-cased \ --task_name $TASK_NAME \ --do_train \