From 570b19e678e06aba26f53df8bc73f848ae2984ba Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:58:36 +0800 Subject: [PATCH] Update `python_bundle_workflow` (#1656) Fixes # . ### Description - add description for `self._set_prop` in python `python_bundle_workflow`. - remove generate data outside of the Workflow class. ### Checks - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t ` --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../scripts/inference.py | 22 +++++-------- .../python_bundle_workflow/scripts/train.py | 33 +++++++++++-------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/bundle/python_bundle_workflow/scripts/inference.py b/bundle/python_bundle_workflow/scripts/inference.py index 24810a0dfb..a98b47bb64 100644 --- a/bundle/python_bundle_workflow/scripts/inference.py +++ b/bundle/python_bundle_workflow/scripts/inference.py @@ -38,6 +38,7 @@ ScaleIntensityd, ) from monai.utils import BundleProperty +from scripts.train import prepare_data class InferenceWorkflow(BundleWorkflow): @@ -46,7 +47,7 @@ class InferenceWorkflow(BundleWorkflow): """ - def __init__(self, dataset_dir: str = "."): + def __init__(self, dataset_dir: str = "./infer"): super().__init__(workflow="inference") print_config() # set root log level to INFO and init a evaluation logger, will be used in `StatsHandler` @@ -54,20 +55,14 @@ def __init__(self, dataset_dir: str = "."): get_logger("eval_log") # create a temporary directory and 40 random image, mask pairs - print(f"generating synthetic data to {dataset_dir} (this may take a while)") - for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz")) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz")) + prepare_data(dataset_dir=dataset_dir) self._props = {} self._set_props = {} self.dataset_dir = dataset_dir def initialize(self): - self.props = {} + self._props = {} def run(self): self.evaluator.run() @@ -76,6 +71,7 @@ def finalize(self): pass def _set_property(self, name, property, value): + # stores user-reset initialized objects that should not be re-initialized. self._set_props[name] = value def _get_property(self, name, property): @@ -88,11 +84,11 @@ def _get_property(self, name, property): """ value = None - if name in self._props: - value = self._props[name] - elif name in self._set_props: + if name in self._set_props: value = self._set_props[name] self._props[name] = value + elif name in self._props: + value = self._props[name] else: try: value = getattr(self, f"get_{name}")() @@ -112,7 +108,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "." + return self.dataset_dir def get_network_def(self): return UNet( diff --git a/bundle/python_bundle_workflow/scripts/train.py b/bundle/python_bundle_workflow/scripts/train.py index ccd409e969..29a53b46a1 100644 --- a/bundle/python_bundle_workflow/scripts/train.py +++ b/bundle/python_bundle_workflow/scripts/train.py @@ -12,6 +12,7 @@ import logging import os import sys +from pathlib import Path from glob import glob import nibabel as nib @@ -48,13 +49,24 @@ from monai.utils import BundleProperty, set_determinism +def prepare_data(dataset_dir): + Path(dataset_dir).mkdir(exist_ok=True) + print(f"generating synthetic data to {dataset_dir} (this may take a while)") + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz")) + + class TrainWorkflow(BundleWorkflow): """ Test class simulates the bundle training workflow defined by Python script directly. """ - def __init__(self, dataset_dir: str = "."): + def __init__(self, dataset_dir: str = "./train"): super().__init__(workflow="train") print_config() # set root log level to INFO and init a train logger, will be used in `StatsHandler` @@ -62,13 +74,7 @@ def __init__(self, dataset_dir: str = "."): get_logger("train_log") # create a temporary directory and 40 random image, mask pairs - print(f"generating synthetic data to {dataset_dir} (this may take a while)") - for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz")) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz")) + prepare_data(dataset_dir=dataset_dir) # define buckets to store the generated properties and set properties self._props = {} @@ -82,7 +88,7 @@ def __init__(self, dataset_dir: str = "."): def initialize(self): set_determinism(0) - self.props = {} + self._props = {} def run(self): self.trainer.run() @@ -91,6 +97,7 @@ def finalize(self): set_determinism(None) def _set_property(self, name, property, value): + # stores user-reset initialized objects that should not be re-initialized. self._set_props[name] = value def _get_property(self, name, property): @@ -103,11 +110,11 @@ def _get_property(self, name, property): """ value = None - if name in self._props: - value = self._props[name] - elif name in self._set_props: + if name in self._set_props: value = self._set_props[name] self._props[name] = value + elif name in self._props: + value = self._props[name] else: try: value = getattr(self, f"get_{name}")() @@ -127,7 +134,7 @@ def get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset_dir(self): - return "." + return self.dataset_dir def get_network(self): return UNet(