Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time Abstraction #146

Closed
11 tasks done
ravi-mosaicml opened this issue Dec 10, 2021 · 8 comments · Fixed by #594
Closed
11 tasks done

Time Abstraction #146

ravi-mosaicml opened this issue Dec 10, 2021 · 8 comments · Fixed by #594
Assignees
Labels
enhancement New (engineering) enhancements, such as features or API changes.

Comments

@ravi-mosaicml
Copy link
Contributor

ravi-mosaicml commented Dec 10, 2021

🚀 Time Abstraction

Motivation

There are various measures of time during training, and we need a common steppable abstraction to handle conversion between units. In the CV community, it is common to track time in terms of samples in batches. However, in NLP, it is more common to track time in terms of tokens and the duration of the training process. Here, we propose a time tracking solution.

Implementation

After discussion with @abhi-mosaic and @moinnadeem, we are leaning towards the following design:

  • 1. Time objects will simplify time arithmatic. A Time object consists of an integer and a unit, which will be one of epochs, batches, samples, or tokens. Via overloaded functions, Time objects will support comparisons, addition, and subtraction against other Time objects of same units, and for backwards compatibility, raw integers (though, in that case, a UserWarning will be emitted). They will also have getters to get the underlying value as an integer and unit.

  • 2. A Timer object, attached to the trainer's state, will track epochs, batches, samples, and tokens. Types are Time objects, except for tokens which may be None (for non-NLP training jobs). The timer object will have getters for each of these fields and a single update function that the training loop will call to update the timer at the end of every batch -- e.g. timer.update(samples=X, token=Y).

  • 3. To determine the number of samples and number of tokens, a dataset can provide get_batch_size(batch) and get_num_tokens(batch). If not specified, the default get_batch_size() will be used, and tokens will NOT be tracked.

  • 4. Datasets can optionally provide __len__ and get_num_tokens(). By pytorch convention, __len__ should be the number of samples in the dataset. get_num_tokens can either return a constant number, perform some sort of computation upon initialization to determine the number of tokens in the dataset, or (by default) return None if the number of tokens is unknown.

  • 5. The max_epochs property in the trainer hparams will be replaced with max_duration, where duration can be specified in terms of epochs, steps, samples, or tokens.

  • 6. The trainer will have a function trainer.get_elapsed_duration() that will query the timer object and return a float on [0,1] representing how much of the training process has been completed (relative to the max_duration parameter).

  • 7. The timing module (NOT the timer object) will have a static method like:

    convert(time_string, desired_unit, dataset_num_samples: Optional[int] = None, dataset_num_tokens: Optional[int] = None, max_training_duration: Optional[str] = None, batch_size: Optional[int] = None):
        pass

    This static method performs a static conversion between the specified time string and desired unit. Depending on the conversion being performed, dataset_num_samples, dataset_num_tokens, max_training_duration, and/or batch_size will need to be provided. These parameters must be explicitly provided to emphasize that this function is a static conversion, done at the time of conversion, and may be inaccurate if these parameters later change (e.g. an algorithm changes the training duration). The follow conversions are allowed.
    1. epochs <-> batches, if dataset_num_samples and batch_size are defined
    1. epochs <-> samples, if dataset_num_samples is defined
    1. batches <-> samples, if batch_size is defined
    1. epochs <-> tokens, if dataset_num_tokens is defined.
    1. duration <-> unit of max_duration: You can convert a duration string (e.g. "0.1dur") into the unit (e.g. ep) of max_duration (e.g. "90ep") -- e.g. would return 9
    1 duration <-> other units. If a unit other than that of max_duration is specified, then the conversion will attempt to use one or more of the above conversions to perform it.

  1. We will rewrite all schedulers to query the time object and perform a closed-form calculation to determine the learning rate, using timer.get_elapsed_duration and timer.get_num_XXX calls, so they are compatible with datasets of unknown size or tokens. However, this can be done later, and for the time being, timer.convert calls can be used to properly initialize schedulers upon creation.

