-
Notifications
You must be signed in to change notification settings - Fork 433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple calls to .fit #138
Comments
I can see at least two primary ways in which multiple
With that in mind:
Could be quite a lot, but those should mostly be left to the user, in my opinion. They could do things like change dataloaders, loss functions or even small model modifications between calls.
The only thing that seems clear as a parameter to
Should definitely have a different concept for each call to
That makes sense and could be useful in the second use case of pre-training and fine-tuning. That would reset things like schedulers, I would think. The run directory could be subfoldered on calls to |
This may depend on the scheduler.
|
I think Tyler and Austin hit all the key points, but I want to reiterate the importance of accommodating hparam changes for each call to |
Thanks for these ideas! Breaking down the main questions:
For ease of use / usability, I'm thinking that all parameters for
I now realize that once #43 is resolved, then this question will be irrelevant. I presume that no events will fire on
It seems like this question is much more nuanced than I thought. I think having one Instead, as an alternative design to the trainer, ....As an alternative design, what if we converted the Trainer into a function (rather than a class)? E.g. something like this: def fit(model, train_dataloader, eval_dataloader, max_epochs, all_other_hparams, ...):
...
def eval(model, eval_dataloader, ...):
... Since users would have to supply the parameters to The "create from hparams" logic would then move to the hparams file e.g. calling |
Notes from the discussion on 12/13:
Discuss further:
|
For training + fine-tuning it seems like we'd want to support passing dataloaders to
If |
@ajaysaini725 For interactive development and debugging, I am leaning towards letting users set Re: # of epochs, I do like having For pretraining + fine tuning, based off of the discussion, I think you will want to create a new trainer. That is a different training job at that point I think? |
Support multiple calls to .fit for partial training, as discussed in #138. Specifically, this PR: * Moves most initialization logic to `Trainer.__init__`, so it is invoked only once * Added `num_epochs` and `num_batches` as optional parameters to `Trainer.fit` * Significant refactor of the training loop to support the above * Added test cases (and updated the synthetic dataloader to use a deterministic generator, so the model from the tests would be identical)
Looking at #154 (which I think is a good first step). it may be a good idea to revisit this discussion. other than the concept that each Looking at the existing args. One way to split them would be "anything related to engineering" -> Current args:
To me, these are all systems/logging concerns that obviously stay on
These are the args I think users could reasonably want to interact with because.. science?
Not sure, probably
It is worth noting that some frameworks (PTL) don't support interactive development well at all. interactive development style s can make reproducibility difficult. My view is, it would be nice to support interactive development (as there are some clear benefits) but encourage people to refactor the final result into a script that can be run end to end in some way. I'm not sure where we officially stand on this spectrum. |
Thanks for this detailed feedback @A-Jacobson. Given the way that YAHP (the hyperparameter library) works and our yamls are written, all arguments have to be specified in INIT -- otherwise we would lose whatever (It's not impossible, though. We could restructure our YAMLs to look like this: trainer:
model:
device:
ddp_sync_strategy:
ddp_timeout:
seed:
deterministic_mode:
# logging
callbacks:
log_destinations:
checkpoint_filepath:
checkpoint_interval_unit:
checkpoint_folder:
checkpoint_interval:
runs:
- algorithms: # first call to fit
num_epochs:
train_dataloader:
- algorithms: # second call to fit
num_epochs:
train_dataloader: We'd also need a new entrypoint. But that's probably better left for another PR?) That said, we can definitely support additional arguments on a subsequent call to |
sure, most of the proposal is just moving existing args around. though I did change class Trainer:
"""Trainer for training a model with algorithms.
Can be created either with ``__init__`` or by providing a
:class:`~composer.trainer.TrainerHparams` object
(see :meth:`~composer.trainer.Trainer.create_from_hparams`).
Args:
model (BaseMosaicModel): The model to train.
# reproducibility
precision (Precision, optional): Numerical precision to use for training. (default: ``Precision.FP32``).
seed (int, optional): The seed used in randomization. When not provided a random seed
will be created. (default: ``None``)
deterministic_mode (bool, optional): Run the model deterministically. Experimental. Performance
degradations expected. Certain Torch modules may not have deterministic implementations,
which will result in a crash. (default: ``False``)
# hardware
device (Device, optional): The device to use for training. Either `DeviceCPU` or `DeviceGPU`.
(default ``DeviceCPU(n_cpus=1)``)
ddp_sync_strategy (DDPSyncStrategy, optional): The strategy to use for synchronizing gradients.
ddp_timeout (float, optional): Timeout, in seconds, for initializing the DDP process group.
(default: ``5.0``)
precision (Precision, optional): Numerical precision to use for training. (default: ``Precision.FP32``).
# logging
validate_every_n_batches (int, optional): Compute metrics on evaluation data every N batches.
Set to -1 to never validate on a batchwise frequency. (default: ``-1``)
validate_every_n_epochs (int, optional): Compute metrics on evaluation data every N epochs.
Set to -1 to never validate on a epochwise frequency. (default: ``1``)
compute_training_metrics (bool, optional): True to compute metrics on training data and False to not.
(default: ``False``)
log_destinations (List[BaseLoggerBackend], optional): The destinations to log training information to.
(default ``[TQDMLoggerBackend()]``).
callbacks (Sequence[Callback], optional): The callbacks to run during training. (default: ``[]``)
checkpoint_filepath (str, optional): The path to a trainer checkpoint file. If provided
the trainer will load the state (along with it's associated attributes) during initialization.
(default: ``None``)
checkpoint_interval_unit (int, optional): Unit for the checkpoint save interval -- should be 'ep')
for epochs, 'ba' for batches, or None to disable checkpointing. (default: ``None``).
checkpoint_folder (str, optional): The folder to save checkpoints to. Relative to `os.environ.get('RUN_DIRECTORY', '.')`,
(default: ``checkpoints``)
checkpoint_interval (int, optional): The frequency with which to checkpoint. (default: ``1``)
config (Dict[str, Any], optional): Extra user-provided trainer configuration. Will be persisted
along with the trainer state during checkpointing. (default: ``None``)
"""
def fit(
self,
train_dataloader,
eval_dataloader,
optimizer,
scheduler,
max_duration,
grad_accum,
grad_clip_norm,
algorithms):
"""
Args:
train_dataloader,
eval_dataloader,
optimizer: pytorch optimizer object
scheduler: scheduler object or function
max_duration: (str) : following the same convention as our schedulers ex: ("82ep")
grad_accum (int, optional): The number of microbatches to split a per-device batch into. Gradients
are summed over the microbatches per device. (default: ``1``)
grad_clip_norm (float, optional): The norm to clip gradient magnitudes to. Set to None for no gradient
clipping. (default: ``None``)
algorithms (List[Algorithm], optional): The algorithms to use during training.
(default: ``[]``)
""" This would easily support situations where:
This indirectly supports situations where:
|
Thanks for sketching out these use cases. However, would these use cases still be the same training run? I feel like changing algorithms, datasets, or schedulers would now be a different training run, in which case a new trainer (by the convention) should be created. Using the existing trainer and changing pretty much any parameter mid-training without starting over will cause compatibility issues with algorithms. For example, randaugment (https://github.com/mosaicml/composer/blob/dev/composer/algorithms/randaugment/randaugment.py) modifies the dataset of the dataloader upon the training start event. If you added the randaugment algorithm mid-training (i.e. after training start fires), then it wouldn't do anything; if you change the dataloader mid-training, then the new dataloader wouldn't have randaugment applied. I think it would be too much burden on the user to how to manually apply the algorithms when changing training parameters, so I'd strongly prefer not to allow this (at least, not allow it easily). If we don't want to go with the convention of "one trainer = one run", then I presume we'd want to have I think we still need to decide whether |
Design brainstorm: https://www.notion.so/Multiple-Calls-to-Fit-3a682d9bece04c2e93b3c9a429365c63 |
This PR changes how callbacks are shutdown. The overall goals are: 1. Ensure that any exit traceback is captured via the loggers (e.g. in `WandB` and the `FileLogger`) 2. Not to shutdown the callbacks at the end of `Trainer.fit()`, which is a requirement for #138 To accomplish these goals, this PR changes where `Engine.close` (which shuts down the callbacks) is run. Specifically, this method now runs at the any of these events: 1. Via `atexit`: If the exception causes the python process to crash, `atexit` callbacks will run after the traceback has been printed (test this for yourself! hard to write a test case for this) 2. on `__del__`: If the `trainer` instance is being garbage collected, it should be shut down. 3. If manually invoked via `trainer.close`: If the user is creating multiple trainers, they may want to shut one down at a specific point. If `engine.close()` is invoked multiple times, it should be a no-op. Implementation overview: - Moved `close` to run on `__del__` / `atexit` as described above. The actual shutdown logic needs to be in a static method, so it does not hold a reference count in the `atexit` registry that would prevent garbage collection. Added test cases. - Added `Event.FIT_END` since closing is now decoupled from the end of training. This allows for callbacks to log stuff at the end of training and flush (but not close) any open file handles. - Fixed a bug in the `FileLogger` where the prefix was being printed multiple times if `write` was called multiple times per line - Fixed the documentation for the Event class to show all events in the mock training loop - Added tests and fixed callbacks and loggers to ensure that `.close()` and `.post_close()` implementations are idempotent. - Added test cases to ensure that unhandled stack traces are captured in the logfile.
This PR changes how callbacks are shutdown. The overall goals are: 1. Ensure that any exit traceback is captured via the loggers (e.g. in `WandB` and the `FileLogger`) 2. Not to shutdown the callbacks at the end of `Trainer.fit()`, which is a requirement for #138 To accomplish these goals, this PR changes where `Engine.close` (which shuts down the callbacks) is run. Specifically, this method now runs at the any of these events: 1. Via `atexit`: If the exception causes the python process to crash, `atexit` callbacks will run after the traceback has been printed (test this for yourself! hard to write a test case for this) 2. on `__del__`: If the `trainer` instance is being garbage collected, it should be shut down. 3. If manually invoked via `trainer.close`: If the user is creating multiple trainers, they may want to shut one down at a specific point. If `engine.close()` is invoked multiple times, it should be a no-op. Implementation overview: - Moved `close` to run on `__del__` / `atexit` as described above. The actual shutdown logic needs to be in a static method, so it does not hold a reference count in the `atexit` registry that would prevent garbage collection. Added test cases. - Added `Event.FIT_END` since closing is now decoupled from the end of training. This allows for callbacks to log stuff at the end of training and flush (but not close) any open file handles. - Fixed a bug in the `FileLogger` where the prefix was being printed multiple times if `write` was called multiple times per line - Fixed the documentation for the Event class to show all events in the mock training loop - Added tests and fixed callbacks and loggers to ensure that `.close()` and `.post_close()` implementations are idempotent. - Added test cases to ensure that unhandled stack traces are captured in the logfile.
The trainer should support multiple calls to .fit.
Composer is going with the convention of "one run" = "one instance of the Trainer". So, if you want to do this, create a new trainer for each run:
Nonetheless, there are valid reasons for calling
.fit()
multiple times, for example:To support this, we will allow
.fit()
to optionally take atraining_duration
parameter. If specified, then.fit
will train for this much time..fit()
can be called multiple times, and each time it will train for the specified duration. If the duration not specified, then it will train formax_epochs
. The trainer will never train beyondmax_epochs
.To support changing trainer behavior, almost all attributes that are specified upon
__init__
will be bound to the trainer as attributes or properties with proper getters and setters. However, when manually updating attributes in the middle of a.fit
, then the burden is on the user to make sure that changed attributes are in the correct state (e.g. adding a callback halfway through? make sure that you calledcallback.run_event(Event.INIT)
before calling.fit(training duration)
again).In pseudocode:
Todos:
__init__
arguments as properites/attributes #154 (which depends on Removed dataclass from state #153).fit(training_duration)
The text was updated successfully, but these errors were encountered: