-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Added Ray Logging to MosaicTrainer #29620
Changes from 76 commits
364d7f6
cfa593a
919aea4
7f3b551
0871fa4
f083e1b
536a381
9d15416
7b509da
2712fda
1851322
ae1b9a6
5d599e4
5da865d
0e42dec
104075f
161009e
72dc74f
faa2e9b
810428c
3d32969
b96f384
1fe7989
d39b63d
de77a64
b297926
00bd1f8
be4c311
a93ddf9
8b324af
65bf8bb
2b57038
5c96046
7177134
3617626
1563395
9fb684e
d27c20b
e1e92e4
8c64463
4fec09e
652c12f
31914cf
706d971
797b4df
715c37f
175aeaa
82305c4
883d2e5
8c3306c
69f60f9
f168f6c
9e3d119
858973b
4654631
48b48bb
20dc4f8
e401eff
eb0c97d
928b529
3529b77
5e956cc
7ddee9b
9d1a302
1a2c4b1
7fcb7b0
9b3d9e7
4cf8fa0
cbc1c6e
1fc42f4
acbcc2a
0c79025
3d84dcb
2ce2b41
1993c0f
3f15a5a
36393e9
09acf07
870b8c5
18392ff
875dacc
ac07acc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import Any, Dict, Optional, List | ||
import torch | ||
|
||
from composer.loggers import Logger | ||
from composer.loggers.logger_destination import LoggerDestination | ||
from composer.core.state import State | ||
|
||
from ray.air import session | ||
|
||
|
||
class RayLogger(LoggerDestination): | ||
"""A logger to relay information logged by composer models to ray. | ||
|
||
This logger allows utilizing all necessary logging and logged data handling provided | ||
by the Composer library. All the logged information is saved in the data dictionary | ||
every time a new information is logged, but to reduce unnecessary reporting, the | ||
most up-to-date logged information is reported as metrics every batch checkpoint and | ||
epoch checkpoint (see Composer's Event module for more details). | ||
|
||
Because ray's metric dataframe will not include new keys that is reported after the | ||
very first report call, any logged information with the keys not included in the | ||
first batch checkpoint would not be retrievable after training. In other words, if | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are users expected to know what these keys upfront? Looking at the Mosaic code, it seems that these keys are automatically added by Mosaic algorithms and callbacks, so I don't think users are aware of what these keys are in order to provide them here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should fix the underlying bug There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
the log level is greater than `LogLevel.BATCH` for some data, they would not be | ||
present in `Result.metrics_dataframe`. To allow preserving those information, the | ||
user can provide keys to be always included in the reported data by using `keys` | ||
argument in the constructor. For `MosaicTrainer`, use | ||
`trainer_init_config['log_keys']` to populate these keys. | ||
|
||
Note that in the Event callback functions, we remove unused variables, as this is | ||
practiced in Mosaic's composer library. | ||
|
||
Args: | ||
keys: the key values that will be included in the reported metrics. | ||
""" | ||
|
||
def __init__(self, keys: List[str] = None) -> None: | ||
ilee300a marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.data = {} | ||
if keys: | ||
for key in keys: | ||
self.data[key] = None | ||
|
||
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: | ||
self.data.update(metrics.items()) | ||
for key, val in self.data.items(): | ||
if isinstance(val, torch.Tensor): | ||
self.data[key] = val.item() | ||
|
||
def epoch_checkpoint(self, state: State, logger: Logger) -> None: | ||
del logger # unused | ||
session.report(self.data) | ||
ilee300a marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def fit_end(self, state: State, logger: Logger) -> None: | ||
# report at close in case the trainer stops in the middle of an epoch. | ||
# this may be double counted with epoch checkpoint. | ||
del logger # unused | ||
session.report(self.data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't report on at both the batch level and the epoch level. Each call to For now, I would say let's just log only at every epoch. We can see in the future if we want to give users the ability to configure this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we'll definitely have users want to do it on either level - that was the case with HF, where we started with epochs only and had to add steps too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Completely agree @Yard1. I’m thinking we can default to epoch for now and then add batch support in a follow up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been updated! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mentioned change above has been applied. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,7 +71,7 @@ def trainer_init_per_worker(config): | |
weight_decay=2.0e-3, | ||
) | ||
|
||
if config.pop("eval", False): | ||
if config.pop("should_eval", False): | ||
config["eval_dataloader"] = evaluator | ||
|
||
return composer.trainer.Trainer( | ||
|
@@ -85,9 +85,17 @@ def trainer_init_per_worker(config): | |
def test_mosaic_cifar10(ray_start_4_cpus): | ||
from ray.train.examples.mosaic_cifar10_example import train_mosaic_cifar10 | ||
|
||
_ = train_mosaic_cifar10() | ||
result = train_mosaic_cifar10().metrics_dataframe | ||
|
||
# TODO : add asserts once reporting has been integrated | ||
# check the max epoch value | ||
assert result["epoch"][result.index[-1]] == 1 | ||
|
||
# check train_iterations | ||
assert result["_training_iteration"][result.index[-1]] == 3 | ||
|
||
# check metrics/train/Accuracy has increased | ||
acc = list(result["metrics/train/Accuracy"]) | ||
assert acc[-1] > acc[0] | ||
|
||
|
||
def test_init_errors(ray_start_4_cpus): | ||
|
@@ -149,6 +157,10 @@ class DummyCallback(Callback): | |
def fit_start(self, state: State, logger: Logger) -> None: | ||
raise ValueError("Composer Callback object exists.") | ||
|
||
class DummyMonitorCallback(Callback): | ||
def fit_start(self, state: State, logger: Logger) -> None: | ||
logger.log_metrics({"dummy_callback": "test"}) | ||
|
||
# DummyLogger should not throw an error since it should be removed before `fit` call | ||
trainer_init_config = { | ||
"max_duration": "1ep", | ||
|
@@ -175,6 +187,86 @@ def fit_start(self, state: State, logger: Logger) -> None: | |
trainer.fit() | ||
assert e == "Composer Callback object exists." | ||
|
||
trainer_init_config["callbacks"] = DummyMonitorCallback() | ||
trainer = MosaicTrainer( | ||
trainer_init_per_worker=trainer_init_per_worker, | ||
trainer_init_config=trainer_init_config, | ||
scaling_config=scaling_config, | ||
) | ||
|
||
result = trainer.fit() | ||
|
||
assert "dummy_callback" in result.metrics | ||
assert result.metrics["dummy_callback"] == "test" | ||
|
||
|
||
def test_metrics_key(ray_start_4_cpus): | ||
from ray.train.mosaic import MosaicTrainer | ||
|
||
"""Tests if `log_keys` defined in `trianer_init_config` appears in result | ||
metrics_dataframe. | ||
""" | ||
trainer_init_config = { | ||
"max_duration": "1ep", | ||
"should_eval": True, | ||
"log_keys": ["metrics/my_evaluator/Accuracy"], | ||
} | ||
|
||
trainer = MosaicTrainer( | ||
trainer_init_per_worker=trainer_init_per_worker, | ||
trainer_init_config=trainer_init_config, | ||
scaling_config=scaling_config, | ||
) | ||
|
||
result = trainer.fit() | ||
|
||
# check if the passed in log key exists | ||
assert "metrics/my_evaluator/Accuracy" in result.metrics_dataframe.columns | ||
|
||
|
||
def test_monitor_callbacks(ray_start_4_cpus): | ||
from ray.train.mosaic import MosaicTrainer | ||
|
||
# Test Callbacks involving logging (SpeedMonitor, LRMonitor) | ||
from composer.callbacks import SpeedMonitor, LRMonitor, GradMonitor | ||
|
||
trainer_init_config = { | ||
"max_duration": "1ep", | ||
"should_eval": True, | ||
} | ||
trainer_init_config["log_keys"] = [ | ||
"grad_l2_norm/step", | ||
] | ||
trainer_init_config["callbacks"] = [ | ||
SpeedMonitor(window_size=3), | ||
LRMonitor(), | ||
GradMonitor(), | ||
] | ||
|
||
trainer = MosaicTrainer( | ||
trainer_init_per_worker=trainer_init_per_worker, | ||
trainer_init_config=trainer_init_config, | ||
scaling_config=scaling_config, | ||
) | ||
|
||
result = trainer.fit() | ||
|
||
assert len(result.metrics_dataframe) == 2 | ||
|
||
metrics_columns = result.metrics_dataframe.columns | ||
columns_to_check = [ | ||
"wall_clock/train", | ||
"wall_clock/val", | ||
"wall_clock/total", | ||
"lr-DecoupledSGDW/group0", | ||
"grad_l2_norm/step", | ||
] | ||
for column in columns_to_check: | ||
assert column in metrics_columns, column + " is not found" | ||
assert result.metrics_dataframe[column].isnull().sum() == 0, ( | ||
column + " column has a null value" | ||
) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make these newly added tests more robust?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tests have been updated to check
|
||
if __name__ == "__main__": | ||
import sys | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is batch size for training const and this inline number ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated so that it is
BATCH_SIZE *10
just like the train dataset