Skip to content

Commit

Permalink
user lightning gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed Oct 4, 2024
1 parent c3e5f64 commit 8665486
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 23 deletions.
26 changes: 3 additions & 23 deletions src/otx/algo/instance_segmentation/maskdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from __future__ import annotations

import copy
import itertools
from typing import Any, Callable
from typing import Any

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -352,12 +351,8 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict[s
"""Configure an optimizer and learning-rate schedulers."""
param_groups = self._get_optim_params(self.model)
optimizer = self.optimizer_callable(param_groups)
optimizer_with_grad_clip = MaskDINOR50._add_grad_clipping(optimizer, clip_gradient_value=0.01)(
param_groups,
optimizer.defaults["lr"],
)

schedulers = self.scheduler_callable(optimizer_with_grad_clip)
schedulers = self.scheduler_callable(optimizer)

def ensure_list(item: Any) -> list: # noqa: ANN401
return item if isinstance(item, list) else [item]
Expand All @@ -371,7 +366,7 @@ def ensure_list(item: Any) -> list: # noqa: ANN401
lr_scheduler_config["monitor"] = scheduler.monitor
lr_scheduler_configs.append(lr_scheduler_config)

return [optimizer_with_grad_clip], lr_scheduler_configs
return [optimizer], lr_scheduler_configs

def _get_optim_params(self, model: nn.Module) -> list[dict[str, Any]]:
"""Get optimizer parameters."""
Expand Down Expand Up @@ -426,21 +421,6 @@ def _get_optim_params(self, model: nn.Module) -> list[dict[str, Any]]:
params.append({"params": [value], **hyperparams})
return params

@staticmethod
def _add_grad_clipping(
optim: torch.optim.Optimizer,
clip_gradient_value: float = 0.01,
) -> torch.optim.Optimizer:
"""Add gradient clipping to the optimizer."""

class GradientClippingOptimizer(optim.__class__):
def step(self, closure: Callable | None = None) -> None:
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_gradient_value)
super().step(closure=closure)

return GradientClippingOptimizer

def _customize_outputs(
self,
outputs: dict[str, Tensor], # type: ignore[override]
Expand Down
1 change: 1 addition & 0 deletions src/otx/recipe/instance_segmentation/maskdino_r50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ callback_monitor: val/map_50

data: ../_base_/data/instance_segmentation.yaml
overrides:
gradient_clip_val: 0.01
callbacks:
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ callback_monitor: val/map_50

data: ../_base_/data/instance_segmentation.yaml
overrides:
gradient_clip_val: 0.01
callbacks:
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
init_args:
Expand Down

0 comments on commit 8665486

Please sign in to comment.