Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero-shot learn for ov model #3601

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,12 +714,16 @@ def infer(
inputs: ZeroShotVisualPromptingBatchDataEntity,
reference_feats: Tensor | None = None,
used_indices: Tensor | None = None,
threshold: float = 0.0,
num_bg_points: int = 1,
is_cascade: bool = True,
) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity:
"""Infer to directly connect to the model."""
self.training = False
outputs = self.model.infer(
**self._customize_inputs(inputs, reference_feats=reference_feats, used_indices=used_indices),
threshold=threshold,
num_bg_points=num_bg_points,
is_cascade=is_cascade,
)
return self._customize_outputs(outputs, inputs)
Expand Down Expand Up @@ -774,7 +778,7 @@ def _customize_outputs( # type: ignore[override]
),
)
scores.append(torch.stack([p[2] for p in used_points[label]], dim=0))
labels.append(torch.stack([LongTensor([label]) for _ in range(scores[-1].shape[0])], dim=0))
labels.append(torch.cat([LongTensor([label]) for _ in range(scores[-1].shape[0])], dim=0))

return ZeroShotVisualPromptingBatchPredEntity(
batch_size=len(outputs),
Expand Down Expand Up @@ -886,33 +890,57 @@ def save_reference_info(self, default_root_dir: Path | str) -> None:
"used_indices": self.used_indices,
}
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
path_reference_info.parent.mkdir(parents=True, exist_ok=True)
self.saved_reference_info_path: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
self.saved_reference_info_path.parent.mkdir(parents=True, exist_ok=True)
# TODO (sungchul): ticket no. 139210
torch.save(reference_info, path_reference_info)
torch.save(reference_info, self.saved_reference_info_path)
pickle.dump(
{k: v.numpy() for k, v in reference_info.items()},
path_reference_info.with_suffix(".pickle").open("wb"),
)
log.info(f"Saved reference info at {path_reference_info}.")

def load_reference_info(self, default_root_dir: Path | str, device: str | torch.device = "cpu") -> bool:
"""Load latest reference info to be used."""
_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
self.saved_reference_info_path.with_suffix(".pickle").open("wb"),
)
log.info(f"Saved reference info at {self.saved_reference_info_path}.")

def load_reference_info(
self,
default_root_dir: Path | str,
device: str | torch.device = "cpu",
path_to_directly_load: Path | None = None,
) -> bool:
"""Load latest reference info to be used.

if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt"
).is_file():
reference_info = torch.load(path_reference_info)
Args:
default_root_dir (Path | str): Default root directory to be used
when inappropriate infer_reference_info_root is given.
device (str | torch.device): Device that reference infos will be attached.
path_to_directly_load (Path | None): Reference info path to directly be loaded.
Normally, it is obtained after `learn` which is executed when trying to do `infer`
without reference features in `on_test_start` or `on_predict_start`.

Returns:
(bool): Whether normally loading checkpoint or not.
"""
if path_to_directly_load is not None:
# if `path_to_directly_load` is given, forcely load
reference_info = torch.load(path_to_directly_load)
retval = True
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
log.info(f"reference info saved at {path_to_directly_load} was successfully loaded.")

else:
reference_info = {}
retval = False
_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
)

if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt"
).is_file():
reference_info = torch.load(path_reference_info)
retval = True
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
else:
reference_info = {}
retval = False

