Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IterationTimer callback #2682

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/otx/algo/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for OTX custom callbacks."""
137 changes: 137 additions & 0 deletions src/otx/algo/callbacks/iteration_timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Timer for logging iteration time for train, val, and test phases."""

from __future__ import annotations

from collections import defaultdict
from time import time
from typing import TYPE_CHECKING, Any

from lightning import Callback

if TYPE_CHECKING:
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT


class IterationTimer(Callback):
"""Timer for logging iteration time for train, val, and test phases."""

def __init__(
self,
prog_bar: bool = True,
on_step: bool = True,
on_epoch: bool = True,
) -> None:
super().__init__()
self.prog_bar = prog_bar
self.on_step = on_step
self.on_epoch = on_epoch

self.start_time: dict[str, float] = defaultdict(float)
self.end_time: dict[str, float] = defaultdict(float)

def _on_batch_start(self, pl_module: LightningModule, phase: str) -> None:
self.start_time[phase] = time()

if not self.end_time[phase]:
return

name = f"{phase}/data_time"

data_time = self.start_time[phase] - self.end_time[phase]

pl_module.log(
name=name,
value=data_time,
prog_bar=self.prog_bar,
on_step=self.on_step,
on_epoch=self.on_epoch,
)

def _on_batch_end(self, pl_module: LightningModule, phase: str) -> None:
if not self.end_time[phase]:
self.end_time[phase] = time()
return

name = f"{phase}/iter_time"
curr_end_time = time()
iter_time = curr_end_time - self.end_time[phase]
self.end_time[phase] = curr_end_time

pl_module.log(
name=name,
value=iter_time,
prog_bar=self.prog_bar,
on_step=self.on_step,
on_epoch=self.on_epoch,
)

def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any, # noqa: ANN401
batch_idx: int,
) -> None:
"""Log iteration data time on the training batch start."""
self._on_batch_start(pl_module=pl_module, phase="train")

def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
) -> None:
"""Log iteration time on the training batch end."""
self._on_batch_end(pl_module=pl_module, phase="train")

def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration data time on the validation batch start."""
self._on_batch_start(pl_module=pl_module, phase="validation")

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration time on the validation batch end."""
self._on_batch_end(pl_module=pl_module, phase="validation")

def on_test_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration data time on the test batch start."""
self._on_batch_start(pl_module=pl_module, phase="test")

def on_test_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration time on the test batch end."""
self._on_batch_end(pl_module=pl_module, phase="test")
2 changes: 1 addition & 1 deletion src/otx/config/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defaults:
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
dirpath: ${base.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
monitor: "val/acc"
mode: "max"
Expand Down
7 changes: 7 additions & 0 deletions src/otx/config/callbacks/iteration_timer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html

iteration_timer:
_target_: otx.algo.callbacks.iteration_timer.IterationTimer
prog_bar: True
on_step: True
on_epoch: True
1 change: 1 addition & 0 deletions src/otx/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- base_config
- base: default
- callbacks: iteration_timer
harimkang marked this conversation as resolved.
Show resolved Hide resolved
- data: default
- trainer: default
- model: mmdet
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
#
"""Config data type objects."""
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Optional

from .base import BaseConfig
Expand All @@ -16,14 +16,14 @@ class TrainConfig:
"""DTO for training."""

base: BaseConfig
callbacks: dict
data: DataModuleConfig
trainer: TrainerConfig
model: ModelConfig
logger: dict
recipe: Optional[str] # noqa: FA100
train: bool
test: bool
callbacks: list = field(default_factory=list)


def register_configs() -> None:
Expand Down
Loading
Loading