Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple calls to .fit #138

Closed
2 tasks
ravi-mosaicml opened this issue Dec 7, 2021 · 12 comments · Fixed by #948
Closed
2 tasks

Multiple calls to .fit #138

ravi-mosaicml opened this issue Dec 7, 2021 · 12 comments · Fixed by #948
Assignees
Labels
enhancement New (engineering) enhancements, such as features or API changes. Needs Design Needs Design Further design is required. Do not start implementation until design questions are reso

Comments

@ravi-mosaicml
Copy link
Contributor

ravi-mosaicml commented Dec 7, 2021

The trainer should support multiple calls to .fit.

Composer is going with the convention of "one run" = "one instance of the Trainer". So, if you want to do this, create a new trainer for each run:

  • Pre training and fine tuning
  • Sweeps across parameters

Nonetheless, there are valid reasons for calling .fit() multiple times, for example:

  • When doing interactive development in developing an algorithm, model, etc...
  • When you want to change trainer properties in the middle of a run (outside of an algorithm)

To support this, we will allow .fit() to optionally take a training_duration parameter. If specified, then .fit will train for this much time. .fit() can be called multiple times, and each time it will train for the specified duration. If the duration not specified, then it will train for max_epochs. The trainer will never train beyond max_epochs.

To support changing trainer behavior, almost all attributes that are specified upon __init__ will be bound to the trainer as attributes or properties with proper getters and setters. However, when manually updating attributes in the middle of a .fit, then the burden is on the user to make sure that changed attributes are in the correct state (e.g. adding a callback halfway through? make sure that you called callback.run_event(Event.INIT) before calling .fit(training duration) again).

In pseudocode:

class Trainer:
    def __init__(model, train_dataloader, max_epochs): ...
         self.state = State(model, train_dataloader, max_epochs)
 
    @property
    def train_dataloader(self):
        return self.state.train_dataloader

    @train_dataloader.setter
    def train_dataloader(self, train_dataloader):
        self.state.train_dataloader = train_dataloader

    def fit(self, duration = None):
        if duration is None:
            # train to end
            ...
        else:
            # train for duration
            ...

Todos:

@ravi-mosaicml ravi-mosaicml added the enhancement New (engineering) enhancements, such as features or API changes. label Dec 7, 2021
@siriuslee
Copy link
Contributor

siriuslee commented Dec 8, 2021

I can see at least two primary ways in which multiple .fit() calls would be useful:

  1. Interactive development where you want to see how the model develops. In this case, something like trainer.fit(epochs=1) would be useful to fit for one extra epoch and then the user can evaluate their model again.
  2. Training curricula like pre-training followed by fine-tuning. In this case the user might switch a lot of hyperparameters (e.g. dataset, loss function, optimizer)

With that in mind:

What hyperparemtets can be changed between calls to .fit()?

Could be quite a lot, but those should mostly be left to the user, in my opinion. They could do things like change dataloaders, loss functions or even small model modifications between calls.

Should .fit() take any arguments, or should all hyperparemters be specified as trainer attributes / properties? If the former, should attributes specified via .fit() be removed from .init() (and the hyperparameters)

The only thing that seems clear as a parameter to .fit() would be some measure of training amount (epochs, steps, etc)

Should Event.INIT fire for every call to .fit() (Likely yes, as callbacks may need to reinitialize as they were closed at the end of .fit())

Should definitely have a different concept for each call to .fit() that is not Event.INIT. If you want to keep it there, we definitely need documentation recommending that callbacks check if they have already been initialized.

Should there be a trainer.reset()? This would help differentiate between "let me try the same experiment again" and "let me extend the training run"

That makes sense and could be useful in the second use case of pre-training and fine-tuning. That would reset things like schedulers, I would think. The run directory could be subfoldered on calls to reset as well, though I'm not sure I've thought that through. You definitely wouldn't want it done every call to .fit(), though, because it would feel really awkward for the interactive use case (1).

@A-Jacobson
Copy link
Contributor

A-Jacobson commented Dec 8, 2021

Would LR schedulers be reset?

This may depend on the scheduler.

  • For the most common schedulers (one cycle) I think you have to reset. The learning rate goes to zero after each call to fit. The new call should start a new cycle with this all preferably being tracked in the same logger directory/project.

  • In a situation where we're reducing learning rate based on some validation metric, a user most likely would want to start off at the last learning rate used when they continue training. Making that final learning rate available and letting a user make that choice is also an option.

@growlix
Copy link
Contributor

growlix commented Dec 8, 2021

