Skip to content

Commit

Permalink
Semi-SL Semantic Segmentation. Prototype View. (#2156)
Browse files Browse the repository at this point in the history
* added cutmix, filter loss

* added semisl to segnext, added ohem loss

* added ProtoNet

* changed loss handling

* proto_head_debugged

* changes for experiments

* added ham-seg based head for proto

* cleaned code

* continue to refactor

* added larger models configs

* small fix

* minor

* remove ohem loss:
q
|

* fix pre-commit tests

* delete loss

* merge conflict

* black files back

* revert configuration back

* minor EMA change

* revert semisl recipie back

* fix configure test

* fix dual model ema test

* fix tests

* minor rever back

* fix classes in aux head

* reply to comments

* added docstrings

* added articles to description

* added unit tests, integration and e2e

* fix convert via ignore
  • Loading branch information
kprokofi authored May 24, 2023
1 parent 25b8776 commit e1004eb
Show file tree
Hide file tree
Showing 64 changed files with 1,509 additions and 215 deletions.
33 changes: 6 additions & 27 deletions otx/algorithms/common/adapters/mmcv/hooks/dual_model_ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,16 @@ def before_run(self, runner):
self.src_model = getattr(model, self.src_model_name, None)
self.dst_model = getattr(model, self.dst_model_name, None)
if self.src_model and self.dst_model:
self.enabled = True
self.src_params = self.src_model.state_dict(keep_vars=True)
self.dst_params = self.dst_model.state_dict(keep_vars=True)
if runner.epoch == 0 and runner.iter == 0:
# If it's not resuming from a checkpoint
# initialize student model by teacher model
# (teacher model is main target of load/save)
# (if it's resuming there will be student weights in checkpoint. No need to copy)
self._sync_model()
logger.info("Initialized student model by teacher model")
logger.info(f"model_s model_t diff: {self._diff_model()}")

def before_train_epoch(self, runner):
"""Momentum update."""
if self.epoch_momentum > 0.0:
if runner.epoch == self.start_epoch:
self._copy_model()
self.enabled = True

if self.epoch_momentum > 0.0 and self.enabled:
iter_per_epoch = len(runner.data_loader)
epoch_decay = 1 - self.epoch_momentum
iter_decay = math.pow(epoch_decay, self.interval / iter_per_epoch)
Expand All @@ -95,16 +90,7 @@ def before_train_epoch(self, runner):

def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
if not self.enabled:
return

if runner.iter % self.interval != 0:
# Skip update
return

if runner.epoch + 1 < self.start_epoch:
# Just copy parameters before start epoch
self._copy_model()
if not self.enabled or (runner.iter % self.interval != 0):
return

# EMA
Expand All @@ -121,12 +107,6 @@ def _get_model(self, runner):
model = model.module
return model

def _sync_model(self):
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
src_param.data.copy_(dst_param.data)

def _copy_model(self):
with torch.no_grad():
for name, src_param in self.src_params.items():
Expand All @@ -138,7 +118,6 @@ def _ema_model(self):
with torch.no_grad():
for name, src_param in self.src_params.items():
dst_param = self.dst_params[name]
# dst_param.data.mul_(1 - momentum).add_(src_param.data, alpha=momentum)
dst_param.data.copy_(dst_param.data * (1 - momentum) + src_param.data * momentum)

def _diff_model(self):
Expand Down
32 changes: 22 additions & 10 deletions otx/algorithms/segmentation/adapters/mmseg/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def configure_model(
Patch for OMZ backbones
"""
if ir_options is None:
ir_options = {"ir_model_path": None, "ir_weight_path": None, "ir_weight_init": False}
ir_options = {
"ir_model_path": None,
"ir_weight_path": None,
"ir_weight_init": False,
}

cfg.model_task = cfg.model.pop("task", "segmentation")
if cfg.model_task != "segmentation":
Expand All @@ -153,7 +157,11 @@ def is_mmov_model(key: str, value: Any) -> bool:
recursively_update_cfg(
cfg,
is_mmov_model,
{"model_path": ir_model_path, "weight_path": ir_weight_path, "init_weight": ir_weight_init},
{
"model_path": ir_model_path,
"weight_path": ir_weight_path,
"init_weight": ir_weight_init,
},
)

def configure_data(
Expand Down Expand Up @@ -196,19 +204,17 @@ def configure_task(
def configure_decode_head(self, cfg: Config) -> None:
"""Change to incremental loss (ignore mode) and substitute head with otx universal head."""
ignore = cfg.get("ignore", False)
if ignore:
cfg_loss_decode = ConfigDict(
type="CrossEntropyLossWithIgnore",
use_sigmoid=False,
loss_weight=1.0,
)

for head in ("decode_head", "auxiliary_head"):
decode_head = cfg.model.get(head, None)
if decode_head is not None:
decode_head.base_type = decode_head.type
decode_head.type = otx_head_factory
if ignore:
cfg_loss_decode = ConfigDict(
type="CrossEntropyLossWithIgnore",
use_sigmoid=False,
loss_weight=1.0,
)
decode_head.loss_decode = cfg_loss_decode

# pylint: disable=too-many-branches
Expand Down Expand Up @@ -252,6 +258,9 @@ def configure_classes(self, cfg: Config) -> None:
if "SupConDetCon" in cfg.model.type:
cfg.model.num_classes = len(model_classes)

if "auxiliary_head" in cfg.model:
cfg.model.auxiliary_head.num_classes = len(model_classes)

# Task classes
self.org_model_classes = org_model_classes
self.model_classes = model_classes
Expand Down Expand Up @@ -598,7 +607,10 @@ def configure_unlabeled_dataloader(cfg: ConfigDict) -> None:
updated = False
for custom_hook in custom_hooks:
if custom_hook["type"] == "ComposedDataLoadersHook":
custom_hook["data_loaders"] = [*custom_hook["data_loaders"], unlabeled_dataloader]
custom_hook["data_loaders"] = [
*custom_hook["data_loaders"],
unlabeled_dataloader,
]
updated = True
if not updated:
custom_hooks.append(
Expand Down
11 changes: 9 additions & 2 deletions otx/algorithms/segmentation/adapters/mmseg/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

# pylint: disable=invalid-name, too-many-locals, too-many-instance-attributes, super-init-not-called
def get_annotation_mmseg_format(
dataset_item: DatasetItemEntity, labels: List[LabelEntity], use_otx_adapter: bool = True
dataset_item: DatasetItemEntity,
labels: List[LabelEntity],
use_otx_adapter: bool = True,
) -> dict:
"""Function to convert a OTX annotation to mmsegmentation format.
Expand Down Expand Up @@ -259,7 +261,12 @@ def __init__(self, **kwargs):
classes = ["background"] + classes
else:
classes = []
super().__init__(otx_dataset=otx_dataset, pipeline=pipeline, classes=classes, use_otx_adapter=use_otx_adapter)
super().__init__(
otx_dataset=otx_dataset,
pipeline=pipeline,
classes=classes,
use_otx_adapter=use_otx_adapter,
)

self.CLASSES = [label.name for label in self.project_labels]
if "background" not in self.CLASSES:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,10 @@
from .loads import LoadAnnotationFromOTXDataset, LoadImageFromOTXDataset
from .transforms import TwoCropTransform

__all__ = ["MaskCompose", "ProbCompose", "LoadImageFromOTXDataset", "LoadAnnotationFromOTXDataset", "TwoCropTransform"]
__all__ = [
"MaskCompose",
"ProbCompose",
"LoadImageFromOTXDataset",
"LoadAnnotationFromOTXDataset",
"TwoCropTransform",
]
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,15 @@ def __call__(self, results: Dict[str, Any]):
results["img_shape"] = img.size
for key in results.get("seg_fields", []):
results[key] = np.array(
F.resized_crop(Image.fromarray(results[key]), i, j, height, width, self.size, self.interpolation)
F.resized_crop(
Image.fromarray(results[key]),
i,
j,
height,
width,
self.size,
self.interpolation,
)
)

# change order because of difference between numpy and PIL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@
class NeighbourSupport(nn.Module):
"""Neighbour support module."""

def __init__(self, channels, kernel_size=3, key_ratio=8, value_ratio=8, conv_cfg=None, norm_cfg=None):
def __init__(
self,
channels,
kernel_size=3,
key_ratio=8,
value_ratio=8,
conv_cfg=None,
norm_cfg=None,
):
super().__init__()

self.in_channels = channels
Expand Down Expand Up @@ -127,7 +135,12 @@ class CrossResolutionWeighting(nn.Module):
"""Cross resolution weighting."""

def __init__(
self, channels, ratio=16, conv_cfg=None, norm_cfg=None, act_cfg=(dict(type="ReLU"), dict(type="Sigmoid"))
self,
channels,
ratio=16,
conv_cfg=None,
norm_cfg=None,
act_cfg=(dict(type="ReLU"), dict(type="Sigmoid")),
):
super().__init__()

Expand Down Expand Up @@ -175,7 +188,14 @@ def forward(self, x):
class SpatialWeighting(nn.Module):
"""Spatial weighting."""

def __init__(self, channels, ratio=16, conv_cfg=None, act_cfg=(dict(type="ReLU"), dict(type="Sigmoid")), **kwargs):
def __init__(
self,
channels,
ratio=16,
conv_cfg=None,
act_cfg=(dict(type="ReLU"), dict(type="Sigmoid")),
**kwargs,
):
super().__init__()

if isinstance(act_cfg, dict):
Expand Down Expand Up @@ -213,7 +233,15 @@ def forward(self, x):
class SpatialWeightingV2(nn.Module):
"""The original repo: https://github.com/DeLightCMU/PSA."""

def __init__(self, channels, ratio=16, conv_cfg=None, norm_cfg=None, enable_norm=False, **kwargs):
def __init__(
self,
channels,
ratio=16,
conv_cfg=None,
norm_cfg=None,
enable_norm=False,
**kwargs,
):
super().__init__()

self.in_channels = channels
Expand Down Expand Up @@ -367,7 +395,11 @@ def __init__(
self.spatial_weighting = nn.ModuleList(
[
spatial_weighting_module(
channels=channel, ratio=4, conv_cfg=conv_cfg, norm_cfg=norm_cfg, enable_norm=True
channels=channel,
ratio=4,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
enable_norm=True,
)
for channel in branch_channels
]
Expand All @@ -378,7 +410,12 @@ def __init__(
self.neighbour_weighting = nn.ModuleList(
[
NeighbourSupport(
channel, kernel_size=3, key_ratio=8, value_ratio=4, conv_cfg=conv_cfg, norm_cfg=norm_cfg
channel,
kernel_size=3,
key_ratio=8,
value_ratio=4,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
for channel in branch_channels
]
Expand Down Expand Up @@ -1172,10 +1209,18 @@ def __init__(
num_channels = self.stages_spec["num_channels"][i]
num_channels = [num_channels[i] for i in range(len(num_channels))]

setattr(self, f"transition{i}", self._make_transition_layer(num_channels_last, num_channels))
setattr(
self,
f"transition{i}",
self._make_transition_layer(num_channels_last, num_channels),
)

stage, num_channels_last = self._make_stage(
self.stages_spec, i, num_channels, multiscale_output=True, dropout=dropout
self.stages_spec,
i,
num_channels,
multiscale_output=True,
dropout=dropout,
)
setattr(self, f"stage{i}", stage)

Expand Down Expand Up @@ -1316,7 +1361,13 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer)
),
build_norm_layer(self.norm_cfg, in_channels)[1],
build_conv_layer(
self.conv_cfg, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False
self.conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(),
Expand All @@ -1326,7 +1377,14 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer)

return nn.ModuleList(transition_layers)

def _make_stage(self, stages_spec, stage_index, in_channels, multiscale_output=True, dropout=None):
def _make_stage(
self,
stages_spec,
stage_index,
in_channels,
multiscale_output=True,
dropout=None,
):
num_modules = stages_spec["num_modules"][stage_index]
num_branches = stages_spec["num_branches"][stage_index]
num_blocks = stages_spec["num_blocks"][stage_index]
Expand Down
Loading

0 comments on commit e1004eb

Please sign in to comment.