-
Notifications
You must be signed in to change notification settings - Fork 28.2k
/
Copy pathtraining_args.py
2968 lines (2675 loc) · 148 KB
/
training_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import io
import json
import math
import os
import warnings
from dataclasses import asdict, dataclass, field, fields
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import get_full_repo_name
from packaging import version
from .debug_utils import DebugOption
from .trainer_utils import (
EvaluationStrategy,
FSDPOption,
HubStrategy,
IntervalStrategy,
SchedulerType,
)
from .utils import (
ACCELERATE_MIN_VERSION,
ExplicitEnum,
cached_property,
is_accelerate_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tf32_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
requires_backends,
)
from .utils.generic import strtobool
from .utils.import_utils import is_optimum_neuron_available
logger = logging.get_logger(__name__)
log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)
if is_torch_available():
import torch
import torch.distributed as dist
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType
from .trainer_pt_utils import AcceleratorConfig
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
if is_torch_neuroncore_available(check_device=False):
# torchrun support
# https://github.com/pytorch/xla/pull/3609
if os.environ.get("TORCHELASTIC_RUN_ID"):
if is_optimum_neuron_available():
logger.info(
"Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
"will fail otherwise."
)
else:
logger.warning(
"Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
"training on AWS Trainium instances. More information here: "
"https://github.com/huggingface/optimum-neuron"
)
import torch_xla.distributed.xla_backend as xbn
if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
dist.init_process_group(backend="xla")
if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
smp.init()
def default_logdir() -> str:
"""
Same default as PyTorch
"""
import socket
from datetime import datetime
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
return os.path.join("runs", current_time + "_" + socket.gethostname())
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return default
def get_xla_device_type(device: "torch.device") -> Optional[str]:
"""
Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
"""
if is_torch_xla_available():
if device.type == "cpu":
return "CPU"
return xm.xla_real_devices([device])[0].split(":")[0]
return None
class OptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers.
"""
ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit"
GALORE_ADAMW = "galore_adamw"
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
GALORE_ADAFACTOR = "galore_adafactor"
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"
# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
_VALID_DICT_FIELDS = [
"accelerator_config",
"fsdp_config",
"deepspeed",
"gradient_checkpointing_kwargs",
"lr_scheduler_kwargs",
]
def _convert_str_dict(passed_value: dict):
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
for key, value in passed_value.items():
if isinstance(value, dict):
passed_value[key] = _convert_str_dict(value)
elif isinstance(value, str):
# First check for bool and convert
if value.lower() in ("true", "false"):
passed_value[key] = value.lower() == "true"
# Check for digit
elif value.isdigit():
passed_value[key] = int(value)
elif value.replace(".", "", 1).isdigit():
passed_value[key] = float(value)
return passed_value
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass
class TrainingArguments:
"""
TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
itself**.
Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
output_dir (`str`):
The output directory where the model predictions and checkpoints will be written.
overwrite_output_dir (`bool`, *optional*, defaults to `False`):
If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
points to a checkpoint directory.
do_train (`bool`, *optional*, defaults to `False`):
Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
by your training/evaluation scripts instead. See the [example
scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
do_eval (`bool`, *optional*):
Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
training/evaluation scripts instead. See the [example
scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
do_predict (`bool`, *optional*, defaults to `False`):
Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
intended to be used by your training/evaluation scripts instead. See the [example
scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
The evaluation strategy to adopt during training. Possible values are:
- `"no"`: No evaluation is done during training.
- `"steps"`: Evaluation is done (and logged) every `eval_steps`.
- `"epoch"`: Evaluation is done at the end of each epoch.
prediction_loss_only (`bool`, *optional*, defaults to `False`):
When performing evaluation and generating predictions, only returns the loss.
per_device_train_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.
per_device_eval_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
<Tip warning={true}>
When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
</Tip>
eval_accumulation_steps (`int`, *optional*):
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
left unset, the whole predictions are accumulated on GPU/NPU/TPU before being moved to the CPU (faster but
requires more memory).
eval_delay (`float`, *optional*):
Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
eval_strategy.
learning_rate (`float`, *optional*, defaults to 5e-5):
The initial learning rate for [`AdamW`] optimizer.
weight_decay (`float`, *optional*, defaults to 0):
The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`]
optimizer.
adam_beta1 (`float`, *optional*, defaults to 0.9):
The beta1 hyperparameter for the [`AdamW`] optimizer.
adam_beta2 (`float`, *optional*, defaults to 0.999):
The beta2 hyperparameter for the [`AdamW`] optimizer.
adam_epsilon (`float`, *optional*, defaults to 1e-8):
The epsilon hyperparameter for the [`AdamW`] optimizer.
max_grad_norm (`float`, *optional*, defaults to 1.0):
Maximum gradient norm (for gradient clipping).
num_train_epochs(`float`, *optional*, defaults to 3.0):
Total number of training epochs to perform (if not an integer, will perform the decimal part percents of
the last epoch before stopping training).
max_steps (`int`, *optional*, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
`max_steps` is reached.
lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
lr_scheduler_kwargs ('dict', *optional*, defaults to {}):
The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
warmup_ratio (`float`, *optional*, defaults to 0.0):
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
warmup_steps (`int`, *optional*, defaults to 0):
Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
log_level (`str`, *optional*, defaults to `passive`):
Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',
'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the
current log level for the Transformers library (which will be `"warning"` by default).
log_level_replica (`str`, *optional*, defaults to `"warning"`):
Logger log level to use on replicas. Same choices as `log_level`"
log_on_each_node (`bool`, *optional*, defaults to `True`):
In multinode distributed training, whether to log using `log_level` once per node, or only on the main
node.
logging_dir (`str`, *optional*):
[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
*output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.
logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
The logging strategy to adopt during training. Possible values are:
- `"no"`: No logging is done during training.
- `"epoch"`: Logging is done at the end of each epoch.
- `"steps"`: Logging is done every `logging_steps`.
logging_first_step (`bool`, *optional*, defaults to `False`):
Whether to log the first `global_step` or not.
logging_steps (`int` or `float`, *optional*, defaults to 500):
Number of update steps between two logs if `logging_strategy="steps"`. Should be an integer or a float in
range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
logging_nan_inf_filter (`bool`, *optional*, defaults to `True`):
Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan`
or `inf` is filtered and the average loss of the current logging window is taken instead.
<Tip>
`logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the
gradient is computed or applied to the model.
</Tip>
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
The checkpoint save strategy to adopt during training. Possible values are:
- `"no"`: No save is done during training.
- `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`.
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
very end of training, always.
save_steps (`int` or `float`, *optional*, defaults to 500):
Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
save_total_limit (`int`, *optional*):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
`output_dir`. When `load_best_model_at_end` is enabled, the "best" checkpoint according to
`metric_for_best_model` will always be retained in addition to the most recent ones. For example, for
`save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
checkpoints are saved: the last one and the best one (if they are different).
save_safetensors (`bool`, *optional*, defaults to `True`):
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
default `torch.load` and `torch.save`.
save_on_each_node (`bool`, *optional*, defaults to `False`):
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
the main one.
This should not be activated when the different nodes use the same storage as the files will be saved with
the same names for each node.
save_only_model (`bool`, *optional*, defaults to `False`):
When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
Note that when this is true, you won't be able to resume training from checkpoint.
This enables you to save storage by not storing the optimizer, scheduler & rng state.
You can only load the model using `from_pretrained` with this option set to `True`.
restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to restore the callback states from the checkpoint. If `True`, will override
callbacks passed to the `Trainer` if they exist in the checkpoint."
use_cpu (`bool`, *optional*, defaults to `False`):
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
seed (`int`, *optional*, defaults to 42):
Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
[`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.
data_seed (`int`, *optional*):
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
seed.
jit_mode_eval (`bool`, *optional*, defaults to `False`):
Whether or not to use PyTorch jit trace for inference.
use_ipex (`bool`, *optional*, defaults to `False`):
Use Intel extension for PyTorch when it is available. [IPEX
installation](https://github.com/intel/intel-extension-for-pytorch).
bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change.
fp16 (`bool`, *optional*, defaults to `False`):
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (`str`, *optional*, defaults to 'O1'):
For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
the [Apex documentation](https://nvidia.github.io/apex/amp).
fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
fp16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values.
tf32 (`bool`, *optional*):
Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends
on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to
the [TF32](https://huggingface.co/docs/transformers/performance#tf32) documentation. This is an
experimental API and it may change.
local_rank (`int`, *optional*, defaults to -1):
Rank of the process during distributed training.
ddp_backend (`str`, *optional*):
The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`, `"hccl"`.
tpu_num_cores (`int`, *optional*):
When training on TPU, the number of TPU cores (automatically passed by launcher script).
dataloader_drop_last (`bool`, *optional*, defaults to `False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not.
eval_steps (`int` or `float`, *optional*):
Number of update steps between two evaluations if `eval_strategy="steps"`. Will default to the same
value as `logging_steps` if not set. Should be an integer or a float in range `[0,1)`. If smaller than 1,
will be interpreted as ratio of total training steps.
dataloader_num_workers (`int`, *optional*, defaults to 0):
Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the
main process.
past_index (`int`, *optional*, defaults to -1):
Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of
the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will
use the corresponding output (usually index 2) as the past state and feed it to the model at the next
training step under the keyword argument `mems`.
run_name (`str`, *optional*, defaults to `output_dir`):
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/) and
[mlflow](https://www.mlflow.org/) logging. If not specified, will be the same as `output_dir`.
disable_tqdm (`bool`, *optional*):
Whether or not to disable the tqdm progress bars and table of metrics produced by
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
set to warn or lower (default), `False` otherwise.
remove_unused_columns (`bool`, *optional*, defaults to `True`):
Whether or not to automatically remove the columns unused by the model forward method.
label_names (`List[str]`, *optional*):
The list of keys in your dictionary of inputs that correspond to the labels.
Will eventually default to the list of argument names accepted by the model that contain the word "label",
except if the model used is one of the `XxxForQuestionAnswering` in which case it will also include the
`["start_positions", "end_positions"]` keys.
load_best_model_at_end (`bool`, *optional*, defaults to `False`):
Whether or not to load the best model found during training at the end of training. When this option is
enabled, the best checkpoint will always be saved. See
[`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit)
for more.
<Tip>
When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in
the case it is "steps", `save_steps` must be a round multiple of `eval_steps`.
</Tip>
metric_for_best_model (`str`, *optional*):
Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different
models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`. Will
default to `"loss"` if unspecified and `load_best_model_at_end=True` (to use the evaluation loss).
If you set this value, `greater_is_better` will default to `True`. Don't forget to set it to `False` if
your metric is better when lower.
greater_is_better (`bool`, *optional*):
Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models
should have a greater metric or not. Will default to:
- `True` if `metric_for_best_model` is set to a value that doesn't end in `"loss"`.
- `False` if `metric_for_best_model` is not set, or set to a value that ends in `"loss"`.
ignore_data_skip (`bool`, *optional*, defaults to `False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
can take a long time) but will not yield the same results as the interrupted training would have.
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
Use PyTorch Distributed Parallel Training (in distributed training only).
A list of options along the following:
- `"full_shard"`: Shard parameters, gradients and optimizer states.
- `"shard_grad_op"`: Shard optimizer states and gradients.
- `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes.
- `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes.
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
`"shard_grad_op"`).
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
fsdp_config (`str` or `dict`, *optional*):
Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of
fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
A List of config and its options:
- min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
passed).
- transformer_layer_cls_to_wrap (`List[str]`, *optional*):
List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
`T5Block` .... (useful only when `fsdp` flag is passed).
- backward_prefetch (`str`, *optional*)
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
`fsdp` field is passed).
A list of options along the following:
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
gradient
computation.
- `"backward_post"` : This prefetches the next set of parameters after the current set of
parameter’s
gradient computation.
- forward_prefetch (`bool`, *optional*, defaults to `False`)
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass.
- limit_all_gathers (`bool`, *optional*, defaults to `False`)
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- use_orig_params (`bool`, *optional*, defaults to `True`)
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
refer this
[blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
- sync_module_states (`bool`, *optional*, defaults to `True`)
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
ensure they are the same across all ranks after initialization
- cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`)
If `"True"`, only the first process loads the pretrained model checkpoint while all other processes
have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`,
otherwise all the processes except the main process would have random weights leading to unexpected
behaviour during training.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
- xla_fsdp_settings (`dict`, *optional*)
The value is a dictionary which stores the XLA FSDP wrapping parameters.
For a complete list of options, please see [here](
https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).
- xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`):
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
`ds_config.json`) or an already loaded json file as a `dict`"
<Tip warning={true}>
If enabling any Zero-init, make sure that your model is not initialized until
*after* initializing the `TrainingArguments`, else it will not be applied.
</Tip>
accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):
Config to be used with the internal `Accelerator` implementation. The value is either a location of
accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,
or an instance of [`~trainer_pt_utils.AcceleratorConfig`].
A list of config and its options:
- split_batches (`bool`, *optional*, defaults to `False`):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
- dispatch_batches (`bool`, *optional*):
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
underlying dataset is an `IterableDataset`, `False` otherwise.
- even_batches (`bool`, *optional*, defaults to `True`):
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
- use_seedable_sampler (`bool`, *optional*, defaults to `True`):
Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
- use_configured_state (`bool`, *optional*, defaults to `False`):
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
with hyperparameter tuning.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
label_smoothing_factor/num_labels` respectively.
debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `""`):
Enable one or more debug features. This is an experimental feature.
Possible options are:
- `"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that led to
the event
- `"tpu_metrics_debug"`: print debug metrics on TPU
The options should be separated by whitespaces.
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or
adafactor.
optim_args (`str`, *optional*):
Optional arguments that are supplied to AnyPrecisionAdamW.
group_by_length (`bool`, *optional*, defaults to `False`):
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
length_column_name (`str`, *optional*, defaults to `"length"`):
Column name for precomputed lengths. If the column exists, grouping by length will use these values rather
than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an
instance of `Dataset`.
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
integrations.
ddp_find_unused_parameters (`bool`, *optional*):
When using distributed training, the value of the flag `find_unused_parameters` passed to
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
ddp_bucket_cap_mb (`int`, *optional*):
When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
ddp_broadcast_buffers (`bool`, *optional*):
When using distributed training, the value of the flag `broadcast_buffers` passed to
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
Whether you want to pin memory in data loaders or not. Will default to `True`.
dataloader_persistent_workers (`bool`, *optional*, defaults to `False`):
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
increase RAM usage. Will default to `False`.
dataloader_prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push the model to the Hub every time the model is saved. If this is activated,
`output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content
will be pushed each time a save is triggered (depending on your `save_strategy`). Calling
[`~Trainer.save_model`] will also trigger a push.
<Tip warning={true}>
If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
pushed.
</Tip>
resume_from_checkpoint (`str`, *optional*):
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
hub_model_id (`str`, *optional*):
The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
`"organization_name/model"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the
name of `output_dir`.
Will default to the name of `output_dir`.
hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`):
Defines the scope of what is pushed to the Hub and when. Possible values are:
- `"end"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a
draft of a model card when the [`~Trainer.save_model`] method is called.
- `"every_save"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and
a draft of a model card each time there is a model save. The pushes are asynchronous to not block
training, and in case the save are very frequent, a new push is only attempted if the previous one is
finished. A last push is made with the final model at the end of training.
- `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named
last-checkpoint, allowing you to resume training easily with
`trainer.train(resume_from_checkpoint="last-checkpoint")`.
- `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the output
folder (so you will get one checkpoint folder per folder in your final repository)
hub_token (`str`, *optional*):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
`huggingface-cli login`.
hub_private_repo (`bool`, *optional*, defaults to `False`):
If True, the Hub repo will be set to private.
hub_always_push (`bool`, *optional*, defaults to `False`):
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Key word arguments to be passed to the `gradient_checkpointing_enable` method.
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
will instead store them as lists, with each batch kept separate.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
full_determinism (`bool`, *optional*, defaults to `False`)
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training. Important: this will negatively impact the performance, so only use it for debugging.
torchdynamo (`str`, *optional*):
If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`,
`"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
ray_scope (`str`, *optional*, defaults to `"last"`):
The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
then use the last checkpoint of all trials, compare those, and select the best one. However, other options
are also available. See the [Ray documentation](
https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
more options.
ddp_timeout (`int`, *optional*, defaults to 1800):
The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when
performing slow operations in distributed runnings. Please refer the [PyTorch documentation]
(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
information.
use_mps_device (`bool`, *optional*, defaults to `False`):
This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device.
torch_compile (`bool`, *optional*, defaults to `False`):
Whether or not to compile the model using PyTorch 2.0
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/).
This will use the best defaults for the [`torch.compile`
API](https://pytorch.org/docs/stable/generated/torch.compile.html?highlight=torch+compile#torch.compile).
You can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we
don't guarantee any of them will work as the support is progressively rolled in in PyTorch.
This flag and the whole compile API is experimental and subject to change in future releases.
torch_compile_backend (`str`, *optional*):
The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases.
torch_compile_mode (`str`, *optional*):
The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases.
split_batches (`bool`, *optional*):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices
during distributed training. If
set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it
must be a
round multiple of the number of processes you are using (such as GPUs).
include_tokens_per_second (`bool`, *optional*):
Whether or not to compute the number of tokens per second per device for training speed metrics.
This will iterate over the entire training dataloader once beforehand,
and will slow down the entire process.
include_num_input_tokens_seen (`bool`, *optional*):
Whether or not to track the number of input tokens seen throughout training.
May be slower in distributed training as gather operations must be called.
neftune_noise_alpha (`Optional[float]`):
If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft.
optim_target_modules (`Union[str, List[str]]`, *optional*):
The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
https://arxiv.org/abs/2403.03507
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
only.
batch_eval_metrics (`Optional[bool]`, defaults to `False`):
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
eval_on_start(`bool`, *optional*, defaults to `False`):
Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
"""
framework = "pt"
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
eval_strategy: Union[IntervalStrategy, str] = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
prediction_loss_only: bool = field(
default=False,
metadata={"help": "When performing evaluation and predictions, only returns the loss."},
)
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for evaluation."}
)
per_gpu_train_batch_size: Optional[int] = field(
default=None,
metadata={
"help": (
"Deprecated, the use of `--per_device_train_batch_size` is preferred. "
"Batch size per GPU/TPU core/CPU for training."
)
},
)
per_gpu_eval_batch_size: Optional[int] = field(
default=None,
metadata={
"help": (
"Deprecated, the use of `--per_device_eval_batch_size` is preferred. "
"Batch size per GPU/TPU core/CPU for evaluation."
)
},
)
gradient_accumulation_steps: int = field(
default=1,
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
)
eval_accumulation_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
)
eval_delay: Optional[float] = field(
default=0,
metadata={
"help": (
"Number of epochs or steps to wait for before the first evaluation can be performed, depending on the"
" eval_strategy."
)
},
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
max_steps: int = field(
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
lr_scheduler_type: Union[SchedulerType, str] = field(
default="linear",
metadata={"help": "The scheduler type to use."},
)
lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
default_factory=dict,
metadata={
"help": (
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
)
},
)
warmup_ratio: float = field(
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
)
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
log_level: Optional[str] = field(
default="passive",
metadata={
"help": (
"Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug',"
" 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and"
" lets the application set the level. Defaults to 'passive'."
),
"choices": trainer_log_levels.keys(),
},
)
log_level_replica: Optional[str] = field(
default="warning",
metadata={
"help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``",
"choices": trainer_log_levels.keys(),
},
)
log_on_each_node: bool = field(
default=True,
metadata={
"help": (
"When doing a multinode distributed training, whether to log once per node or just once on the main"
" node."
)
},
)
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
logging_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: float = field(
default=500,
metadata={
"help": (
"Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
"If smaller than 1, will be interpreted as ratio of total training steps."
)
},
)
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
save_steps: float = field(
default=500,
metadata={
"help": (
"Save checkpoint every X updates steps. Should be an integer or a float in range `[0,1)`. "
"If smaller than 1, will be interpreted as ratio of total training steps."
)
},
)
save_total_limit: Optional[int] = field(
default=None,
metadata={
"help": (
"If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in"
" `output_dir`. When `load_best_model_at_end` is enabled, the 'best' checkpoint according to"
" `metric_for_best_model` will always be retained in addition to the most recent ones. For example,"
" for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be"
" retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`,"
" it is possible that two checkpoints are saved: the last one and the best one (if they are different)."
" Default is unlimited checkpoints"
)
},
)
save_safetensors: Optional[bool] = field(
default=True,
metadata={
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
},
)
save_on_each_node: bool = field(
default=False,
metadata={
"help": (
"When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
" only on the main one"
)
},
)
save_only_model: bool = field(
default=False,
metadata={
"help": (
"When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state."
"Note that when this is true, you won't be able to resume training from checkpoint."
"This enables you to save storage by not storing the optimizer, scheduler & rng state."
"You can only load the model using from_pretrained with this option set to True."
)
},
)
restore_callback_states_from_checkpoint: bool = field(
default=False,
metadata={
"help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
},
)
no_cuda: bool = field(
default=False,
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
)
use_cpu: bool = field(
default=False,
metadata={
"help": " Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available."
},
)
use_mps_device: bool = field(
default=False,
metadata={
"help": "This argument is deprecated. `mps` device will be used if available similar to `cuda` device."
" It will be removed in version 5.0 of 🤗 Transformers"
},
)
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
jit_mode_eval: bool = field(