I think Tyler and Austin hit all the key points, but I want to reiterate the importance of accommodating hparam changes for each call to .fit(). In the case of e.g. MLM pre-training of RoBERTa followed by fine-tuning on a suite of downstream tasks like GLUE, each call to .fit() would have different hparams, including dataset, dataloader, learning rate, batch size, methods/algorithms, logger/W&B run names, and initialization (pretraining is random init, fine-tuning is from final pretraining checkpoint). @moinnadeem probably has some thoughts on the matter. Multi-task evaluation of vision models will probably be pretty analogous.

@ravi-mosaicml
Copy link
Contributor Author

Thanks for these ideas! Breaking down the main questions:

What hyperparemtets can be changed between calls to .fit()?
Should .fit() take any arguments, or should all hyperparemters be specified as trainer attributes / properties? If the former, should attributes specified via .fit() be removed from .init() (and the hyperparameters)

For ease of use / usability, I'm thinking that all parameters for __init__ would also be mapped to trainer attributes / properties (that update the appropriate field on the state, if necessary), and would also be supported as parameters to fit(). Specifying parameters on fit() would be equivalent to updating the property/attribute on the trainer immediately before calling fit() without arguments. @siriuslee would this be OK?

Should Event.INIT fire for every call to .fit() (Likely yes, as callbacks may need to reinitialize as they were closed at the end of .fit())

I now realize that once #43 is resolved, then this question will be irrelevant. I presume that no events will fire on __init__ and instead we'll just have Event.TRAINING_START in .fit().

Should there be a trainer.reset()? This would help differentiate between "let me try the same experiment again" and "let me extend the training run"

It seems like this question is much more nuanced than I thought. I think having one reset() is way too vague. Would it be OK to put this burden on the user to explicitly specify which components they want reset? (E.g. what to reset a model? Then reinitialize it. Want to reset the learning rate? Pass in a new LR scheduler object)

Instead, as an alternative design to the trainer, ....

As an alternative design, what if we converted the Trainer into a function (rather than a class)? E.g. something like this:

def fit(model, train_dataloader, eval_dataloader, max_epochs, all_other_hparams, ...):
   ...

def eval(model, eval_dataloader, ...):
   ...

Since users would have to supply the parameters to .fit(), they would know exactly what is being recreated vs what is being reused.

The "create from hparams" logic would then move to the hparams file e.g. calling TrainerHparams.initialize_object() would return a functools.partial(fit, training_arguments_from_hparams). While no parameters would be necessary, users could still selectively override the hparams by passing in keyword arguments to the returned partial function.

@ravi-mosaicml
Copy link
Contributor Author

ravi-mosaicml commented Dec 13, 2021

Notes from the discussion on 12/13:

  • Make _train_batch() public, and add train_epoch() to support interactive mode.
  • Throw clean errors if someone mixes calls of train_batch / train_epoch
  • If you call .fit after train_batch or train_epoch, then continue training until the end.
  • If you want a new experiment, create a new trainer
  • "throw a fit": You can only call it once. Throw with a clean error if you want to call it after the model is fully trained.

Discuss further:

  • Ctrl-C

@ajaysaini725
Copy link
Contributor

ajaysaini725 commented Dec 13, 2021

For training + fine-tuning it seems like we'd want to support passing dataloaders to fit() rather than to __init__() (assuming we are fine-tuning with a different dataset than we use for pre-training). What if we did:

def fit(self, train_dataloader, eval_dataloader, epochs=None)

If epochs is set to a number than train for only that number of epochs otherwise train for self.state.max_epochs.

@ravi-mosaicml
Copy link
Contributor Author

