Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimize for visual prompting to 2.0 #3040

Merged
merged 69 commits into from
Mar 11, 2024
Merged
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
f111cf1
Fix return's shape when all predicted masks are zero
sungchul2 Feb 6, 2024
722918c
Swap order of prompts in segment_anything.py
sungchul2 Feb 20, 2024
f9e1254
Merge branch 'v2' into v2_add_vpm_export
sungchul2 Feb 20, 2024
50accde
Enable to export segment anything modules respectively
sungchul2 Feb 21, 2024
d3b9a9a
Update prompt order and refactoring
sungchul2 Feb 22, 2024
9b4b42e
Fix `return_single_mask` to True
sungchul2 Feb 22, 2024
54f31bd
Add OVIR inference logic
sungchul2 Feb 22, 2024
4bc4cfa
Updates for integration test
sungchul2 Feb 23, 2024
7671271
precommit
sungchul2 Feb 23, 2024
0155526
Update prompt getter
sungchul2 Feb 23, 2024
b1da2dc
(WIP) Update zero-shot
sungchul2 Feb 23, 2024
1da36e2
(WIP) Internalize preprocess into model
sungchul2 Feb 27, 2024
eed5e2c
Update reference infos to buffer
sungchul2 Feb 27, 2024
33f6dd1
Update reference_info path
sungchul2 Feb 27, 2024
2d48108
Enable standalone infer logic
sungchul2 Feb 27, 2024
af901e4
Update location of `-load_latest_reference_info` for OVModel
sungchul2 Feb 27, 2024
52a8068
Update recipes
sungchul2 Feb 28, 2024
6ac1ec5
Enable `infer` ov inference
sungchul2 Feb 28, 2024
20ba18e
Enable standalone `infer` logic on OVModel
sungchul2 Feb 28, 2024
9f4c3cd
precommit
sungchul2 Feb 28, 2024
a754c84
Update unittest
sungchul2 Feb 29, 2024
6d34e89
Update for integration test
sungchul2 Feb 29, 2024
81c1672
Merge branch 'v2' into v2_add_vpm_export
sungchul2 Feb 29, 2024
a41feb8
precommit
sungchul2 Feb 29, 2024
eccf7ae
Fix
sungchul2 Feb 29, 2024
2de2fd7
Fix
sungchul2 Feb 29, 2024
fc89558
Fix
sungchul2 Feb 29, 2024
6744dfe
Fix intg tests
sungchul2 Feb 29, 2024
2708d7b
Enable to update `export_args` to `deploy_cfg`
sungchul2 Feb 29, 2024
731447b
Update with walrus
sungchul2 Mar 4, 2024
03d9795
Update to use dict labels
sungchul2 Mar 4, 2024
c7f5f1e
Change openvino model names
sungchul2 Mar 4, 2024
b8d1b51
Update compatibility with zero-shot
sungchul2 Mar 4, 2024
f1aa67c
Refactoring for unnecessary assigned variables
sungchul2 Mar 4, 2024
180990a
Avoid repeatedly executing `torch.cat`
sungchul2 Mar 4, 2024
4220ded
precommit
sungchul2 Mar 4, 2024
a07f4c6
Fix unit test
sungchul2 Mar 4, 2024
27a05cd
Update variable name
sungchul2 Mar 4, 2024
979a01b
Add `example_inputs` in anomaly
sungchul2 Mar 4, 2024
14fa2c7
Fix unit test
sungchul2 Mar 4, 2024
25bb988
Fix
sungchul2 Mar 4, 2024
05bfdfa
Update `model_names` for visual prompting
sungchul2 Mar 5, 2024
05104a6
precommit
sungchul2 Mar 5, 2024
73f081a
Not to include other params in `example_inputs`
sungchul2 Mar 5, 2024
d9335de
Disable condition for visual prompting
sungchul2 Mar 5, 2024
97d674b
Update to `example_inputs`
sungchul2 Mar 5, 2024
0a9eaa1
Remove unused kwargs
sungchul2 Mar 5, 2024
9393dd4
Update
sungchul2 Mar 5, 2024
c86ddca
Remove unused parts
sungchul2 Mar 5, 2024
84ed9f4
Update exported models' names
sungchul2 Mar 5, 2024
2d75684
Remove `example_inputs`
sungchul2 Mar 5, 2024
a80aae2
Add `OTXVisualPromptingModelExporter`
sungchul2 Mar 5, 2024
6a405fe
Update overlapped region refinement
sungchul2 Feb 29, 2024
d4c9a50
Update `export_params`
sungchul2 Mar 6, 2024
1217721
Merge branch 'v2' into v2_add_vpm_optimize
sungchul2 Mar 6, 2024
af662f4
Enable optimize
sungchul2 Mar 6, 2024
e361b74
Add exportable code, but updating `demo.py` is required
sungchul2 Mar 7, 2024
6d515c4
Update model_type
sungchul2 Mar 7, 2024
033826b
Fix integration test
sungchul2 Mar 7, 2024
96f8147
Merge branch 'v2' into v2_add_vpm_optimize
sungchul2 Mar 7, 2024
38c8ba2
Add unit test
sungchul2 Mar 7, 2024
d248f81
Refactoring redundant parts
sungchul2 Mar 7, 2024
5ff1ff5
Revert exportable_code
sungchul2 Mar 7, 2024
9f812a7
Update unit test
sungchul2 Mar 7, 2024
7ceee75
Merge branch 'v2' into v2_add_vpm_optimize
sungchul2 Mar 7, 2024
9da32eb
Update unit test
sungchul2 Mar 7, 2024
3ffb11a
Fix unit test
sungchul2 Mar 7, 2024
c79ad8e
Temporarily fix integration test
sungchul2 Mar 8, 2024
e3d29ae
Revert to disable opening subprocess & add xfail for vpm tasks
sungchul2 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update reference infos to buffer
  • Loading branch information
sungchul2 committed Feb 27, 2024
commit eed5e2c15bd89e7ed5b1b994c952e3d9664f440e
34 changes: 15 additions & 19 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@
import numpy as np

import torch
from torch.nn import Parameter, ParameterDict
from datumaro import Polygon as dmPolygon
from torch import LongTensor, Tensor, nn
from torch.nn import functional as F # noqa: N812
@@ -214,8 +213,7 @@ def __init__(
self.save_outputs = kwargs.pop("save_outputs", True)
self.path_reference_info = kwargs.pop("path_reference_info", "vpm_zsl_reference_infos/{}/reference_info.pt")
super().__init__(*args, **kwargs)

self.reference_info: ParameterDict = ParameterDict()

self.initialize_reference_info()

self.prompt_getter = PromptGetter(image_size=self.image_size)
@@ -243,15 +241,14 @@ def set_default_config(self, **kwargs) -> dict[str, Any]:

def initialize_reference_info(self) -> None:
"""Initialize reference information."""
self.reference_info["reference_feats"] = Parameter(torch.zeros(0, 1, self.embed_dim), requires_grad=False)
self.reference_info["used_indices"] = Parameter(torch.tensor([], dtype=torch.int64), requires_grad=False)
self.register_buffer("reference_feats", torch.zeros(0, 1, self.embed_dim), False)
self.register_buffer("used_indices", torch.tensor([], dtype=torch.int64), False)

def expand_reference_info(self, new_largest_label: int) -> None:
"""Expand reference info dimensions if newly given processed prompts have more lables."""
if new_largest_label > (cur_largest_label := len(self.reference_info["reference_feats"]) - 1):
if new_largest_label > (cur_largest_label := len(self.reference_feats) - 1):
diff = new_largest_label - cur_largest_label
padded_reference_feats = F.pad(self.reference_info["reference_feats"], (0, 0, 0, 0, 0, diff), value=0.0)
self.reference_info["reference_feats"] = Parameter(padded_reference_feats, requires_grad=False)
self.reference_feats = F.pad(self.reference_feats, (0, 0, 0, 0, 0, diff), value=0.0)

@torch.no_grad()
def learn(
@@ -260,7 +257,7 @@ def learn(
processed_prompts: list[dict[int, list[tv_tensors.TVTensor]]],
ori_shapes: list[Tensor],
reset_feat: bool = False,
) -> tuple[nn.ParameterDict, list[Tensor]] | None:
) -> tuple[dict[str, Tensor], list[Tensor]] | None:
"""Get reference features.

Using given images, get reference features.
@@ -339,14 +336,11 @@ def learn(
)
default_threshold_reference -= 0.05

self.reference_info["reference_feats"][label] = ref_feat.detach().cpu()
self.reference_info["used_indices"] = Parameter(
torch.cat((self.reference_info["used_indices"], torch.tensor([label])), dim=0),
requires_grad=False,
)
self.reference_feats[label] = ref_feat.detach().cpu()
self.used_indices = torch.cat((self.used_indices, torch.tensor([label])), dim=0)
ref_masks[label] = ref_mask.detach().cpu()
reference_masks.append(ref_masks)
return self.reference_info, reference_masks
return {"reference_feats": self.reference_feats, "used_indices": self.used_indices}, reference_masks

@torch.no_grad()
def infer(
@@ -627,11 +621,13 @@ def _find_latest_reference_info(self, root: str = "vpm_zsl_reference_infos") ->
return stamps[0]
return None

def _load_latest_reference_info(self) -> None:
def _load_latest_reference_info(self, device: str | torch.device = "cpu") -> None:
"""Load latest reference info to be used."""
if (latest_stamp := self._find_latest_reference_info()) is not None:
latest_reference_info = self.path_reference_info.format(latest_stamp)
self.reference_info = torch.load(latest_reference_info)
reference_info = torch.load(latest_reference_info)
self.register_buffer("reference_feats", reference_info.get("reference_feats", torch.zeros(0, 1, self.embed_dim)).to(device), False)
self.register_buffer("used_indices", reference_info.get("used_indices", torch.tensor([], dtype=torch.int64)).to(device), False)
log.info(f"reference info saved at {latest_reference_info} was successfully loaded.")


@@ -687,8 +683,8 @@ def _customize_inputs(self, inputs: ZeroShotVisualPromptingBatchDataEntity) -> d
# infer
return {
"images": [tv_tensors.wrap(image.unsqueeze(0), like=image) for image in inputs.images],
"reference_feats": self.model.reference_info["reference_feats"],
"used_indices": self.model.reference_info["used_indices"],
"reference_feats": self.model.reference_feats,
"used_indices": self.model.used_indices,
"ori_shapes": [torch.tensor(info.ori_shape) for info in inputs.imgs_info],
"is_cascade": self.model.is_cascade,
}
23 changes: 0 additions & 23 deletions src/otx/core/model/entity/visual_prompting.py
Original file line number Diff line number Diff line change
@@ -320,26 +320,3 @@ class OTXZeroShotVisualPromptingModel(

def __init__(self, num_classes: int = 0) -> None:
super().__init__(num_classes=num_classes)

self._register_load_state_dict_pre_hook(self.load_state_dict_pre_hook)

def state_dict(
self,
*args,
destination: dict[str, Any] | None = None,
prefix: str = "",
keep_vars: bool = False,
) -> dict[str, Any] | None:
"""Return state dictionary of model entity with reference features, masks, and used indices."""
super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)

if isinstance(destination, dict):
# to save reference_info instead of reference_feats only
destination.pop(prefix + "model.reference_info.reference_feats")
destination.update({prefix + "model.reference_info": self.model.reference_info})
return destination

def load_state_dict_pre_hook(self, state_dict: dict[str, Any], prefix: str = "", *args, **kwargs) -> None:
"""Load reference info manually."""
self.model.reference_info = state_dict.get(prefix + "model.reference_info", self.model.reference_info)
state_dict[prefix + "model.reference_info.reference_feats"] = self.model.reference_info["reference_feats"]
16 changes: 9 additions & 7 deletions src/otx/core/model/module/visual_prompting.py
Original file line number Diff line number Diff line change
@@ -263,27 +263,29 @@ def on_train_start(self) -> None:

def on_test_start(self) -> None:
"""Load previously saved reference info."""
self.model.model._load_latest_reference_info()
self.model.model._load_latest_reference_info(self.device)

def on_predict_start(self) -> None:
"""Load previously saved reference info."""
self.model.model._load_latest_reference_info()
self.model.model._load_latest_reference_info(self.device)

def on_train_epoch_start(self) -> None:
"""Skip on_train_epoch_start unused in zero-shot visual prompting."""

def on_train_epoch_end(self) -> None:
"""Skip on_train_epoch_end unused in zero-shot visual prompting."""
self.model.model.reference_info["used_indices"] = Parameter(
self.model.model.reference_info["used_indices"].unique().unsqueeze(0), requires_grad=False
)
self.model.model.used_indices = self.model.model.used_indices.unique()
if self.model.model.save_outputs:
reference_info = {
"reference_feats": self.model.model.reference_feats,
"used_indices": self.model.model.used_indices,
}
# save reference info
path_reference_info = self.model.model.path_reference_info.format(time.strftime("%Y%m%d_%H%M%S"))
os.makedirs(os.path.dirname(path_reference_info), exist_ok=True)
torch.save(self.model.model.reference_info, path_reference_info)
torch.save(reference_info, path_reference_info)
pickle.dump(
{k: v.numpy() for k, v in self.model.model.reference_info.items()},
{k: v.numpy() for k, v in reference_info.items()},
open(path_reference_info.replace(".pt", ".pickle"), "wb"),
)
log.info(f"Saved reference info at {path_reference_info}.")