Skip to content

Commit

Permalink
Add State.eval_timestamp and State.predict_timestamp
Browse files Browse the repository at this point in the history
- Added `eval_timestamp` and `predict_timestamp` as attributes on the State for tracking evaluation and prediction progress. This is useful for callbacks and loggers to track where we are in the current evaluation or prediction dataloader (for example, logging metrics where the X axis is the evaluation batch number, and the Y axis is some metric, like accuracy, for that batch).
- Added new attributes to the state, instead of hot-swapping `state.timestamp`, since it is still useful to know the training batch number during evaluation (e.g. track how evaluation improves as training progresses)
- Added tests to ensure that the timestamp is properly set.

TODO:
- [ ] Merge mosaicml#948
  • Loading branch information
ravi-mosaicml committed May 9, 2022
1 parent 96fdb93 commit 4c748a8
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 12 deletions.
8 changes: 8 additions & 0 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ class State(Serializable):
loss (torch.Tensor | Sequence[torch.Tensor]): The most recently computed loss.
outputs (torch.Tensor | Sequence[torch.Tensor]): The most recently computed output from the model's forward pass.
timestamp (Timestamp): The current training timestamp.
eval_timestamp (Timestamp): The timestamp for the current evaluation dataloader. This timestamp is reset
before the dataloader is evaluated. The :attr:`~Timestamp.epoch` attribute for this timestamp is always
``0``.
predict_timestamp (Timestamp): The timestamp for the current prediction dataloader. This timestamp is reset
before the dataloader is used. The :attr:`~Timestamp.epoch` attribute for this timestamp is always
``0``.
serialized_attributes (List[str]): The names of the attribute which are serialized in a checkpoint.
By default, the following attributes are serialized:
Expand Down Expand Up @@ -232,6 +238,8 @@ def __init__(
self._evaluators = list(ensure_tuple(evaluators))

self.timestamp = Timestamp()
self.eval_timestamp = Timestamp()
self.predict_timestamp = Timestamp()
self._precision = Precision(precision)

if optimizers is None:
Expand Down
48 changes: 39 additions & 9 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,12 @@ def _spin_dataloaders(self):
for _ in dataloader:
break

def _accumulate_samples_and_tokens_across_ranks(self, num_samples: int, num_tokens: int) -> Tuple[int, int]:
"""Accumulate the number of samples and tokens across ranks. Returns a (num_samples, num_tokens) tuple."""
tensor = self._device.tensor_to_device(torch.tensor([num_samples, num_tokens], dtype=torch.int))
dist.all_reduce(tensor, reduce_operation="SUM")
return int(tensor[0].cpu().item()), int(tensor[1].cpu().item())

def _train_loop(self) -> None:
"""Run training for the specified number of epochs and log results."""
# print training start
Expand Down Expand Up @@ -1477,13 +1483,6 @@ def _train_loop(self) -> None:

self.engine.run_event(Event.AFTER_DATALOADER)

num_samples_in_batch = self._device.tensor_to_device(
torch.tensor([self.state.batch_num_samples], dtype=torch.int))
num_tokens_in_batch = self._device.tensor_to_device(
torch.tensor([self.state.batch_num_tokens], dtype=torch.int))
dist.all_reduce(num_samples_in_batch, reduce_operation="SUM")
dist.all_reduce(num_tokens_in_batch, reduce_operation="SUM")

self.engine.run_event(Event.BATCH_START)
self.logger.data_batch({
"trainer/global_step": int(self.state.timestamp.batch),
Expand All @@ -1504,9 +1503,14 @@ def _train_loop(self) -> None:
full_loss = total_loss.cpu().item()
self.logger.data_batch({'loss/train': full_loss / dist.get_world_size()})

total_num_samples, total_num_tokens = self._accumulate_samples_and_tokens_across_ranks(
num_samples=self.state.batch_num_samples,
num_tokens=self.state.batch_num_tokens,
)

self.state.timestamp = self.state.timestamp.to_next_batch(
samples=int(num_samples_in_batch.item()),
tokens=int(num_tokens_in_batch.item()),
samples=total_num_samples,
tokens=total_num_tokens,
)

if self._scheduler_step_frequency == TimeUnit.BATCH:
Expand Down Expand Up @@ -1805,6 +1809,9 @@ def predict(self, dataloader: Union[DataLoader, DataSpec], subset_num_batches: i
self.state.set_dataloader(data_spec.dataloader, "predict", subset_num_batches)
assert self.state.dataloader is not None, "Already set the dataloader"

# Reset the predict timestamp
self.state.predict_timestamp = Timestamp()

with torch.no_grad():

self.engine.run_event(Event.PREDICT_START)
Expand All @@ -1831,6 +1838,16 @@ def predict(self, dataloader: Union[DataLoader, DataSpec], subset_num_batches: i
self.state.outputs = self.state.model(self.state.batch)
self.engine.run_event(Event.PREDICT_AFTER_FORWARD)

total_num_samples, total_num_tokens = self._accumulate_samples_and_tokens_across_ranks(
num_samples=self.state.batch_num_samples,
num_tokens=self.state.batch_num_tokens,
)

self.state.predict_timestamp = self.state.predict_timestamp.to_next_batch(
samples=total_num_samples,
tokens=total_num_tokens,
)

self.engine.run_event(Event.PREDICT_BATCH_END)

self.engine.run_event(Event.PREDICT_END)
Expand Down Expand Up @@ -1885,6 +1902,9 @@ def eval(
dataloader = DataSpec(dataloader)
data_spec = dataloader

# Reset the eval timestamp
self.state.eval_timestamp = Timestamp()

self.state.model.eval()
with torch.no_grad():
self.state.set_dataloader(data_spec.dataloader, dataloader_label, subset_num_batches)
Expand Down Expand Up @@ -1924,6 +1944,16 @@ def eval(
metrics.update(self.state.outputs, targets)
self._compute_and_log_metrics(dataloader_label=dataloader_label, metrics=metrics, log_level=log_level)

total_num_samples, total_num_tokens = self._accumulate_samples_and_tokens_across_ranks(
num_samples=self.state.batch_num_samples,
num_tokens=self.state.batch_num_tokens,
)

self.state.eval_timestamp = self.state.eval_timestamp.to_next_batch(
samples=total_num_samples,
tokens=total_num_tokens,
)

self.engine.run_event(Event.EVAL_BATCH_END)

self.logger.data_epoch({"epoch": self.state.timestamp.epoch.value})
Expand Down
30 changes: 30 additions & 0 deletions tests/trainer/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,33 @@ def test_predict(self, subset_num_batches: int):
# Validate that the predict events were called the correct number of times
num_predict_batches = subset_num_batches if subset_num_batches >= 0 else len(predict_dl)
_assert_predict_events_called_expected_number_of_times(event_counter_callback, num_predict_batches)

def test_timestamps(self):
# Construct the trainer
event_counter_callback = EventCounterCallback()
trainer = Trainer(
model=SimpleModel(),
callbacks=[event_counter_callback],
)

# Predict on the model
predict_dataloader = DataLoader(dataset=RandomClassificationDataset())
trainer.predict(predict_dataloader)

# Ensure that the predict timestamp matches the number of prediction events
assert event_counter_callback.event_to_num_calls[
Event.PREDICT_BATCH_START] == trainer.state.predict_timestamp.batch
assert trainer.state.predict_timestamp.batch == trainer.state.predict_timestamp.batch_in_epoch

# Ensure that if we predict again, the predict timestamp was reset

# Reset the event counter callback
event_counter_callback.event_to_num_calls = {k: 0 for k in event_counter_callback.event_to_num_calls}

# Predict again
trainer.predict(predict_dataloader)

# Validate the same invariants
assert event_counter_callback.event_to_num_calls[
Event.PREDICT_BATCH_START] == trainer.state.predict_timestamp.batch
assert trainer.state.predict_timestamp.batch == trainer.state.predict_timestamp.batch_in_epoch
40 changes: 37 additions & 3 deletions tests/trainer/test_trainer_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from tests.common.events import EventCounterCallbackHparams


@pytest.mark.filterwarnings(r"ignore:.*No `eval_dataloader` was specified.*")
def test_trainer_eval_only():
# Construct the trainer
trainer = Trainer(model=SimpleModel(),)
trainer = Trainer(model=SimpleModel())

# Evaluate the model
eval_dataloader = DataLoader(dataset=RandomClassificationDataset())
Expand All @@ -32,7 +31,6 @@ def test_trainer_eval_only():
assert trainer.state.current_metrics['eval']['Accuracy'] != 0.0


@pytest.mark.filterwarnings(r"ignore:.*No `eval_dataloader` was specified.*")
def test_trainer_eval_subset_num_batches():
# Construct the trainer
event_counter_callback = EventCounterCallback()
Expand All @@ -55,6 +53,42 @@ def test_trainer_eval_subset_num_batches():
assert event_counter_callback.event_to_num_calls[Event.EVAL_BATCH_START] == 1


def test_trainer_eval_timestamp():
# Construct the trainer
event_counter_callback = EventCounterCallback()
trainer = Trainer(
model=SimpleModel(),
callbacks=[event_counter_callback],
)

# Evaluate the model
eval_dataloader = DataLoader(dataset=RandomClassificationDataset())
trainer.eval(
dataloader=eval_dataloader,
dataloader_label='eval',
metrics=torchmetrics.Accuracy(),
)

# Ensure that the eval timestamp matches the number of evaluation events
assert event_counter_callback.event_to_num_calls[Event.EVAL_BATCH_START] == trainer.state.eval_timestamp.batch
assert trainer.state.eval_timestamp.batch == trainer.state.eval_timestamp.batch_in_epoch

# Ensure that if we eval again, the eval timestamp was reset

# Reset the event counter callback
event_counter_callback.event_to_num_calls = {k: 0 for k in event_counter_callback.event_to_num_calls}

# Eval again
trainer.eval(
dataloader=eval_dataloader,
dataloader_label='eval',
metrics=torchmetrics.Accuracy(),
)
# Validate the same invariants
assert event_counter_callback.event_to_num_calls[Event.EVAL_BATCH_START] == trainer.state.eval_timestamp.batch
assert trainer.state.eval_timestamp.batch == trainer.state.eval_timestamp.batch_in_epoch


@pytest.mark.parametrize("eval_dataloader", [
DataLoader(dataset=RandomClassificationDataset()),
Evaluator(
Expand Down

0 comments on commit 4c748a8

Please sign in to comment.