@ajaysaini725 For interactive development and debugging, I am leaning towards letting users set trainer.train_dataloader = ... and trainer.eval_dataloder = ... (well, for eval, not exactly like that, but similar, since this will be merged in after #120), so those parameters would not be available in .fit(). See #154 which implements this. But I could be convinced otherwise to also allow these parameters in .fit.....

Re: # of epochs, I do like having trainer.fit("10ep") more than doing something like for _ in range(10): trainer.train_epoch(). If trainer.fit(training_time = None) took training time (similar to what @siriuslee mentioned), then train_epoch() and train_batch() would not need to be public. Instead, .fit() would be responsible for doing a clean resume when it is called again. It would not train beyond max_epochs, regardless of what was passed in. One advantage of this approach as it would ensure that all events are properly called. (e.g. if train_batch() was public, then it would be somewhat messy to ensure that epoch_start and epoch_end events are invoked.)

For pretraining + fine tuning, based off of the discussion, I think you will want to create a new trainer. That is a different training job at that point I think?

ravi-mosaicml added a commit that referenced this issue Dec 15, 2021
Support multiple calls to .fit for partial training, as discussed in #138. Specifically, this PR:
* Moves most initialization logic to `Trainer.__init__`, so it is invoked only once
* Added `num_epochs` and `num_batches` as optional parameters to `Trainer.fit`
* Significant refactor of the training loop to support the above
* Added test cases (and updated the synthetic dataloader to use a deterministic generator, so the model from the tests would be identical)
@A-Jacobson
Copy link
Contributor

A-Jacobson commented Dec 16, 2021

Looking at #154 (which I think is a good first step). it may be a good idea to revisit this discussion. other than the concept that each Trainer object should represent one run (a concept we have yet to define) it seems we didn't land on much.

Looking at the existing args. One way to split them would be "anything related to engineering" -> init, "anything related to the science" on fit. Here's my attempt:

Current args:

model
train_dataloader
eval_dataloader
max_epochs (int): 
algorithms 
optimizer_hparams: 
schedulers_hparams: 
device
grad_accum 
grad_clip_norm
validate_every_n_batches
validate_every_n_epochs
compute_training_metrics
precision
ddp_timeout 
seed 
deterministic_mode 
log_destinations
callbacks
checkpoint_filepath 
checkpoint_interval_unit
checkpoint_folder
checkpoint_interval 
train_subset_num_batches 
eval_subset_num_batches 
config 

To me, these are all systems/logging concerns that obviously stay on __init__. I can't imagine changing them as part of a training curricula and they shouldn't have any effect on the quality of a trained model.

# systems
device
ddp_sync_strategy 
ddp_timeout 
seed
deterministic_mode 

# logging
callbacks
log_destinations
checkpoint_filepath 
checkpoint_interval_unit
checkpoint_folder
checkpoint_interval 
config

These are the args I think users could reasonably want to interact with because.. science?

num epochs (or one of the other duration measure we're working on in the timing discussion)
learning rate and other optimizer params (not sure about entire optimizer)
algorithms
training data and params (specifically batch size)
grad_accum 
grad_clip_norm
eval_dataloader # changing out the evaluation data indicates a new run as you now care about metrics on a new task,  though it would be an awkward user experience to set one data loader at `init` and another at `fit`.

Not sure, probably init:

validate_every_n_batches 
validate_every_n_epochs
compute_training_metrics

It is worth noting that some frameworks (PTL) don't support interactive development well at all. interactive development style s can make reproducibility difficult. My view is, it would be nice to support interactive development (as there are some clear benefits) but encourage people to refactor the final result into a script that can be run end to end in some way. I'm not sure where we officially stand on this spectrum.

@ravi-mosaicml
Copy link
Contributor Author

ravi-mosaicml commented Dec 16, 2021

Thanks for this detailed feedback @A-Jacobson. Given the way that YAHP (the hyperparameter library) works and our yamls are written, all arguments have to be specified in INIT -- otherwise we would lose whatever .fit parameters are specified.

(It's not impossible, though. We could restructure our YAMLs to look like this:

trainer:
  model:
  device:
  ddp_sync_strategy:
  ddp_timeout:
  seed:
  deterministic_mode:
  
  # logging
  callbacks:
  log_destinations:
  checkpoint_filepath:
  checkpoint_interval_unit:
  checkpoint_folder:
  checkpoint_interval:

runs:
   - algorithms:  # first call to fit
     num_epochs:
     train_dataloader:
   - algorithms:  # second call to fit
     num_epochs:
     train_dataloader:

We'd also need a new entrypoint. But that's probably better left for another PR?)

That said, we can definitely support additional arguments on a subsequent call to .fit (even if they are also specified on init). It would be really helpful if you could write out the function signature (and docstrings) so I can get an idea what the expected behavior would be.

@A-Jacobson
Copy link
Contributor

sure, most of the proposal is just moving existing args around. though I did change duration to mirror what we're doing with schedulers and yaml configs.

class Trainer:
    """Trainer for training a model with algorithms.

    Can be created either with ``__init__`` or by providing a
    :class:`~composer.trainer.TrainerHparams` object
    (see :meth:`~composer.trainer.Trainer.create_from_hparams`).

    Args:

        model (BaseMosaicModel): The model to train.

       # reproducibility
        precision (Precision, optional): Numerical precision to use for training. (default: ``Precision.FP32``).
        seed (int, optional): The seed used in randomization. When not provided a random seed
            will be created. (default: ``None``)
        deterministic_mode (bool, optional): Run the model deterministically. Experimental. Performance
            degradations expected. Certain Torch modules may not have deterministic implementations,
            which will result in a crash. (default: ``False``)

       # hardware
        device (Device, optional): The device to use for training. Either `DeviceCPU` or `DeviceGPU`.
            (default ``DeviceCPU(n_cpus=1)``)
        ddp_sync_strategy (DDPSyncStrategy, optional): The strategy to use for synchronizing gradients.
        ddp_timeout (float, optional): Timeout, in seconds, for initializing the DDP process group.
            (default: ``5.0``)
         precision (Precision, optional): Numerical precision to use for training. (default: ``Precision.FP32``).


       # logging
        validate_every_n_batches (int, optional): Compute metrics on evaluation data every N batches.
             Set to -1 to never validate on a batchwise frequency. (default: ``-1``)
        validate_every_n_epochs (int, optional): Compute metrics on evaluation data every N epochs.
            Set to -1 to never validate on a epochwise frequency. (default: ``1``)
        compute_training_metrics (bool, optional): True to compute metrics on training data and False to not.
            (default: ``False``)
        log_destinations (List[BaseLoggerBackend], optional): The destinations to log training information to.
            (default ``[TQDMLoggerBackend()]``).
        callbacks (Sequence[Callback], optional): The callbacks to run during training. (default: ``[]``)
        checkpoint_filepath (str, optional): The path to a trainer checkpoint file. If provided
            the trainer will load the state (along with it's associated attributes) during initialization.
            (default: ``None``)
        checkpoint_interval_unit (int, optional): Unit for the checkpoint save interval -- should be 'ep')
            for epochs, 'ba' for batches, or None to disable checkpointing. (default: ``None``).
        checkpoint_folder (str, optional): The folder to save checkpoints to. Relative to `os.environ.get('RUN_DIRECTORY', '.')`, 
            (default: ``checkpoints``)
        checkpoint_interval (int, optional): The frequency with which to checkpoint. (default: ``1``)
        config (Dict[str, Any], optional): Extra user-provided trainer configuration. Will be persisted
            along with the trainer state during checkpointing. (default: ``None``)
   """
      def fit(
        self, 
        train_dataloader, 
        eval_dataloader, 
        optimizer,
        scheduler,
        max_duration, 
        grad_accum, 
        grad_clip_norm,
        algorithms):
        """
        Args:
               train_dataloader, 
               eval_dataloader, 
               optimizer: pytorch optimizer object
               scheduler: scheduler object or function
               max_duration: (str) : following the same convention as our schedulers ex: ("82ep")
               grad_accum (int, optional): The number of microbatches to split a per-device batch into. Gradients
            are summed over the microbatches per device. (default: ``1``)
              grad_clip_norm (float, optional): The norm to clip gradient magnitudes to. Set to None for no gradient
            clipping. (default: ``None``)
             algorithms (List[Algorithm], optional): The algorithms to use during training.
            (default: ``[]``)

        """

This would easily support situations where:

  1. You're training a model interactively and want to continue training with the same settings after an initial run (though something like a cosine annealing scheduler would have to be reset per run)
  2. You want to train with 1 set of algorithms, then another.
  3. You want to change the grad_accum/gradient clipping settings.
  4. You want to train on one dataset, then another.
  5. You want to train with one lr scheduler, then another.

This indirectly supports situations where:

  1. You want to manually change the learning rate between runs (you would have to access the learning rate via the optimizer object). Maybe this is worth directly supporting?
  2. You want to change the batch size between runs (have to access the batch size on the data loader object or create a new dataloader)

@ravi-mosaicml
Copy link
Contributor Author

ravi-mosaicml commented Dec 21, 2021

Thanks for sketching out these use cases. However, would these use cases still be the same training run? I feel like changing algorithms, datasets, or schedulers would now be a different training run, in which case a new trainer (by the convention) should be created.

Using the existing trainer and changing pretty much any parameter mid-training without starting over will cause compatibility issues with algorithms. For example, randaugment (https://github.com/mosaicml/composer/blob/dev/composer/algorithms/randaugment/randaugment.py) modifies the dataset of the dataloader upon the training start event. If you added the randaugment algorithm mid-training (i.e. after training start fires), then it wouldn't do anything; if you change the dataloader mid-training, then the new dataloader wouldn't have randaugment applied. I think it would be too much burden on the user to how to manually apply the algorithms when changing training parameters, so I'd strongly prefer not to allow this (at least, not allow it easily).

If we don't want to go with the convention of "one trainer = one run", then I presume we'd want to have .fit start over with Event.INIT? I presume that the model and dataloaders would be re-used as-is (unless the user passes in new dataloaders), and optimizers and schedulers would be reconstructed (unless the user passes in new optimizer or scheduler hparams). Callbacks and algorithms would need to be able to reset themselves on a 2nd call of the init event. The timer would be reset to 0, and all algorithms and callbacks would fire as if it's a new training run (so, it would not be possible to call .fit for a duration that is less than the training duration.)

I think we still need to decide whether .fit would reset the timer and do effectively a new training run on the existing model (to support the signature / docstring that @A-Jacobson suggested) or extend the current training run (which is where I think we were leaning during the 12/13 meeting). @abhi-mosaic @A-Jacobson @hanlint @siriuslee @growlix @moinnadeem and all -- thoughts?

@ravi-mosaicml ravi-mosaicml added the Needs Design Needs Design Further design is required. Do not start implementation until design questions are reso label Jan 20, 2022
@ravi-mosaicml
Copy link
Contributor Author

@ravi-mosaicml ravi-mosaicml added this to the Backlog milestone Feb 15, 2022
@ravi-mosaicml ravi-mosaicml modified the milestones: Backlog, v0.5 Feb 28, 2022
ravi-mosaicml added a commit that referenced this issue Apr 28, 2022
This PR changes how callbacks are shutdown. The overall goals are:

1. Ensure that any exit traceback is captured via the loggers (e.g. in `WandB` and the `FileLogger`)
2. Not to shutdown the callbacks at the end of `Trainer.fit()`, which is a requirement for #138

To accomplish these goals, this PR changes where `Engine.close` (which shuts down the callbacks) is run. Specifically, this method now runs at the any of these events:
1. Via `atexit`: If the exception causes the python process to crash, `atexit` callbacks will run after the traceback has been printed (test this for yourself! hard to write a test case for this)
2. on `__del__`: If the `trainer` instance is being garbage collected, it should be shut down.
3. If manually invoked via `trainer.close`: If the user is creating multiple trainers, they may want to shut one down at a specific point.

If `engine.close()` is invoked multiple times, it should be a no-op.

Implementation overview:
- Moved `close` to run on `__del__` / `atexit` as described above. The actual shutdown logic needs to be in a static method, so it does not hold a reference count in the `atexit` registry that would prevent garbage collection. Added test cases.
- Added `Event.FIT_END` since closing is now decoupled from the end of training. This allows for callbacks to log stuff at the end of training and flush (but not close) any open file handles.
- Fixed a bug in the `FileLogger` where the prefix was being printed multiple times if `write` was called multiple times per line
- Fixed the documentation for the Event class to show all events in the mock training loop
- Added tests and fixed callbacks and loggers to ensure that `.close()` and `.post_close()` implementations are idempotent.
- Added test cases to ensure that unhandled stack traces are captured in the logfile.
ravi-mosaicml added a commit that referenced this issue Apr 29, 2022
This PR changes how callbacks are shutdown. The overall goals are:

1. Ensure that any exit traceback is captured via the loggers (e.g. in `WandB` and the `FileLogger`)
2. Not to shutdown the callbacks at the end of `Trainer.fit()`, which is a requirement for #138

To accomplish these goals, this PR changes where `Engine.close` (which shuts down the callbacks) is run. Specifically, this method now runs at the any of these events:
1. Via `atexit`: If the exception causes the python process to crash, `atexit` callbacks will run after the traceback has been printed (test this for yourself! hard to write a test case for this)
2. on `__del__`: If the `trainer` instance is being garbage collected, it should be shut down.
3. If manually invoked via `trainer.close`: If the user is creating multiple trainers, they may want to shut one down at a specific point.

If `engine.close()` is invoked multiple times, it should be a no-op.

Implementation overview:
- Moved `close` to run on `__del__` / `atexit` as described above. The actual shutdown logic needs to be in a static method, so it does not hold a reference count in the `atexit` registry that would prevent garbage collection. Added test cases.
- Added `Event.FIT_END` since closing is now decoupled from the end of training. This allows for callbacks to log stuff at the end of training and flush (but not close) any open file handles.
- Fixed a bug in the `FileLogger` where the prefix was being printed multiple times if `write` was called multiple times per line
- Fixed the documentation for the Event class to show all events in the mock training loop
- Added tests and fixed callbacks and loggers to ensure that `.close()` and `.post_close()` implementations are idempotent.
- Added test cases to ensure that unhandled stack traces are captured in the logfile.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New (engineering) enhancements, such as features or API changes. Needs Design Needs Design Further design is required. Do not start implementation until design questions are reso
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants