Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reference info path for zero-shot learning #3354

Merged
merged 10 commits into from
Apr 22, 2024
49 changes: 33 additions & 16 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import logging as log
import os
import pickle
from collections import defaultdict
from copy import deepcopy
from itertools import product
Expand Down Expand Up @@ -624,15 +624,16 @@ def _decide_cascade_results(
class OTXZeroShotSegmentAnything(OTXZeroShotVisualPromptingModel):
"""Zero-Shot Visual Prompting model."""

def __init__(
def __init__( # noqa: PLR0913
self,
backbone: Literal["tiny_vit", "vit_b"],
num_classes: int = 0,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = VisualPromptingMetricCallable,
torch_compile: bool = False,
root_reference_info: Path | str = "vpm_zsl_reference_infos",
reference_info_dir: Path | str = "reference_infos",
infer_reference_info_root: Path | str = "../.latest/train",
save_outputs: bool = True,
pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006
pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006
Expand Down Expand Up @@ -668,7 +669,8 @@ def __init__(
)

self.save_outputs = save_outputs
self.root_reference_info: Path = Path(root_reference_info)
self.reference_info_dir: Path = Path(reference_info_dir)
vinnamkim marked this conversation as resolved.
Show resolved Hide resolved
self.infer_reference_info_root: Path = Path(infer_reference_info_root)

self.register_buffer("pixel_mean", Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", Tensor(pixel_std).view(-1, 1, 1), False)
Expand Down Expand Up @@ -876,21 +878,36 @@ def initialize_reference_info(self) -> None:
self.register_buffer("reference_feats", torch.zeros(0, 1, self.model.embed_dim), False)
self.register_buffer("used_indices", torch.tensor([], dtype=torch.int64), False)

def _find_latest_reference_info(self, root: Path) -> str | None:
"""Find latest reference info to be used."""
if not Path.is_dir(root):
return None
if len(stamps := sorted(os.listdir(root), reverse=True)) > 0:
return stamps[0]
return None
def save_reference_info(self, default_root_dir: Path | str) -> None:
"""Save reference info."""
reference_info = {
"reference_feats": self.reference_feats,
"used_indices": self.used_indices,
}
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
Path.mkdir(Path(path_reference_info).parent, parents=True, exist_ok=True)
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
torch.save(reference_info, path_reference_info)
pickle.dump(
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
{k: v.numpy() for k, v in reference_info.items()},
Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"),
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
)
log.info(f"Saved reference info at {path_reference_info}.")

def load_latest_reference_info(self, device: str | torch.device = "cpu") -> bool:
def load_reference_info(self, default_root_dir: Path | str, device: str | torch.device = "cpu") -> bool:
"""Load latest reference info to be used."""
if (latest_stamp := self._find_latest_reference_info(self.root_reference_info)) is not None:
latest_reference_info = self.root_reference_info / latest_stamp / "reference_info.pt"
reference_info = torch.load(latest_reference_info)
_infer_reference_info_root = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
)

if Path.is_file(
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt",
):
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
reference_info = torch.load(path_reference_info)
retval = True
log.info(f"reference info saved at {latest_reference_info} was successfully loaded.")
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
else:
reference_info = {}
retval = False
Expand Down
Loading
Loading