Skip to content

Commit

Permalink
Update python_bundle_workflow (Project-MONAI#1656)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

- add description for `self._set_prop` in python
`python_bundle_workflow`.
- remove generate data outside of the Workflow class.

### Checks
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Avoid including large-size files in the PR.
- [ ] Clean up long text outputs from code cells in the notebook.
- [ ] For security purposes, please check the contents and remove any
sensitive info such as user names and private key.
- [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use
relative paths for tutorial repo files (3) put figure and graphs in the
`./figure` folder
- [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>`

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
KumoLiu and pre-commit-ci[bot] authored Mar 6, 2024
1 parent 8186739 commit 570b19e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
22 changes: 9 additions & 13 deletions bundle/python_bundle_workflow/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ScaleIntensityd,
)
from monai.utils import BundleProperty
from scripts.train import prepare_data


class InferenceWorkflow(BundleWorkflow):
Expand All @@ -46,28 +47,22 @@ class InferenceWorkflow(BundleWorkflow):
"""

def __init__(self, dataset_dir: str = "."):
def __init__(self, dataset_dir: str = "./infer"):
super().__init__(workflow="inference")
print_config()
# set root log level to INFO and init a evaluation logger, will be used in `StatsHandler`
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
get_logger("eval_log")

# create a temporary directory and 40 random image, mask pairs
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
for i in range(5):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))
prepare_data(dataset_dir=dataset_dir)

self._props = {}
self._set_props = {}
self.dataset_dir = dataset_dir

def initialize(self):
self.props = {}
self._props = {}

def run(self):
self.evaluator.run()
Expand All @@ -76,6 +71,7 @@ def finalize(self):
pass

def _set_property(self, name, property, value):
# stores user-reset initialized objects that should not be re-initialized.
self._set_props[name] = value

def _get_property(self, name, property):
Expand All @@ -88,11 +84,11 @@ def _get_property(self, name, property):
"""
value = None
if name in self._props:
value = self._props[name]
elif name in self._set_props:
if name in self._set_props:
value = self._set_props[name]
self._props[name] = value
elif name in self._props:
value = self._props[name]
else:
try:
value = getattr(self, f"get_{name}")()
Expand All @@ -112,7 +108,7 @@ def get_device(self):
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_dataset_dir(self):
return "."
return self.dataset_dir

def get_network_def(self):
return UNet(
Expand Down
33 changes: 20 additions & 13 deletions bundle/python_bundle_workflow/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import sys
from pathlib import Path
from glob import glob

import nibabel as nib
Expand Down Expand Up @@ -48,27 +49,32 @@
from monai.utils import BundleProperty, set_determinism


def prepare_data(dataset_dir):
Path(dataset_dir).mkdir(exist_ok=True)
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
for i in range(40):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))


class TrainWorkflow(BundleWorkflow):
"""
Test class simulates the bundle training workflow defined by Python script directly.
"""

def __init__(self, dataset_dir: str = "."):
def __init__(self, dataset_dir: str = "./train"):
super().__init__(workflow="train")
print_config()
# set root log level to INFO and init a train logger, will be used in `StatsHandler`
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
get_logger("train_log")

# create a temporary directory and 40 random image, mask pairs
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
for i in range(40):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))
prepare_data(dataset_dir=dataset_dir)

# define buckets to store the generated properties and set properties
self._props = {}
Expand All @@ -82,7 +88,7 @@ def __init__(self, dataset_dir: str = "."):

def initialize(self):
set_determinism(0)
self.props = {}
self._props = {}

def run(self):
self.trainer.run()
Expand All @@ -91,6 +97,7 @@ def finalize(self):
set_determinism(None)

def _set_property(self, name, property, value):
# stores user-reset initialized objects that should not be re-initialized.
self._set_props[name] = value

def _get_property(self, name, property):
Expand All @@ -103,11 +110,11 @@ def _get_property(self, name, property):
"""
value = None
if name in self._props:
value = self._props[name]
elif name in self._set_props:
if name in self._set_props:
value = self._set_props[name]
self._props[name] = value
elif name in self._props:
value = self._props[name]
else:
try:
value = getattr(self, f"get_{name}")()
Expand All @@ -127,7 +134,7 @@ def get_device(self):
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_dataset_dir(self):
return "."
return self.dataset_dir

def get_network(self):
return UNet(
Expand Down

0 comments on commit 570b19e

Please sign in to comment.