TODO

  1. PR 1: Build out the timer, and use the timer to track progress in the training loop. Update the state object. Should be a non-breaking change.
  2. PR 2: Update the rest of the codebase to support timing strings (e.g. in schedulers, checkpoint interval, flush intervals, etc...). If needed, use timer.convert to be compatible with existing pytorch components.
  3. PR 3: Create our own drop-in replacements for the pytorch schedulers that do not depend on timer.convert.
  4. PR 4 (can be concurrent, or maybe should be done with PR 3): Update the algorithms. Try to avoid using the timer in the functional form.

See also

  1. https://github.com/mosaicml/mosaicml/issues/79
  2. https://github.com/mosaicml/mosaicml/pull/85
  3. https://docs.google.com/presentation/d/1ljMB4gdZ2EEjZ7njRDikU9pwvRx0xyCS-TqAwzFe7EU/edit#slide=id.p
@ravi-mosaicml ravi-mosaicml added the enhancement New (engineering) enhancements, such as features or API changes. label Dec 10, 2021
@moinnadeem
Copy link
Contributor

The max_epochs property in the trainer hparams will be replaced with max_duration, where duration can be specified in terms of epochs, steps, samples, or tokens.

We should also do this for validate_frequency, where validate_every_n_batches should be replaced with a validate_frequency option that uses the Time abstraction. Ideally, I'd like to run validation every 1% of steps, rather than a fixed number.

@ravi-mosaicml
Copy link
Contributor Author

Notes from discussion on 12/13:

  • Add in wall clock time to the timer object (e.g. so you can checkpoint every 30 minutes)

@A-Jacobson
Copy link
Contributor

To determine the number of samples and number of tokens, a dataset can provide get_batch_size(batch) and get_num_tokens(batch). If not specified, the default get_batch_size() will be used, and tokens will NOT be tracked.

what's the reasoning for not just pulling the batch_size arg from the dataloader? to deal with partial batches if drop_last isn't enabled?

We will rewrite all schedulers to query the time object and perform a closed-form calculation to determine the learning rate, using timer.get_elapsed_duration and timer.get_num_XXX calls, so they are compatible with datasets of unknown size or tokens. However, this can be done later, and for the time being, timer.convert calls can be used to properly initialize schedulers upon creation.

Maybe this isn't the place for this discussion, but is it absolutely necessary to write all of our own schedulers? I think it's important to keep in mind here that while there has been some consolidation towards cyclic schedules or linear warmup, there are still methods that use arbitrary schedulers. Maybe some way of wrapping default pytorch schedulers or keeping a very simple API for people to define their own would be the more sustainable approach here?

Other than that, LGTM. seems we can track whatever we want (as long as that concept works for that type of run) with minimal added deps.

@abhi-mosaic
Copy link
Contributor

abhi-mosaic commented Dec 14, 2021

Re. batch_size, I think we are planning ahead for variable-batch size algorithms, which are already used in NLP for warmups. So it would be safer to query the current batch size at each step rather than hard-code it at the start.

For schedulers, I think the main concern is that using scheduler.step assumes things about how Time passes, and places time state within the Scheduler (which can fall out of sync), whereas the cleaner way would be to treat the scheduler as a stateless function that returns the decay factor given the current Time, something like scheduler.get_factor(timer).

My hope is that our reimplementations for the common schedulers will actually be cleaner than Pytorchs, almost like one-line functions. And making a custom Scheduler should also be pretty easy. Something like:

class TriangleScheduler():
  def get_factor(timer: Timer) -> float:
    frac = timer.get_current_period_frac_elapsed()
    return frac if frac <= 0.5 else (1 - frac)


class MilestoneDecayScheduler():
  def __init__(self, milestones: List[Time], decay: float):
    super().__init__()
    self.milestones = milestones
    self.decay = decay

  def get_factor(timer: Timer) -> float:
    period_time = timer.get_current_period_time()
    num_drops = [m >= period_time for m in self.milestones]
    return self.decay ** (num_drops)

And YAMLs can look like:

...
max_duration: 90ep
max_duration_base_units: sp  # Need to have base_units in order to handle non-whole-epoch periods 
periods:
  - warmup:
      duration: 0.05dur
      scheduler: LinearWarmup
  - main:
      duration: -1  # equivalent to 0.95dur
      scheduler: MilestoneDecay
        - 0.5dur
        - 0.75dur

whereas our current Trainer is only capable of handling this:

...
max_duration: 90ep
max_duration_base_units: ep  # <<< only whole epochs
periods:
  - warmup:
      duration: 0.05dur # <<< set to int(90*0.05) = 4 epochs :(
      scheduler: LinearWarmup
  - main:
      duration: -1  # equivalent to 0.95dur
      scheduler: MilestoneDecay
        - 0.5dur # <<< set to int(86*0.5) = 43 epochs :(
        - 0.75dur # <<< set to int(86*0.75) = 64 epochs :(

What do you think @A-Jacobson ?

@growlix
Copy link
Contributor

growlix commented Dec 16, 2021

I don't have a ton to add atm.

The only thing that comes to mind is that it looks like the scale schedule algorithm modifies max_epochs (to become max_duration) in a state object, which has two implications to me:

  • Be mindful about using state's (and not hparam's) max_duration
  • Are there going to be scenarios in which time or timer objects are used prior to scale_schedule being called, and if so, will it cause problems that scale_schedule has changed max_duration between calls/uses of time/timer?

For schedulers, I think the main concern is that using scheduler.step assumes things about how Time passes, and places time state within the Scheduler (which can fall out of sync), whereas the cleaner way would be to treat the scheduler as a stateless function that returns the decay factor given the current Time, something like scheduler.get_factor(timer).

Big +1 here

@abhi-mosaic
Copy link
Contributor

abhi-mosaic commented Dec 16, 2021

@growlix Thats a super good point... basically scale_schedule -> maybe rename to -> scale_max_duration has to either:

  • be applied before the Trainer does any fractional dur conversions, which would require some sort of special handling
  • scale * all * time quantities everywhere, which might be hard to track? Not 100% sure about this...

Should scale_max_duration be a trainer flag rather than an algorithm? That would alleviate the issues

@ravi-mosaicml ^^^

@A-Jacobson
Copy link
Contributor

@abhi-mosaic, love the scheduler idea. Is there an issue/pr for that?

ravi-mosaicml added a commit that referenced this issue Dec 17, 2021
Added the `Time` class, `Timer` class, coversion function, and test cases and docstrings for these classes.

Modified the `StringEnum` class to relax the requirement that the lowercase name be the value, since this does not hold for our timing abbreviations.

Other PRs will integrate the Timer into the training loop.
@ravi-mosaicml
Copy link
Contributor Author

To keep it as an algorithm, the engine could ensure that scale schedule is applied first. That will ensure that anything that uses the max duration would see the scaled value. Thoughts on this?

Would rather keep it as an algorithm rather than a trainer property, as our simple trainer is already not-so-simple :D. But don't have a strong preference either way.

This wouldn't fix it if non-duration units are hardcoded elsewhere (e.g. warmup for 2 epochs). That will remain as 2 epochs.

