diff --git a/bundle/python_bundle_workflow/scripts/inference.py b/bundle/python_bundle_workflow/scripts/inference.py index 24810a0dfb..a98b47bb64 100644 --- a/bundle/python_bundle_workflow/scripts/inference.py +++ b/bundle/python_bundle_workflow/scripts/inference.py @@ -38,6 +38,7 @@ ScaleIntensityd, ) from monai.utils import BundleProperty +from scripts.train import prepare_data class InferenceWorkflow(BundleWorkflow): @@ -46,7 +47,7 @@ class InferenceWorkflow(BundleWorkflow): """ - def __init__(self, dataset_dir: str = "."): + def __init__(self, dataset_dir: str = "./infer"): super().__init__(workflow="inference") print_config() # set root log level to INFO and init a evaluation logger, will be used in `StatsHandler` @@ -54,20 +55,14 @@ def __init__(self, dataset_dir: str = "."): get_logger("eval_log") # create a temporary directory and 40 random image, mask pairs - print(f"generating synthetic data to {dataset_dir} (this may take a while)") - for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz")) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz")) + prepare_data(dataset_dir=dataset_dir) self._props = {} self._set_props = {} self.dataset_dir = dataset_dir def initialize(self): - self.props = {} + self._props = {} def run(self): self.evaluator.run() @@ -76,6 +71,7 @@ def finalize(self): pass def _set_property(self, name, property, value): + # stores user-reset initialized objects that should not be re-initialized. self._set_props[name] = value def _get_property(self, name, property): @@ -88,11 +84,11 @@ def _get_property(self, name, property): """ value = None - if name in self._props: - value = self._props[name] - elif name in self._set_props: + if name in self._set_props: value = self._set_props[name] self._props[name] = value + elif name in self._props: + value = self._props[name] else: try: value = getattr(self, f"get_{name}")() @@ -112,7 +108,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "." + return self.dataset_dir def get_network_def(self): return UNet( diff --git a/bundle/python_bundle_workflow/scripts/train.py b/bundle/python_bundle_workflow/scripts/train.py index ccd409e969..29a53b46a1 100644 --- a/bundle/python_bundle_workflow/scripts/train.py +++ b/bundle/python_bundle_workflow/scripts/train.py @@ -12,6 +12,7 @@ import logging import os import sys +from pathlib import Path from glob import glob import nibabel as nib @@ -48,13 +49,24 @@ from monai.utils import BundleProperty, set_determinism +def prepare_data(dataset_dir): + Path(dataset_dir).mkdir(exist_ok=True) + print(f"generating synthetic data to {dataset_dir} (this may take a while)") + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz")) + + class TrainWorkflow(BundleWorkflow): """ Test class simulates the bundle training workflow defined by Python script directly. """ - def __init__(self, dataset_dir: str = "."): + def __init__(self, dataset_dir: str = "./train"): super().__init__(workflow="train") print_config() # set root log level to INFO and init a train logger, will be used in `StatsHandler` @@ -62,13 +74,7 @@ def __init__(self, dataset_dir: str = "."): get_logger("train_log") # create a temporary directory and 40 random image, mask pairs - print(f"generating synthetic data to {dataset_dir} (this may take a while)") - for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz")) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz")) + prepare_data(dataset_dir=dataset_dir) # define buckets to store the generated properties and set properties self._props = {} @@ -82,7 +88,7 @@ def __init__(self, dataset_dir: str = "."): def initialize(self): set_determinism(0) - self.props = {} + self._props = {} def run(self): self.trainer.run() @@ -91,6 +97,7 @@ def finalize(self): set_determinism(None) def _set_property(self, name, property, value): + # stores user-reset initialized objects that should not be re-initialized. self._set_props[name] = value def _get_property(self, name, property): @@ -103,11 +110,11 @@ def _get_property(self, name, property): """ value = None - if name in self._props: - value = self._props[name] - elif name in self._set_props: + if name in self._set_props: value = self._set_props[name] self._props[name] = value + elif name in self._props: + value = self._props[name] else: try: value = getattr(self, f"get_{name}")() @@ -127,7 +134,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "." + return self.dataset_dir def get_network(self): return UNet(