Skip to content

Commit

Permalink
Set path to save pseudo masks into workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed May 24, 2023
1 parent 9fb782d commit 76f752b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
5 changes: 5 additions & 0 deletions otx/cli/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,11 @@ def get_dataset_config(self, subsets: List[str], hyper_parameters: Optional[Conf
if learning_parameters:
num_workers = getattr(learning_parameters, "num_workers", 0)
dataset_config["cache_config"]["num_workers"] = num_workers
if str(self.task_type == "SEGMENTATION").upper() and str(self.train_type).upper() == "SELFSUPERVISED":
# FIXME: manually set a path to save pseudo masks in workspace
train_type_rel_path = TASK_TYPE_TO_SUB_DIR_NAME[self.train_type]
train_type_dir = self.workspace_root / train_type_rel_path
dataset_config["pseudo_mask_dir"] = train_type_dir / "detcon_mask"
return dataset_config

def update_data_config(self, data_yaml: dict) -> None:
Expand Down
2 changes: 2 additions & 0 deletions otx/core/data/adapter/base_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
unlabeled_data_roots: Optional[str] = None,
unlabeled_file_list: Optional[str] = None,
cache_config: Optional[Dict[str, Any]] = None,
**kwargs
):
self.task_type = task_type
self.domain = task_type.domain
Expand All @@ -97,6 +98,7 @@ def __init__(
test_ann_files=test_ann_files,
unlabeled_data_roots=unlabeled_data_roots,
unlabeled_file_list=unlabeled_file_list,
**kwargs
)

cache_config = cache_config if cache_config is not None else {}
Expand Down
25 changes: 11 additions & 14 deletions otx/core/data/adapter/segmentation_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
unlabeled_data_roots: Optional[str] = None,
unlabeled_file_list: Optional[str] = None,
cache_config: Optional[Dict[str, Any]] = None,
**kwargs
):
super().__init__(
task_type,
Expand All @@ -65,6 +66,7 @@ def __init__(
unlabeled_data_roots,
unlabeled_file_list,
cache_config,
**kwargs
)
self.updated_label_id: Dict[int, int] = {}

Expand Down Expand Up @@ -199,23 +201,20 @@ def _import_dataset(
self.is_train_phase = True

# Load pseudo masks
img_dir = None
total_labels = []
os.makedirs(pseudo_mask_dir, exist_ok=True)
for item in dataset[Subset.TRAINING]:
img_path = item.media.path
if img_dir is None:
# Get image directory
img_dir = train_data_roots.split("/")[-1]
pseudo_mask_path = img_path.replace(img_dir, pseudo_mask_dir)
if pseudo_mask_path.endswith(".jpg"):
pseudo_mask_path = pseudo_mask_path.replace(".jpg", ".png")
pseudo_mask_path = pseudo_mask_dir / os.path.basename(img_path)
if pseudo_mask_path.suffix == ".jpg":
pseudo_mask_path = pseudo_mask_path.with_name(f"{pseudo_mask_path.stem}.png")

if not os.path.isfile(pseudo_mask_path):
# Create pseudo mask
pseudo_mask = self.create_pseudo_masks(item.media.data, pseudo_mask_path) # type: ignore
else:
# Load created pseudo mask
pseudo_mask = cv2.imread(pseudo_mask_path, cv2.IMREAD_GRAYSCALE)
pseudo_mask = cv2.imread(str(pseudo_mask_path), cv2.IMREAD_GRAYSCALE)

# Set annotations into each item
annotations = []
Expand All @@ -229,18 +228,17 @@ def _import_dataset(
)
item.annotations = annotations

pseudo_mask_roots = train_data_roots.replace(img_dir, pseudo_mask_dir) # type: ignore
if not os.path.isfile(os.path.join(pseudo_mask_roots, "dataset_meta.json")):
if not os.path.isfile(os.path.join(pseudo_mask_dir, "dataset_meta.json")):
# Save dataset_meta.json for newly created pseudo masks
# FIXME: Because background class is ignored when generating polygons, meta is set with len(labels)-1.
# It must be considered to set the whole labels later.
# (-> {i: f"target{i+1}" for i in range(max(total_labels)+1)})
meta = {"label_map": {i + 1: f"target{i+1}" for i in range(max(total_labels))}}
with open(os.path.join(pseudo_mask_roots, "dataset_meta.json"), "w", encoding="UTF-8") as f:
with open(os.path.join(pseudo_mask_dir, "dataset_meta.json"), "w", encoding="UTF-8") as f:
json.dump(meta, f, indent=4)

# Make categories for pseudo masks
label_map = parse_meta_file(os.path.join(pseudo_mask_roots, "dataset_meta.json"))
label_map = parse_meta_file(os.path.join(pseudo_mask_dir, "dataset_meta.json"))
dataset[Subset.TRAINING].define_categories(make_categories(label_map))

return dataset
Expand All @@ -261,7 +259,6 @@ def create_pseudo_masks(self, img: np.array, pseudo_mask_path: str, mode: str =
else:
raise ValueError((f'{mode} is not supported to create pseudo masks for DetCon. Choose one of ["FH"].'))

os.makedirs(os.path.dirname(pseudo_mask_path), exist_ok=True)
cv2.imwrite(pseudo_mask_path, pseudo_mask.astype(np.uint8))
cv2.imwrite(str(pseudo_mask_path), pseudo_mask.astype(np.uint8))

return pseudo_mask

0 comments on commit 76f752b

Please sign in to comment.