self.register_buffer(
"reference_feats",
Expand Down
98 changes: 70 additions & 28 deletions src/otx/core/model/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ def _inference_step_for_zero_shot(
if _name == "mAP":
# MeanAveragePrecision
_preds = [
{
k: v > 0.5 if k == "masks" else v.squeeze(1).to(model.device) if k == "labels" else v
for k, v in ett.items()
}
{k: v > 0.5 if k == "masks" else v.to(model.device) if k == "labels" else v for k, v in ett.items()}
for ett in converted_entities["preds"]
]
_target = converted_entities["target"]
Expand Down Expand Up @@ -351,7 +348,11 @@ def on_test_start(self) -> None:
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.load_reference_info(self.trainer.default_root_dir, self.device)
self.load_reference_info(
self.trainer.default_root_dir,
self.device,
path_to_directly_load=self.saved_reference_info_path,
)

def on_predict_start(self) -> None:
"""Load previously saved reference info."""
Expand All @@ -364,7 +365,11 @@ def on_predict_start(self) -> None:
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.load_reference_info(self.trainer.default_root_dir, self.device)
self.load_reference_info(
self.trainer.default_root_dir,
self.device,
path_to_directly_load=self.saved_reference_info_path,
)

def on_train_epoch_start(self) -> None:
"""Skip on_train_epoch_start unused in zero-shot visual prompting."""
Expand Down Expand Up @@ -828,7 +833,7 @@ def learn(
for label, input_prompts in prompts.items():
ref_mask: np.ndarray = np.zeros(original_shape, dtype=np.uint8)
for inputs_decoder in input_prompts:
label = inputs_decoder.pop("label") # noqa: PLW2901
inputs_decoder.pop("label")
if "point_coords" in inputs_decoder:
# bboxes and points
inputs_decoder.update(image_embeddings)
Expand All @@ -853,7 +858,7 @@ def learn(
cur_default_threshold_reference -= 0.05

self.reference_feats[label] = ref_feat
self.used_indices: np.ndarray = np.concatenate((self.used_indices, label))
self.used_indices: np.ndarray = np.concatenate((self.used_indices, [label]))
ref_masks[label] = ref_mask
reference_masks.append(ref_masks)
self.used_indices = np.unique(self.used_indices)
Expand Down Expand Up @@ -1038,7 +1043,7 @@ def _customize_outputs( # type: ignore[override]
)
scores.append(torch.stack([torch.as_tensor(p[2]) for p in used_points[label]], dim=0).to(self.device))
labels.append(
torch.stack([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0).to(self.device),
torch.cat([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0).to(self.device),
)

return ZeroShotVisualPromptingBatchPredEntity(
Expand Down Expand Up @@ -1263,12 +1268,17 @@ def save_reference_info(self, default_root_dir: Path | str) -> None:
"used_indices": self.used_indices,
}
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
path_reference_info.parent.mkdir(parents=True, exist_ok=True)
self.saved_reference_info_path: Path = (
Path(default_root_dir) / self.reference_info_dir / "reference_info.pickle"
)
self.saved_reference_info_path.parent.mkdir(parents=True, exist_ok=True)
# TODO (sungchul): ticket no. 139210
torch.save({k: torch.as_tensor(v) for k, v in reference_info.items()}, path_reference_info)
pickle.dump(reference_info, path_reference_info.with_suffix(".pickle").open("wb"))
log.info(f"Saved reference info at {path_reference_info}.")
torch.save(
{k: torch.as_tensor(v) for k, v in reference_info.items()},
self.saved_reference_info_path.with_suffix(".pt"),
)
pickle.dump(reference_info, self.saved_reference_info_path.open("wb"))
log.info(f"Saved reference info at {self.saved_reference_info_path}.")

def _generate_masked_features(
self,
Expand Down Expand Up @@ -1322,8 +1332,40 @@ def _pad_to_square(self, x: np.ndarray, image_size: int = 1024) -> np.ndarray:
######################################
# Infer #
######################################
def load_reference_info(self, default_root_dir: Path | str, *args, **kwargs) -> bool:
"""Load latest reference info to be used."""
def load_reference_info(
self,
default_root_dir: Path | str,
*args,
path_to_directly_load: Path | None = None,
**kwargs,
) -> bool:
"""Load latest reference info to be used.

Args:
default_root_dir (Path | str): Default root directory to be used
when inappropriate infer_reference_info_root is given.
path_to_directly_load (Path | None): Reference info path to directly be loaded.
Normally, it is obtained after `learn` which is executed when trying to do `infer`
without reference features in `on_test_start` or `on_predict_start`.

Returns:
(bool): Whether normally loading checkpoint or not.
"""

def _load_and_assign_reference_info(path: Path) -> bool:
reference_info: dict[str, np.ndarray] = pickle.load(path.open("rb")) # noqa: S301 # nosec: B301
self.reference_feats = reference_info.get(
"reference_feats",
np.zeros((0, 1, self.model["decoder"].embed_dim), dtype=np.float32),
)
self.used_indices = reference_info.get("used_indices", np.array([], dtype=np.int64))
log.info(f"reference info saved at {path} was successfully loaded.")
return True

if path_to_directly_load is not None:
# if `path_to_directly_load` is given, forcely load
return _load_and_assign_reference_info(path_to_directly_load)

_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
Expand All @@ -1333,14 +1375,8 @@ def load_reference_info(self, default_root_dir: Path | str, *args, **kwargs) ->
if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pickle"
).is_file():
reference_info: dict[str, np.ndarray] = pickle.load(path_reference_info.open("rb")) # noqa: S301 # nosec: B301
self.reference_feats = reference_info.get(
"reference_feats",
np.zeros((0, 1, self.model["decoder"].embed_dim), dtype=np.float32),
)
self.used_indices = reference_info.get("used_indices", np.array([], dtype=np.int64))
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
return True
return _load_and_assign_reference_info(path_reference_info)

return False

def _get_prompt_candidates(
Expand Down Expand Up @@ -1527,7 +1563,7 @@ def on_train_start(self) -> None:
def on_test_start(self) -> None:
"""Load previously saved reference info."""
super().on_test_start()
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
if not self.load_reference_info(self.trainer.default_root_dir):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.trainer.fit_loop.run()
Expand All @@ -1536,11 +1572,14 @@ def on_test_start(self) -> None:
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.load_reference_info(self.trainer.default_root_dir, self.device)
self.load_reference_info(
self.trainer.default_root_dir,
path_to_directly_load=self.saved_reference_info_path,
)

def on_predict_start(self) -> None:
"""Load previously saved reference info."""
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
if not self.load_reference_info(self.trainer.default_root_dir):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.trainer.fit_loop.run()
Expand All @@ -1549,7 +1588,10 @@ def on_predict_start(self) -> None:
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.load_reference_info(self.trainer.default_root_dir, self.device)
self.load_reference_info(
self.trainer.default_root_dir,
path_to_directly_load=self.saved_reference_info_path,
)

def on_train_epoch_start(self) -> None:
"""Skip on_train_epoch_start unused in zero-shot visual prompting."""
Expand Down
Loading
Loading