Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix eval duplicate logging issue #2018

Merged
merged 2 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions composer/loggers/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ def epoch_end(self, state: State, logger: Logger) -> None:

if unit == TimeUnit.EPOCH and (cur_epoch % int(self.log_interval) == 0 or cur_epoch == 1):
self.log_to_console(self.logged_metrics, prefix='Train ', state=state)
# Clear logged metrics.
self.logged_metrics = {}
# Always clear logged metrics so they don't get logged in a subsequent eval call. The
# metrics will be recomputed and overridden in future batches so they can be safely
# discarded.
self.logged_metrics = {}

def batch_end(self, state: State, logger: Logger) -> None:
cur_batch = int(state.timestamp.batch)
Expand All @@ -114,7 +116,8 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:

def eval_end(self, state: State, logger: Logger) -> None:
# Log to the console at the end of eval no matter what log interval is selected.
self.log_to_console(state.eval_metric_values, prefix='Eval ', state=state, is_train=False)
self.log_to_console(self.logged_metrics, prefix='Eval ', state=state, is_train=False)
self.logged_metrics = {}

def fit_start(self, state: State, logger: Logger) -> None:
if not self.hparams_already_logged_to_console:
Expand Down
129 changes: 98 additions & 31 deletions tests/loggers/test_console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,39 +84,43 @@ def test_console_logger_interval(console_logger_test_stream, console_logger_test
@pytest.mark.parametrize('max_duration_unit', ['ba', 'ep'])
@pytest.mark.parametrize('eval_interval', [2, 3])
@pytest.mark.parametrize('max_duration', [8, 9])
def test_console_logger_interval_with_eval(console_logger_test_stream, console_logger_test_file_path, eval_interval,
max_duration, eval_interval_unit, max_duration_unit):

@pytest.mark.parametrize('pass_in_fit', [True, False])
def test_console_logger_fit(
console_logger_test_stream,
console_logger_test_file_path,
eval_interval,
max_duration,
eval_interval_unit,
max_duration_unit,
pass_in_fit,
):
batch_size = 4
dataset_size = 17
eval_batch_size = 2
eval_dataset_size = 25
batches_per_epoch = math.ceil(dataset_size / batch_size)

model = SimpleModel()
trainer = Trainer(model=model,
console_stream=console_logger_test_stream,
eval_interval=f'{eval_interval}{eval_interval_unit}',
log_to_console=True,
progress_bar=False,
train_dataloader=DataLoader(RandomClassificationDataset(size=dataset_size),
batch_size=batch_size),
eval_dataloader=DataLoader(RandomClassificationDataset(size=eval_dataset_size),
batch_size=eval_batch_size),
max_duration=f'{max_duration}{max_duration_unit}')
# 1. Run with empty fit
trainer.fit()
console_logger_test_stream.flush()
# 2. Run again with eval, while passing an eval_dataloader
trainer.eval(eval_dataloader=Evaluator(label='trainer.eval_dataloader',
dataloader=DataLoader(RandomClassificationDataset(size=eval_dataset_size),
batch_size=eval_batch_size)))
console_logger_test_stream.flush()
# 3. Run again with fit
trainer.fit(eval_dataloader=DataLoader(RandomClassificationDataset(size=eval_dataset_size),
batch_size=eval_batch_size),
reset_time=True,
eval_interval=f'{eval_interval}{eval_interval_unit}')
trainer = Trainer(
model=model,
console_stream=console_logger_test_stream,
eval_interval=f'{eval_interval}{eval_interval_unit}',
log_to_console=True,
progress_bar=False,
train_dataloader=DataLoader(RandomClassificationDataset(size=dataset_size), batch_size=batch_size),
eval_dataloader=DataLoader(RandomClassificationDataset(size=eval_dataset_size), batch_size=eval_batch_size),
max_duration=f'{max_duration}{max_duration_unit}',
)
if pass_in_fit:
eval_dataloader = DataLoader(RandomClassificationDataset(size=eval_dataset_size), batch_size=eval_batch_size)
trainer.fit(
eval_dataloader=eval_dataloader,
reset_time=True,
eval_interval=f'{eval_interval}{eval_interval_unit}',
)
else:
trainer.fit()

console_logger_test_stream.flush()
console_logger_test_stream.close()

Expand Down Expand Up @@ -152,12 +156,75 @@ def test_console_logger_interval_with_eval(console_logger_test_stream, console_l
expected_num_eval_lines = expected_num_eval_logging_events * (num_eval_metrics_and_losses_per_logging_event +
num_eval_progress_lines_per_eval_event)

expected_num_eval_lines *= 2 # Because we run fit twice
assert actual_num_eval_log_lines == expected_num_eval_lines


@pytest.mark.parametrize('eval_interval_unit', ['ba', 'ep'])
@pytest.mark.parametrize('max_duration_unit', ['ba', 'ep'])
@pytest.mark.parametrize('eval_interval', [2, 3])
@pytest.mark.parametrize('max_duration', [8, 9])
def test_console_logger_eval(
console_logger_test_stream,
console_logger_test_file_path,
eval_interval,
max_duration,
eval_interval_unit,
max_duration_unit,
):
batch_size = 4
dataset_size = 17
eval_batch_size = 2
eval_dataset_size = 25
batches_per_epoch = math.ceil(dataset_size / batch_size)

model = SimpleModel()
trainer = Trainer(
model=model,
console_stream=console_logger_test_stream,
eval_interval=f'{eval_interval}{eval_interval_unit}',
log_to_console=True,
progress_bar=False,
train_dataloader=DataLoader(RandomClassificationDataset(size=dataset_size), batch_size=batch_size),
eval_dataloader=DataLoader(RandomClassificationDataset(size=eval_dataset_size), batch_size=eval_batch_size),
max_duration=f'{max_duration}{max_duration_unit}',
)

trainer.eval(eval_dataloader=Evaluator(label='trainer.eval_dataloader',
dataloader=DataLoader(RandomClassificationDataset(size=eval_dataset_size),
batch_size=eval_batch_size)),)
console_logger_test_stream.flush()
console_logger_test_stream.close()

with open(console_logger_test_file_path, 'r') as f:
lines = f.readlines()

# Make a regular expression for matches for any line that contains "Eval" followed by
# a colon.
eval_reg_exp = re.compile('Eval *:*')
actual_num_eval_log_lines = sum([1 if bool(eval_reg_exp.search(line)) else 0 for line in lines])

assert model.val_metrics is not None
num_eval_metrics_per_event = len(list(model.val_metrics.keys())) if isinstance(model.val_metrics,
MetricCollection) else 1

if eval_interval_unit == max_duration_unit:
expected_num_eval_logging_events, remainder = divmod(max_duration, eval_interval)
elif eval_interval_unit == 'ba' and max_duration_unit == 'ep':
expected_num_eval_logging_events, remainder = divmod((batches_per_epoch * max_duration), eval_interval)
else: # for the case where eval_interval_unit == 'ep' and max_duration == 'ba'.
batches_per_logging_event = batches_per_epoch * eval_interval
expected_num_eval_logging_events, remainder = divmod(max_duration, batches_per_logging_event)

num_progress_events_due_to_eval_interval = NUM_EVAL_LOGGING_EVENTS
num_eval_progress_lines_per_eval_event = num_progress_events_due_to_eval_interval
# An eval logging event always happens at fit_end, so if one would not normally fall at
# last batch or epoch, then add an extra event to the expected.
if remainder:
expected_num_eval_logging_events += 1

expected_num_eval_logging_events_for_trainer_eval_call = 1
expected_num_eval_lines_in_trainer_eval_call = (
expected_num_eval_logging_events_for_trainer_eval_call *
(num_eval_progress_lines_per_eval_event + num_eval_metrics_per_event))
expected_num_eval_lines += expected_num_eval_lines_in_trainer_eval_call # because we run trainer.eval
expected_num_eval_lines = expected_num_eval_logging_events_for_trainer_eval_call * (
num_eval_progress_lines_per_eval_event + num_eval_metrics_per_event)

assert actual_num_eval_log_lines == expected_num_eval_lines

Expand Down