Skip to content

Commit

Permalink
[OTX] Rebase latest changes for MPA merge (#1468)
Browse files Browse the repository at this point in the history
* make logfile saved in save-model-to directory

* enable train

* make main process train also

* bugfix

* refactor multi gpu training

* make all processs have same output path

* prevent child process from being termated by fokred main process

* refactor multigpu implementation

* refactor multi gpu implementation

* modify argument help sentence

* add multi gpu test code

* align with pre-commit test

* separate multi GPU manager class

* modify train cli argument 'save-logs-to' to 'output-path'

* remove tray excpet during killing child process

* apply output_path to all tasks

* change print to logger

* skip multi gpu test if number of gpu is insufficient

* fix typo

* multi gpu test bugfix

* isort fix

* test case bugfix

* fix typo and change some variable name

* [OTX] Apply changes in develop to feature/otx branch (#1436)

* Add tiling module (#1200)

* Update submodule branch (#1222)

* Enhance training schedule for multi-label classification (#1212)

* [CVS-88098] Remove initialize from export functions (#1226)

* Train graph added (#1211)

* Add @attrs decorator for base configs (#1229)

* Pretrained weight download error in MobilenetV3-large-1 of deep-object-reid in SC (#1233)

* [Anomaly Task] Revert hpo template (#1230)

* 🐞 [Anomaly Task] Fix progress bar (#1223)

* [CVS-90555] Fix NaN value in classification (#1244)

* update hpo_config.yaml (#1240)

* [CVS-90400, CVS-91015] NNCF pruning supported tweaks (#1248)

* [Anomaly Task] 🐞 Fix inference when model backbone changes (#1242)

* [CVS-91472] Add pruning_supported value (#1263)

* Pruning supported tweaks (#1256)

* [CVS-90400, CVS-91015] NNCF pruning supported tweaks (#1248)

* Revert "[CVS-90400, CVS-91015] NNCF pruning supported tweaks (#1248)" (#1269)

* [OTE-TEST] Disable obsolete test cases (#1220)

* [OTE-TEST] hot-fix for MPA performance tests (#1273)

* [Anomaly Task] ✨ Upgrade anomalib (#1243)

* Expose early stopping hyper-parameters for all tasks (#1241)

* Resolve pre-commit issues (#1272)

* Remove LazyEarlyStopHook in model_multilabel.py (#1281)

* Removed xfail (#1239)

* Implement IB loss for incremental learning in multi-class classification (#1289)

* Edit num_workers and change MPA repo as a latest (#1314)

* fix annotation bug (#1320)

* Valid POT configs for small HRNet models (#1313)

* Disable NNCF optimization for FP16 models (#1312)

* fliter object less than 1 pixel  (#1305)

* Fix some tests (#1322)

* [Develop] Move drop_last into MPA (#1357)

* Apply changes from releases/v0.3.1-geti1.0.0 (#1337)

* anomaly save_model bugfix (#1300)

* upgrade networkx module version (#1303)

* Forward CVS-94422 size bug fix PR to release branch (#1326)

* Valid POT configs for small HRNet models (#1317)

* [Release branch] Disable NNCF optimization for FP16 models  (#1319)

* [RELEASE] CVS-95549 - Hierarchical classification training failed without obvious reason (#1329)

* Fix h-label: per-group softmax (#1332)

* Fix dataset length bug in mpa task (#1338)

* Fix drop_last key issue for det/set (#1340)

* Hot-fix for OV inference for iseg output (#1345)

* Fix nncf model export bug (#1346)

* Fixed merge error (#1359)

* Update evaluation iou_thr of ins-seg (#1354)

* fix pre-commit test (#1366)

* Fix dataset item tests (#1360)

* Fix OV Inference issues (tiling tests & detection tests) (#1361)

* fix black & add xfail test cases (#1367)

* Update check_nncf_graph. (#1330)

* [Develop] Hot-fix OV inference issue in rotated detection (#1375)

* [Develop] updated documents (#1383)

* [CVS-94911] Fix difference between train and validation normalization pipeline (#1310)

* Update configs for padim model (#1378)

* updated QUICK_START_GUIDE.md (#1397)

* Change ote threshold of openvino test for cls (#1401)

* Normalize top-1 metrics to [0, 1] (#1394)

* Tiling deployment (#1387)

* Replace current saliency map generation with Recipro-CAM for cls (#1363)

* Class-wise saliency map generation for the detection task (#1402)

* Change submodule to develop (#1410)

* Send full dataset to POT optimization function (#1379) & Convert NaN to num to make visible in geti UI (#1413)

* Add active score evaluation to the classification task

* [release/0.4.0][OTX] Enabling GPU execution for exported code (#1416)

* [OTE][Release][XAI] Detection fix two stage bbox_head error (#1414)

* Update SDK commit for exportable code (#1423)

* HRNet-x and HRNe-18--mod2 configs update (#1419)

* [Release] Enable tiling oriented detection for v0.4.0/geti1.1.0 (#1427)

* [OTE][Releases v0.4.0][XAI] Hot-fix for Detection fix two stage error (#1433)

* Temporary MPA branch while dev->otx merge process

* Update doc & install for dev->otx changes

* Update ote_sdk -> otx.api

* Update ote_cli -> otx.cli

* Update external/mmsegmentation -> otx/algorithms/segmentation

* Align saliency map media instantiation over tasks (#1447)

* Update external/d-o-r -> otx/algorithms/classification

* Update external/mmdetection -> otx/algorithms/detection

* Update external/mpa -> otx/algorithms/*

* Fix CLI test run for better error message

* Numpy constraint for deprecated np.bool error

* Capture stderr only

* Align numpy requirement

* [OTX/Anomaly] Add changes from external to otx (#1452)

* Add changes from external to otx

* Address PR comments

* Update config files + remove backbone from base

* Fix pre-merge checks

* Fix pre-commit issues

* Update exportable code commit

* Fix indent error

* Fix flake8 issue

* Resolve softmax issue w/ FIXME for future work

* Add tiling tests

* Revert MPA branch to otx

Signed-off-by: Songki Choi <songki.choi@intel.com>
Co-authored-by: Eugene Liu <eugene.liu@intel.com>
Co-authored-by: Ashwin Vaidya <ashwin.vaidya@intel.com>
Co-authored-by: Jaeguk Hyun <jaeguk.hyun@intel.com>
Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>
Co-authored-by: Jihwan Eom <jihwan.eom@intel.com>
Co-authored-by: Harim Kang <harim.kang@intel.com>
Co-authored-by: Soobee Lee <soobee.lee@intel.com>
Co-authored-by: Lee, Soobee <soobeele@intel.com>
Co-authored-by: Emily Chun <emily.chun@intel.com>
Co-authored-by: ljcornel <ludo.cornelissen@intel.com>
Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>
Co-authored-by: dlyakhov <daniil.lyakhov@intel.com>
Co-authored-by: kprokofi <kirill.prokofiev@intel.com>
Co-authored-by: Sungman Cho <sungman.cho@intel.com>
Co-authored-by: Yunchu Lee <yunchu.lee@intel.com>
Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
Co-authored-by: Alexander Dokuchaev <alexander.dokuchaev@intel.com>
Co-authored-by: Vladislav Sovrasov <vladislav.sovrasov@intel.com>
Co-authored-by: Evgeny Tsykunov <e.tsykunov@gmail.com>
Co-authored-by: Galina Zalesskaya <galina.zalesskaya@intel.com>
Co-authored-by: dongkwan-kim <dongkwan.kim@intel.com>

* Apply latest MPA openvinotoolkit/model_preparation_algorithm#105

Signed-off-by: Songki Choi <songki.choi@intel.com>

Signed-off-by: Songki Choi <songki.choi@intel.com>
Co-authored-by: eunwoosh <eunwoo.shin@intel.com>
Co-authored-by: Eugene Liu <eugene.liu@intel.com>
Co-authored-by: Ashwin Vaidya <ashwin.vaidya@intel.com>
Co-authored-by: Jaeguk Hyun <jaeguk.hyun@intel.com>
Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>
Co-authored-by: Jihwan Eom <jihwan.eom@intel.com>
Co-authored-by: Harim Kang <harim.kang@intel.com>
Co-authored-by: Soobee Lee <soobee.lee@intel.com>
Co-authored-by: Lee, Soobee <soobeele@intel.com>
Co-authored-by: Emily Chun <emily.chun@intel.com>
Co-authored-by: ljcornel <ludo.cornelissen@intel.com>
Co-authored-by: dlyakhov <daniil.lyakhov@intel.com>
Co-authored-by: kprokofi <kirill.prokofiev@intel.com>
Co-authored-by: Sungman Cho <sungman.cho@intel.com>
Co-authored-by: Yunchu Lee <yunchu.lee@intel.com>
Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
Co-authored-by: Alexander Dokuchaev <alexander.dokuchaev@intel.com>
Co-authored-by: Vladislav Sovrasov <vladislav.sovrasov@intel.com>
Co-authored-by: Evgeny Tsykunov <e.tsykunov@gmail.com>
Co-authored-by: Galina Zalesskaya <galina.zalesskaya@intel.com>
Co-authored-by: dongkwan-kim <dongkwan.kim@intel.com>
  • Loading branch information
22 people committed Dec 29, 2022
1 parent 91bef3a commit 2b77d44
Show file tree
Hide file tree
Showing 25 changed files with 115 additions and 162 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
rev: v2.7.1
hooks:
- id: prettier
exclude: "external/deep-object-reid|otx/mpa|otx/recipes"
exclude: "external|otx/mpa|otx/recipes"

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v0.971"
Expand Down
18 changes: 1 addition & 17 deletions otx/mpa/cls/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ def is_mmov_model(k, v):
if cfg.model.get("multilabel", False) or cfg.model.get("hierarchical", False):
cfg.model.head.pop("topk", None)

# Other hyper-parameters
if cfg.get("hyperparams", False):
self.configure_hyperparams(cfg, training, **kwargs)

return cfg

@staticmethod
Expand Down Expand Up @@ -173,7 +169,7 @@ def configure_task(cfg, training, model_meta=None, **kwargs):

model_tasks, dst_classes = None, None
model_classes, data_classes = [], []
train_data_cfg = Stage.get_train_data_cfg(cfg)
train_data_cfg = Stage.get_data_cfg(cfg, "train")
if isinstance(train_data_cfg, list):
train_data_cfg = train_data_cfg[0]

Expand Down Expand Up @@ -294,18 +290,6 @@ def configure_task(cfg, training, model_meta=None, **kwargs):
cfg.model.head.num_old_classes = len(old_classes)
return model_tasks, dst_classes

@staticmethod
def configure_hyperparams(cfg, training, **kwargs):
hyperparams = kwargs.get("hyperparams", None)
if hyperparams is not None:
bs = hyperparams.get("bs", None)
if bs is not None:
cfg.data.samples_per_gpu = bs

lr = hyperparams.get("lr", None)
if lr is not None:
cfg.optimizer.lr = lr

def _put_model_on_gpu(self, model, cfg):
if torch.cuda.is_available():
model = model.cuda()
Expand Down
12 changes: 8 additions & 4 deletions otx/mpa/cls/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,14 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):

# prepare data loaders
datasets = datasets if isinstance(datasets, (list, tuple)) else [datasets]
train_data_cfg = Stage.get_train_data_cfg(cfg)
drop_last = train_data_cfg.drop_last if train_data_cfg.get("drop_last", False) else False

# updated to adapt list of datasets for the 'train'
train_data_cfg = Stage.get_data_cfg(cfg, "train")
otx_dataset = train_data_cfg.get("otx_dataset", None)
drop_last = False
dataset_len = len(otx_dataset) if otx_dataset else 0
# if task == h-label & dataset size is bigger than batch size
if train_data_cfg.get("hierarchical_info", None) and dataset_len > cfg.data.get("samples_per_gpu", 2):
drop_last = True
# updated to adapt list of dataset for the 'train'
data_loaders = []
sub_loaders = []
for ds in datasets:
Expand Down
8 changes: 4 additions & 4 deletions otx/mpa/det/incremental/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def configure_classes(self, cfg, task_adapt_type, task_adapt_op):

def configure_task_data_pipeline(self, cfg, model_classes, data_classes):
# Trying to alter class indices of training data according to model class order
tr_data_cfg = self.get_train_data_cfg(cfg)
tr_data_cfg = self.get_data_cfg(cfg, "train")
class_adapt_cfg = dict(type="AdaptClassLabels", src_classes=data_classes, dst_classes=model_classes)
pipeline_cfg = tr_data_cfg.pipeline
for i, op in enumerate(pipeline_cfg):
Expand All @@ -120,7 +120,7 @@ def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model
self.configure_ema(cfg)
self.configure_val_interval(cfg)
else:
src_data_cfg = self.get_train_data_cfg(cfg)
src_data_cfg = self.get_data_cfg(cfg, "train")
src_data_cfg.pop("old_new_indices", None)

def configure_bbox_head(self, cfg, model_classes):
Expand All @@ -135,7 +135,7 @@ def configure_bbox_head(self, cfg, model_classes):
# TODO Remove this part
# This is not related with patching bbox head
# This might be useless when semisl using MPADetDataset
tr_data_cfg = self.get_train_data_cfg(cfg)
tr_data_cfg = self.get_data_cfg(cfg, "train")
if tr_data_cfg.type != "MPADetDataset":
tr_data_cfg.img_ids_dict = self.get_img_ids_for_incr(cfg, org_model_classes, model_classes)
tr_data_cfg.org_type = tr_data_cfg.type
Expand Down Expand Up @@ -225,7 +225,7 @@ def get_img_ids_for_incr(cfg, org_model_classes, model_classes):
new_classes = np.setdiff1d(model_classes, org_model_classes).tolist()
old_classes = np.intersect1d(org_model_classes, model_classes).tolist()

src_data_cfg = self.get_train_data_cfg(cfg)
src_data_cfg = self.get_data_cfg(cfg, "train")

ids_old, ids_new = [], []
data_cfg = cfg.data.test.copy()
Expand Down
8 changes: 4 additions & 4 deletions otx/mpa/det/inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
"""
self._init_logger()
mode = kwargs.get("mode", "train")
eval = kwargs.get("eval", False)
dump_features = kwargs.get("dump_features", False)
dump_saliency_map = kwargs.get("dump_saliency_map", False)
eval = kwargs.pop("eval", False)
dump_features = kwargs.pop("dump_features", False)
dump_saliency_map = kwargs.pop("dump_saliency_map", False)
if mode not in self.mode:
return {}

Expand Down Expand Up @@ -90,7 +90,7 @@ def infer(self, cfg, eval=False, dump_features=False, dump_saliency_map=False):
input_source = cfg.get("input_source")
logger.info(f"Inferring on input source: data.{input_source}")
if input_source == "train":
src_data_cfg = self.get_train_data_cfg(cfg)
src_data_cfg = self.get_data_cfg(cfg, input_source)
else:
src_data_cfg = cfg.data[input_source]
data_cfg.test_mode = src_data_cfg.get("test_mode", False)
Expand Down
2 changes: 1 addition & 1 deletion otx/mpa/det/semisl/inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def infer(self, cfg, eval=False, dump_features=False, dump_saliency_map=False):
input_source = cfg.get("input_source")
logger.info(f"Inferring on input source: data.{input_source}")
if input_source == "train":
src_data_cfg = self.get_train_data_cfg(cfg)
src_data_cfg = self.get_data_cfg(cfg, "train")
else:
src_data_cfg = cfg.data[input_source]
data_cfg.test_mode = src_data_cfg.get("test_mode", False)
Expand Down
2 changes: 1 addition & 1 deletion otx/mpa/det/semisl/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model
self.configure_task_adapt_hook(cfg, org_model_classes, model_classes)
self.configure_val_interval(cfg)
else:
src_data_cfg = self.get_train_data_cfg(cfg)
src_data_cfg = self.get_data_cfg(cfg, "train")
src_data_cfg.pop("old_new_indices", None)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions otx/mpa/det/semisl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def train_worker(gpu, target_classes, datasets, cfg, distributed=False, validate
model = build_detector(cfg.model)
model.CLASSES = target_classes
# Do clustering for SSD model
# TODO[JAEGUK]: Temporal Disable cluster_anchors for SSD model
# TODO[JAEGUK]: Temporary disable cluster_anchors for SSD model
# if hasattr(cfg.model, 'bbox_head') and hasattr(cfg.model.bbox_head, 'anchor_generator'):
# if getattr(cfg.model.bbox_head.anchor_generator, 'reclustering_anchors', False):
# train_cfg = Stage.get_train_data_cfg(cfg)
# train_cfg = Stage.get_data_cfg(cfg, "train")
# train_dataset = train_cfg.get('otx_dataset', None)
# cfg, model = cluster_anchors(cfg, train_dataset, model)
train_detector(model, datasets, cfg, distributed=distributed, validate=True, timestamp=timestamp, meta=meta)
2 changes: 1 addition & 1 deletion otx/mpa/det/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def configure_data(self, cfg, data_cfg, training, **kwargs):
cfg.data.train.type = super_type
if training:
if "dataset" in cfg.data.train:
train_cfg = self.get_train_data_cfg(cfg)
train_cfg = self.get_data_cfg(cfg, "train")
train_cfg.otx_dataset = cfg.data.train.pop("otx_dataset", None)
train_cfg.labels = cfg.data.train.get("labels", None)
train_cfg.data_classes = cfg.data.train.pop("data_classes", None)
Expand Down
11 changes: 9 additions & 2 deletions otx/mpa/det/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):

# Data
datasets = [build_dataset(cfg.data.train)]
cfg.data.val.samples_per_gpu = cfg.data.get("samples_per_gpu", 1)

# FIXME: scale_factors is fixed at 1 even batch_size > 1 in simple_test_mask
# Need to investigate, possibly due to OpenVINO
if "roi_head" in model_cfg.model:
if "mask_head" in model_cfg.model.roi_head:
cfg.data.val.samples_per_gpu = 1

if hasattr(cfg, "hparams"):
if cfg.hparams.get("adaptive_anchor", False):
Expand Down Expand Up @@ -109,10 +116,10 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
self._modify_cfg_for_distributed(model, cfg)

# Do clustering for SSD model
# TODO[JAEGUK]: Temporal Disable cluster_anchors for SSD model
# TODO[JAEGUK]: Temporary disable cluster_anchors for SSD model
# if hasattr(cfg.model, 'bbox_head') and hasattr(cfg.model.bbox_head, 'anchor_generator'):
# if getattr(cfg.model.bbox_head.anchor_generator, 'reclustering_anchors', False):
# train_cfg = Stage.get_train_data_cfg(cfg)
# train_cfg = Stage.get_data_cfg(cfg, "train")
# train_dataset = train_cfg.get('otx_dataset', None)
# cfg, model = cluster_anchors(cfg, train_dataset, model)

Expand Down
22 changes: 11 additions & 11 deletions otx/mpa/modules/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@ def __init__(
def after_train_epoch(self, runner):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
if hasattr(runner, "save_ckpt"):
if runner.save_ckpt:
if runner.save_ema_model:
backup_model = runner.model
runner.model = runner.ema_model
runner.logger.info(f"Saving checkpoint at {runner.epoch + 1} epochs")
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
if runner.save_ema_model:
runner.model = backup_model
if hasattr(runner, "save_ckpt") and runner.save_ckpt:
if hasattr(runner, "save_ema_model") and runner.save_ema_model:
backup_model = runner.model
runner.model = runner.ema_model
runner.logger.info(f"Saving checkpoint at {runner.epoch + 1} epochs")
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
if hasattr(runner, "save_ema_model") and runner.save_ema_model:
runner.model = backup_model
runner.save_ema_model = False
runner.save_ckpt = False

@master_only
Expand Down
8 changes: 8 additions & 0 deletions otx/mpa/modules/hooks/recording_forward_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class BaseRecordingForwardHook(ABC):
print(hook.records)
Args:
module (torch.nn.Module): The PyTorch module to be registered in forward pass
fpn_idx (int, optional): The layer index to be processed if the model is a FPN.
Defaults to 0 which uses the largest feature map from FPN.
"""

def __init__(self, module: torch.nn.Module, fpn_idx: int = 0) -> None:
Expand Down Expand Up @@ -241,6 +243,12 @@ def func(self, feature_map: Union[torch.Tensor, Sequence[torch.Tensor]], fpn_idx
"""
Generate the class-wise saliency maps using Recipro-CAM and then normalizing to (0, 255).
Args:
feature_map (Union[torch.Tensor, List[torch.Tensor]]): feature maps from backbone or list of feature maps
from FPN.
fpn_idx (int, optional): The layer index to be processed if the model is a FPN.
Defaults to 0 which uses the largest feature map from FPN.
Returns:
torch.Tensor: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W]
"""
Expand Down
10 changes: 5 additions & 5 deletions otx/mpa/modules/models/classifiers/sam_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def state_dict_hook(module, state_dict, *args, **kwargs):
if backbone_type == "OTXMobileNetV3":
for k, v in state_dict.items():
if k.startswith("backbone"):
k = k.replace("backbone.", "")
k = k.replace("backbone.", "", 1)
elif k.startswith("head"):
k = k.replace("head.", "")
k = k.replace("head.", "", 1)
if "3" in k: # MPA uses "classifier.3", OTX uses "classifier.4". Convert for OTX compatibility.
k = k.replace("3", "4")
if module.multilabel and not module.is_export:
Expand All @@ -119,9 +119,9 @@ def state_dict_hook(module, state_dict, *args, **kwargs):
elif backbone_type == "OTXEfficientNet":
for k, v in state_dict.items():
if k.startswith("backbone"):
k = k.replace("backbone.", "")
k = k.replace("backbone.", "", 1)
elif k.startswith("head"):
k = k.replace("head", "output")
k = k.replace("head", "output", 1)
if not module.hierarchical and not module.is_export:
k = k.replace("fc", "asl")
v = v.t()
Expand All @@ -130,7 +130,7 @@ def state_dict_hook(module, state_dict, *args, **kwargs):
elif backbone_type == "OTXEfficientNetV2":
for k, v in state_dict.items():
if k.startswith("backbone"):
k = k.replace("backbone.", "")
k = k.replace("backbone.", "", 1)
elif k == "head.fc.weight":
k = k.replace("head.fc", "model.classifier")
if not module.hierarchical and not module.is_export:
Expand Down
13 changes: 13 additions & 0 deletions otx/mpa/modules/models/heads/custom_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
#

import torch
import torch.nn.functional as F
from mmcls.models.builder import HEADS
from mmcls.models.heads import LinearClsHead

Expand Down Expand Up @@ -72,6 +74,17 @@ def loss(self, cls_score, gt_label, feature=None):
losses["loss"] = loss
return losses

def simple_test(self, img):
"""Test without augmentation."""
cls_score = self.fc(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
if torch.onnx.is_in_onnx_export():
return cls_score
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None

return self.post_process(pred)

def forward_train(self, x, gt_label):
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, feature=x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
self.hierarchical_info = kwargs.pop("hierarchical_info", None)
assert self.hierarchical_info
super(CustomHierarchicalLinearClsHead, self).__init__(loss=loss)
if self.hierarchical_info["num_multiclass_heads"] + self.hierarchical_info["num_multilabel_classes"] == 0:
raise ValueError("Invalid classification heads configuration")
self.compute_multilabel_loss = False
if self.hierarchical_info["num_multilabel_classes"] > 0:
self.compute_multilabel_loss = build_loss(multilabel_loss)
Expand Down Expand Up @@ -115,14 +117,21 @@ def simple_test(self, img):
"head_idx_to_logits_range"
][i][1],
]
if not torch.onnx.is_in_onnx_export():
multiclass_logit = torch.softmax(multiclass_logit, dim=1)
multiclass_logits.append(multiclass_logit)
multiclass_logits = torch.cat(multiclass_logits, dim=1)
multiclass_pred = torch.softmax(multiclass_logits, dim=1) if multiclass_logits is not None else None
multiclass_pred = torch.cat(multiclass_logits, dim=1) if multiclass_logits else None

if self.compute_multilabel_loss:
multilabel_logits = cls_score[:, self.hierarchical_info["num_single_label_classes"] :]
multilabel_pred = torch.sigmoid(multilabel_logits) if multilabel_logits is not None else None
pred = torch.cat([multiclass_pred, multilabel_pred], axis=1)
if not torch.onnx.is_in_onnx_export():
multilabel_pred = torch.sigmoid(multilabel_logits) if multilabel_logits is not None else None
else:
multilabel_pred = multilabel_logits
if multiclass_pred is not None:
pred = torch.cat([multiclass_pred, multilabel_pred], axis=1)
else:
pred = multilabel_pred
else:
pred = multiclass_pred

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
self.hierarchical_info = kwargs.pop("hierarchical_info", None)
assert self.hierarchical_info
super(CustomHierarchicalNonLinearClsHead, self).__init__(loss=loss)
if self.hierarchical_info["num_multiclass_heads"] + self.hierarchical_info["num_multilabel_classes"] == 0:
raise ValueError("Invalid classification heads configuration")
self.compute_multilabel_loss = False
if self.hierarchical_info["num_multilabel_classes"] > 0:
self.compute_multilabel_loss = build_loss(multilabel_loss)
Expand Down Expand Up @@ -142,14 +144,21 @@ def simple_test(self, img):
"head_idx_to_logits_range"
][i][1],
]
if not torch.onnx.is_in_onnx_export():
multiclass_logit = torch.softmax(multiclass_logit, dim=1)
multiclass_logits.append(multiclass_logit)
multiclass_logits = torch.cat(multiclass_logits, dim=1)
multiclass_pred = torch.softmax(multiclass_logits, dim=1) if multiclass_logits is not None else None
multiclass_pred = torch.cat(multiclass_logits, dim=1) if multiclass_logits else None

if self.compute_multilabel_loss:
multilabel_logits = cls_score[:, self.hierarchical_info["num_single_label_classes"] :]
multilabel_pred = torch.sigmoid(multilabel_logits) if multilabel_logits is not None else None
pred = torch.cat([multiclass_pred, multilabel_pred], axis=1)
if not torch.onnx.is_in_onnx_export():
multilabel_pred = torch.sigmoid(multilabel_logits) if multilabel_logits is not None else None
else:
multilabel_pred = multilabel_logits
if multiclass_pred is not None:
pred = torch.cat([multiclass_pred, multilabel_pred], axis=1)
else:
pred = multilabel_pred
else:
pred = multiclass_pred

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def simple_test(self, img):
cls_score = self.fc(img) * self.scale
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = torch.sigmoid(cls_score) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
return pred
return cls_score
pred = torch.sigmoid(cls_score) if cls_score is not None else None
pred = list(pred.detach().cpu().numpy())
return pred

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def simple_test(self, img):
cls_score = self.classifier(img) * self.scale
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = torch.sigmoid(cls_score) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
return pred
return cls_score
pred = torch.sigmoid(cls_score) if cls_score is not None else None
pred = list(pred.detach().cpu().numpy())
return pred

Expand Down
Loading

0 comments on commit 2b77d44

Please sign in to comment.