ravi-mosaicml added a commit that referenced this issue Dec 18, 2021
For the timing abstraction (#146), the `DataloaderSpec` needed two addition functions -- `get_num_samples_in_batch` and `get_num_tokens_in_batch`. It was getting messy to pass around function pointers in a named tuple, so instead converted `DataloaderSpec` from a NamedTuple into a regular class called `DataSpec`. Custom datasets can inherit the base `DataSpec` class and override functionality as needed.

Moved the `DataSpec` class to `composer.core`, as the `DataSpec` is now bound directly to the state. #120 will also need this change.

Renamed `train_dataloader` and `eval_dataloader` in the trainer and state to `train_data` and `eval_data`, since it encompasses more than the dataloader.

This PR implements part 3 and 4 of the timing abstraction (#146). The implementation differs from the GH issue by adding `num_tokens`, `num_samples`, `get_batch_size`, and `get_num_tokens` to the new `DataSpec` rather than the pytorch dataset class.
ravi-mosaicml added a commit that referenced this issue Dec 28, 2021
Added the `Time` class, `Timer` class, conversion function, and test cases and docstrings for these classes. This functionality encompasses parts 1 and 2 from #146, with the following changes:

1. `tokens` will be 0 (instead of None) if not being tracked.

Also, modified the `StringEnum` class to relax the requirement that the lowercase name be the value, since this does not hold for our timing abbreviations.

Other PRs will integrate the Timer into the training loop.
ravi-mosaicml added a commit that referenced this issue Jan 8, 2022
#178)

For the timing abstraction (#146), the DataloaderSpec needed two addition functions -- get_num_samples_in_batch and get_num_tokens_in_batch.

Moved the DataSpec class to composer.core, as the DataSpec is now bound directly to the state. #120 will also need this change.

This PR implements part 3 and 4 of the timing abstraction (#146). The implementation differs from the GH issue by adding num_tokens, num_samples, get_batch_size, and get_num_tokens to the new DataSpec rather than the pytorch dataset class.
ravi-mosaicml added a commit that referenced this issue Jan 14, 2022
Updated layer freezing to use the new timing abstraction from #146
ravi-mosaicml added a commit that referenced this issue Jan 14, 2022
Updated the selective backprop API to use the timing abstraction, as detailed in #146
@ravi-mosaicml ravi-mosaicml added this to the Backlog milestone Feb 15, 2022
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this issue Feb 23, 2022
Added the `Time` class, `Timer` class, conversion function, and test cases and docstrings for these classes. This functionality encompasses parts 1 and 2 from mosaicml#146, with the following changes:

1. `tokens` will be 0 (instead of None) if not being tracked.

Also, modified the `StringEnum` class to relax the requirement that the lowercase name be the value, since this does not hold for our timing abbreviations.

Other PRs will integrate the Timer into the training loop.
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this issue Feb 23, 2022
…s 3 and 4 (mosaicml#178)

For the timing abstraction (mosaicml#146), the DataloaderSpec needed two addition functions -- get_num_samples_in_batch and get_num_tokens_in_batch.

Moved the DataSpec class to composer.core, as the DataSpec is now bound directly to the state. mosaicml#120 will also need this change.

This PR implements part 3 and 4 of the timing abstraction (mosaicml#146). The implementation differs from the GH issue by adding num_tokens, num_samples, get_batch_size, and get_num_tokens to the new DataSpec rather than the pytorch dataset class.
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this issue Feb 23, 2022
Updated layer freezing to use the new timing abstraction from mosaicml#146
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this issue Feb 23, 2022
Updated the selective backprop API to use the timing abstraction, as detailed in mosaicml#146
ravi-mosaicml added a commit that referenced this issue Feb 24, 2022
1. Remove all uses of `state.max_epochs`; replaced with `state.max_duration`
2. Replaced all `state.step` with `int(state.timer.batch)`
3. Replaced all `state.epoch` with `int(state.timer.epoch)`
4. Removed the constraints that max_duration be specified in epochs; it can now be in any unit. Added test cases.
5. Added a helper method to the `Time` class to convert it to a timestring (with a test case)

Closes #146
Closes #229
Closes #512
@ravi-mosaicml ravi-mosaicml removed this from the Backlog milestone Feb 28, 2022
ravi-mosaicml added a commit that referenced this issue Mar 1, 2022
1. Remove all uses of `state.max_epochs`; replaced with `state.max_duration` -- specifically in stochastic depth (#229) and sequential length warmup (#226). However, it does not fix the latter, as we still need to update the algorithm to support max_duration in terms of tokens or samples.
2. Replaced all `state.step` with `int(state.timer.batch)`
3. Replaced all `state.epoch` with `int(state.timer.epoch)`
4. Replaced all `state.batch_idx` with `int(state.timer.batch_in_epoch)`
5. Removed the constraints that max_duration be specified in epochs; it can now be in any unit. Added test cases for specify max duration in batches.
6. Added a helper method to the `Time` class to convert it to a timestring (with a test case)

Closes #146
Closes #229
Closes #512
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New (engineering) enhancements, such as features or API changes.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants