Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update python_bundle_workflow #1656

Merged
merged 11 commits into from
Mar 6, 2024
16 changes: 6 additions & 10 deletions bundle/python_bundle_workflow/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ScaleIntensityd,
)
from monai.utils import BundleProperty
from scripts.train import prepare_data


class InferenceWorkflow(BundleWorkflow):
Expand All @@ -46,28 +47,22 @@ class InferenceWorkflow(BundleWorkflow):

"""

def __init__(self, dataset_dir: str = "."):
def __init__(self, dataset_dir: str = "./infer"):
super().__init__(workflow="inference")
print_config()
# set root log level to INFO and init a evaluation logger, will be used in `StatsHandler`
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
get_logger("eval_log")

# create a temporary directory and 40 random image, mask pairs
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
for i in range(5):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))
prepare_data(dataset_dir=dataset_dir)

self._props = {}
self._set_props = {}
self.dataset_dir = dataset_dir

def initialize(self):
self.props = {}
self._props = {}

def run(self):
self.evaluator.run()
Expand All @@ -76,6 +71,7 @@ def finalize(self):
pass

def _set_property(self, name, property, value):
# stores user-reset initialized objects that should not be re-initialized.
self._set_props[name] = value

def _get_property(self, name, property):
Expand Down Expand Up @@ -112,7 +108,7 @@ def get_device(self):
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_dataset_dir(self):
return "."
return "./infer"

def get_network_def(self):
return UNet(
Expand Down
28 changes: 17 additions & 11 deletions bundle/python_bundle_workflow/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import sys
from pathlib import Path
from glob import glob

import nibabel as nib
Expand Down Expand Up @@ -48,32 +49,36 @@
from monai.utils import BundleProperty, set_determinism


def prepare_data(dataset_dir):
Path(dataset_dir).mkdir(exist_ok=True)
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
for i in range(40):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))


class TrainWorkflow(BundleWorkflow):
"""
Test class simulates the bundle training workflow defined by Python script directly.

"""

def __init__(self, dataset_dir: str = "."):
def __init__(self, dataset_dir: str = "./train"):
super().__init__(workflow="train")
print_config()
# set root log level to INFO and init a train logger, will be used in `StatsHandler`
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
get_logger("train_log")

# create a temporary directory and 40 random image, mask pairs
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
for i in range(40):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))
prepare_data(dataset_dir=dataset_dir)

# define buckets to store the generated properties and set properties
self._props = {}
self._set_props = {}
self.dataset_dir = dataset_dir

# besides the predefined properties, this bundle workflow can also provide `network`, `loss`, `optimizer`
self.add_property(name="network", required=True, desc="network for the training.")
Expand All @@ -82,7 +87,7 @@ def __init__(self, dataset_dir: str = "."):

def initialize(self):
set_determinism(0)
self.props = {}
self._props = {}

def run(self):
self.trainer.run()
Expand All @@ -91,6 +96,7 @@ def finalize(self):
set_determinism(None)

def _set_property(self, name, property, value):
# stores user-reset initialized objects that should not be re-initialized.
self._set_props[name] = value

def _get_property(self, name, property):
Expand Down Expand Up @@ -127,7 +133,7 @@ def get_device(self):
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_dataset_dir(self):
return "."
return "./train"

def get_network(self):
return UNet(
Expand Down
Loading