Skip to content

Commit

Permalink
Add IterationTimer callback
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim committed Nov 30, 2023
1 parent e538298 commit 2a9313e
Show file tree
Hide file tree
Showing 19 changed files with 611 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/otx/algo/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for OTX custom callbacks."""
137 changes: 137 additions & 0 deletions src/otx/algo/callbacks/iteration_timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Timer for logging iteration time for train, val, and test phases."""

from __future__ import annotations

from collections import defaultdict
from time import time
from typing import TYPE_CHECKING, Any

from lightning import Callback

if TYPE_CHECKING:
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT


class IterationTimer(Callback):
"""Timer for logging iteration time for train, val, and test phases."""

def __init__(
self,
prog_bar: bool = True,
on_step: bool = True,
on_epoch: bool = True,
) -> None:
super().__init__()
self.prog_bar = prog_bar
self.on_step = on_step
self.on_epoch = on_epoch

self.start_time: dict[str, float] = defaultdict(float)
self.end_time: dict[str, float] = defaultdict(float)

def _on_batch_start(self, pl_module: LightningModule, phase: str) -> None:
self.start_time[phase] = time()

if not self.end_time[phase]:
return

name = f"{phase}/data_time"

data_time = self.start_time[phase] - self.end_time[phase]

pl_module.log(
name=name,
value=data_time,
prog_bar=self.prog_bar,
on_step=self.on_step,
on_epoch=self.on_epoch,
)

def _on_batch_end(self, pl_module: LightningModule, phase: str) -> None:
if not self.end_time[phase]:
self.end_time[phase] = time()
return

name = f"{phase}/iter_time"
curr_end_time = time()
iter_time = curr_end_time - self.end_time[phase]
self.end_time[phase] = curr_end_time

pl_module.log(
name=name,
value=iter_time,
prog_bar=self.prog_bar,
on_step=self.on_step,
on_epoch=self.on_epoch,
)

def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any, # noqa: ANN401
batch_idx: int,
) -> None:
"""Log iteration data time on the training batch start."""
self._on_batch_start(pl_module=pl_module, phase="train")

def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
) -> None:
"""Log iteration time on the training batch start."""
self._on_batch_end(pl_module=pl_module, phase="train")

def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration data time on the validation batch start."""
self._on_batch_start(pl_module=pl_module, phase="validation")

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration time on the validation batch start."""
self._on_batch_end(pl_module=pl_module, phase="validation")

def on_test_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration data time on the test batch start."""
self._on_batch_start(pl_module=pl_module, phase="test")

def on_test_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Log iteration time on the test batch start."""
self._on_batch_end(pl_module=pl_module, phase="test")
2 changes: 1 addition & 1 deletion src/otx/config/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defaults:
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
dirpath: ${base.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
monitor: "val/acc"
mode: "max"
Expand Down
7 changes: 7 additions & 0 deletions src/otx/config/callbacks/iteration_timer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html

iteration_timer:
_target_: otx.algo.callbacks.iteration_timer.IterationTimer
prog_bar: True
on_step: True
on_epoch: True
1 change: 1 addition & 0 deletions src/otx/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- base_config
- base: default
- callbacks: iteration_timer
- data: default
- trainer: default
- model: mmdet
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
#
"""Config data type objects."""
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Optional

from .base import BaseConfig
Expand All @@ -16,14 +16,14 @@ class TrainConfig:
"""DTO for training."""

base: BaseConfig
callbacks: dict
data: DataModuleConfig
trainer: TrainerConfig
model: ModelConfig
logger: dict
recipe: Optional[str] # noqa: FA100
train: bool
test: bool
callbacks: list = field(default_factory=list)


def register_configs() -> None:
Expand Down
Loading

0 comments on commit 2a9313e

Please sign in to comment.