Skip to content

Commit

Permalink
Refactoring using Path
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Apr 19, 2024
1 parent 653ac02 commit d90ab5d
Showing 2 changed files with 13 additions and 13 deletions.
12 changes: 6 additions & 6 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
@@ -886,25 +886,25 @@ def save_reference_info(self, default_root_dir: Path | str) -> None:
}
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
Path.mkdir(Path(path_reference_info).parent, parents=True, exist_ok=True)
path_reference_info.parent.mkdir(parents=True, exist_ok=True)
torch.save(reference_info, path_reference_info)
pickle.dump(
{k: v.numpy() for k, v in reference_info.items()},
Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"),
path_reference_info.with_suffix(".pickle").open("wb"),
)
log.info(f"Saved reference info at {path_reference_info}.")

def load_reference_info(self, default_root_dir: Path | str, device: str | torch.device = "cpu") -> bool:
"""Load latest reference info to be used."""
_infer_reference_info_root = (
_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
)

if Path.is_file(
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt",
):
if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt"
).is_file():
reference_info = torch.load(path_reference_info)
retval = True
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
14 changes: 7 additions & 7 deletions src/otx/core/model/visual_prompting.py
Original file line number Diff line number Diff line change
@@ -1254,9 +1254,9 @@ def save_reference_info(self, default_root_dir: Path | str) -> None:
}
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
Path.mkdir(Path(path_reference_info).parent, parents=True, exist_ok=True)
path_reference_info.parent.mkdir(parents=True, exist_ok=True)
torch.save({k: torch.as_tensor(v) for k, v in reference_info.items()}, path_reference_info)
pickle.dump(reference_info, Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"))
pickle.dump(reference_info, path_reference_info.with_suffix(".pickle").open("wb"))
log.info(f"Saved reference info at {path_reference_info}.")

def _generate_masked_features(
@@ -1313,16 +1313,16 @@ def _pad_to_square(self, x: np.ndarray, image_size: int = 1024) -> np.ndarray:
######################################
def load_reference_info(self, default_root_dir: Path | str, *args, **kwargs) -> bool:
"""Load latest reference info to be used."""
_infer_reference_info_root = (
_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
)

if Path.is_file(
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pickle",
):
reference_info: dict[str, np.ndarray] = pickle.load(Path.open(path_reference_info, "rb")) # noqa: S301
if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pickle"
).is_file():
reference_info: dict[str, np.ndarray] = pickle.load(path_reference_info.open("rb")) # noqa: S301
self.reference_feats = reference_info.get(
"reference_feats",
np.zeros((0, 1, self.model["decoder"].embed_dim), dtype=np.float32),

0 comments on commit d90ab5d

Please sign in to comment.