diff --git a/bundle/05_spleen_segmentation_lightning.ipynb b/bundle/05_spleen_segmentation_lightning.ipynb new file mode 100644 index 0000000000..77ac86184c --- /dev/null +++ b/bundle/05_spleen_segmentation_lightning.ipynb @@ -0,0 +1,1077 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5e8ae3d7-3e2e-4755-a0b6-709ef4180719", + "metadata": {}, + "source": [ + "Copyright (c) MONAI Consortium \n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", + "you may not use this file except in compliance with the License. \n", + "You may obtain a copy of the License at \n", + "    http://www.apache.org/licenses/LICENSE-2.0 \n", + "Unless required by applicable law or agreed to in writing, software \n", + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", + "See the License for the specific language governing permissions and \n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "191c5d77-8ae5-49ab-be22-45f5ba41641f", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "886952c4-0be4-459d-9c53-b81b29199c76", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T13:48:44.235392252Z", + "start_time": "2023-10-16T13:48:28.253469477Z" + } + }, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite,pyyaml]\"\n", + "!pip install -q pytorch-lightning~=2.0.0" + ] + }, + { + "cell_type": "markdown", + "id": "a20e1274-0a27-4e37-95d7-fb813243c34c", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-06T15:07:19.730871161Z", + "start_time": "2023-10-06T15:07:11.317018521Z" + } + }, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b1144d87-ec2f-4b9b-907a-16ea2da279c4", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-06T15:11:48.797015283Z", + "start_time": "2023-10-06T15:11:42.300276550Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MONAI version: 1.3.dev2340\n", + "Numpy version: 1.26.0\n", + "Pytorch version: 2.0.1+cu117\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", + "MONAI rev id: 8d89083eeb8005babd7b5f76df83c1c80276cc10\n", + "MONAI __file__: /home//miniconda3/envs/monai_tutorial/lib/python3.9/site-packages/monai/__init__.py\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.11\n", + "ITK version: 5.3.0\n", + "Nibabel version: 5.1.0\n", + "scikit-image version: 0.21.0\n", + "scipy version: 1.11.3\n", + "Pillow version: 10.0.1\n", + "Tensorboard version: 2.14.1\n", + "gdown version: 4.7.1\n", + "TorchVision version: 0.15.2+cu117\n", + "tqdm version: 4.66.1\n", + "lmdb version: 1.4.1\n", + "psutil version: 5.9.0\n", + "pandas version: 2.1.1\n", + "einops version: 0.7.0\n", + "transformers version: 4.21.3\n", + "mlflow version: 2.7.1\n", + "pynrrd version: 1.0.0\n", + "clearml version: 1.13.1\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + } + ], + "source": [ + "from monai.config import print_config\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "c572d8b6-3dca-4487-80ad-928090b3e8ab", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-06T15:07:34.380130283Z", + "start_time": "2023-10-06T15:07:34.330086596Z" + } + }, + "source": [ + "# Spleen Segmentation Lightning Bundle\n", + "\n", + "In this tutorial we'll describe how to create a bundle for a segmentation network. This will include how to train and apply the network on the command line. Medical will be used as the dataset with the bundle based off the [Spleen 3D segmentation with MONAI](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d_lightning.ipynb) from Spleen segmentation using Task_09 subset from the Medical Segmentation Decathlon.\n", + "\n", + "This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-sa/4.0/.\n" + ] + }, + { + "cell_type": "markdown", + "id": "1a18d5cd-6338-4b41-87fd-4e119723bfee", + "metadata": {}, + "source": [ + "Let's start by initialising a bundle directory structure and create a python module `scripts`:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e00b5416-dfab-4043-9293-ec2acdf5842d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T14:44:10.513242253Z", + "start_time": "2023-10-16T14:44:02.546210817Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/usr/bin/tree\n", + "SpleenSegLightning\n", + "├── configs\n", + "│   └── metadata.json\n", + "├── docs\n", + "│   └── README.md\n", + "├── LICENSE\n", + "├── models\n", + "└── scripts\n", + " └── __init__.py\n", + "\n", + "4 directories, 4 files\n" + ] + } + ], + "source": [ + "%%bash\n", + "\n", + "python -m monai.bundle init_bundle SpleenSegLightning\n", + "rm SpleenSegLightning/configs/inference.json\n", + "mkdir SpleenSegLightning/scripts\n", + "touch SpleenSegLightning/scripts/__init__.py\n", + "which tree && tree SpleenSegLightning || true" + ] + }, + { + "cell_type": "markdown", + "id": "5888c9bd-5022-40b5-9dec-84d9f737f868", + "metadata": {}, + "source": [ + "## Metadata\n", + "\n", + "We'll first replace the `metadata.json` file with our description of what the network will do:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b29f053b-cf16-4ffc-bbe7-d9433fdfa872", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T14:45:11.617630093Z", + "start_time": "2023-10-16T14:45:11.573340254Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting SpleenSegLightning/configs/metadata.json\n" + ] + } + ], + "source": [ + "%%writefile SpleenSegLightning/configs/metadata.json\n", + "\n", + "{\n", + " \"version\": \"0.0.1\",\n", + " \"changelog\": {\n", + " \"0.0.1\": \"Initial version\"\n", + " },\n", + " \"monai_version\": \"1.2.0\",\n", + " \"pytorch_version\": \"2.0.0\",\n", + " \"numpy_version\": \"1.23.5\",\n", + " \"optional_packages_version\": {},\n", + " \"name\": \"SpleenSegLightning\",\n", + " \"task\": \"3D Spleen segmentation network using MONAI and Pytorch Lightning\",\n", + " \"description\": \"This is a demo network for segmentation of the spleen from 3D MRI images.\",\n", + " \"authors\": \"Your Name Here\",\n", + " \"copyright\": \"Copyright (c) Your Name Here\",\n", + " \"data_source\": \"Task_09 subset from the Medical Segmentation Decathlon\",\n", + " \"data_type\": \"Nifti\",\n", + " \"intended_use\": \"This is suitable for demonstration only\",\n", + " \"network_data_format\": {\n", + " \"inputs\": {\n", + " \"image\": {\n", + " \"type\": \"image\",\n", + " \"format\": \"magnitude\",\n", + " \"modality\": \"MR\",\n", + " \"num_channels\": 1,\n", + " \"spatial_shape\": [160, 160, 160],\n", + " \"dtype\": \"float32\",\n", + " \"value_range\": [0, 1],\n", + " \"is_patch_data\": false,\n", + " \"channel_def\": {\"0\": \"image\"}\n", + " }\n", + " },\n", + " \"outputs\": {\n", + " \"pred\": {\n", + " \"type\": \"image\",\n", + " \"format\": \"labels\",\n", + " \"num_channels\": 2,\n", + " \"spatial_shape\": [160, 160, 160],\n", + " \"dtype\": \"float32\",\n", + " \"value_range\": [],\n", + " \"is_patch_data\": false,\n", + " \"channel_def\": {\"0\": \"background\", \"1\": \"spleen\"}\n", + " }\n", + " }\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "3f208bf8-0c3a-4def-ab0f-6091cebcd532", + "metadata": {}, + "source": [ + "\n", + "## Common Definitions\n", + "\n", + "What we'll now do is construct the bundle configuration scripts to implement training, testing, and inference based off the original script file given above. Common definitions should be placed in a common file used with other scripts to reduce duplication. In our original script, the network definition and transform sequence will be used in multiple places so should go in this common file:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d11681af-3210-4b2b-b7bd-8ad8dedfe230", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T14:56:36.558682685Z", + "start_time": "2023-10-16T14:56:36.528064430Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing SpleenSegLightning/configs/common.yaml\n" + ] + } + ], + "source": [ + "%%writefile SpleenSegLightning/configs/common.yaml\n", + "\n", + "# common imports\n", + "imports: \n", + "- $import glob\n", + "- $import os\n", + "\n", + "# define a default root directory value, this can \n", + "# overridden on the command line\n", + "bundle_dir: .\n", + "data_dir: .\n", + "\n", + "# use constants from MONAI instead of hard-coding names\n", + "image: $monai.utils.CommonKeys.IMAGE\n", + "label: $monai.utils.CommonKeys.LABEL\n", + "\n", + "# define a train and validation files from the data directory\n", + "train_images: '$sorted(glob.glob(os.path.join(@data_dir, ''imagesTr'', ''*.nii.gz'')))'\n", + "train_labels: '$sorted(glob.glob(os.path.join(@data_dir, ''labelsTr'', ''*.nii.gz'')))'\n", + "\n", + "data_dicts: '$[{''image'': img, ''label'': lbl} for img, lbl in zip(@train_images, @train_labels)]'\n", + "\n", + "train_files: '$@data_dicts[:-9]'\n", + "val_files: '$@data_dicts[-9:]'" + ] + }, + { + "cell_type": "markdown", + "id": "60ee968cb538d983", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "# Scripts for training and evaluation\n", + "\n", + "We'll define the training and evaluation yaml files and scripts contained the Pytorch Lightning-based network. First, in the Python module `scripts`, we'll add `model.py` file containing the network definition:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2c15149785c2192", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T14:45:39.050765311Z", + "start_time": "2023-10-16T14:45:39.031364249Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing SpleenSegLightning/scripts/model.py\n" + ] + } + ], + "source": [ + "%%writefile SpleenSegLightning/scripts/model.py\n", + "\n", + "import pytorch_lightning\n", + "from monai.utils import set_determinism\n", + "from monai.transforms import (\n", + " AsDiscrete,\n", + " Compose,\n", + " EnsureType,\n", + ")\n", + "from monai.networks.nets import UNet\n", + "from monai.networks.layers import Norm\n", + "from monai.metrics import DiceMetric\n", + "from monai.losses import DiceLoss\n", + "from monai.inferers import sliding_window_inference\n", + "from monai.data import decollate_batch\n", + "import torch\n", + "\n", + "\n", + "class MySegNet(pytorch_lightning.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self._model = UNet(\n", + " spatial_dims=3,\n", + " in_channels=1,\n", + " out_channels=2,\n", + " channels=(16, 32, 64, 128, 256),\n", + " strides=(2, 2, 2, 2),\n", + " num_res_units=2,\n", + " norm=Norm.BATCH,\n", + " )\n", + " self.learning_rate = 1e-4\n", + " self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)\n", + " self.post_pred = Compose([EnsureType(\"tensor\", device=\"cpu\"),\n", + " AsDiscrete(argmax=True, to_onehot=2)])\n", + " self.post_label = Compose([EnsureType(\"tensor\", device=\"cpu\"),\n", + " AsDiscrete(to_onehot=2)])\n", + " self.dice_metric = DiceMetric(include_background=False, reduction=\"mean\",\n", + " get_not_nans=False)\n", + " self.best_val_dice = 0\n", + " self.best_val_epoch = 0\n", + " self.validation_step_outputs = []\n", + "\n", + " def forward(self, x):\n", + " return self._model(x)\n", + "\n", + " def configure_optimizers(self):\n", + " print(\"configure_optimizers\", self.learning_rate)\n", + " optimizer = torch.optim.Adam(self._model.parameters(), self.learning_rate)\n", + " return optimizer\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " images, labels = batch[\"image\"], batch[\"label\"]\n", + " output = self.forward(images)\n", + " loss = self.loss_function(output, labels)\n", + " tensorboard_logs = {\"train_loss\": loss.item()}\n", + " return {\"loss\": loss, \"log\": tensorboard_logs}\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " images, labels = batch[\"image\"], batch[\"label\"]\n", + " roi_size = (160, 160, 160)\n", + " sw_batch_size = 4\n", + " outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward)\n", + " loss = self.loss_function(outputs, labels)\n", + " outputs = [self.post_pred(i) for i in decollate_batch(outputs)]\n", + " labels = [self.post_label(i) for i in decollate_batch(labels)]\n", + " self.dice_metric(y_pred=outputs, y=labels)\n", + " d = {\"val_loss\": loss, \"val_number\": len(outputs)}\n", + " self.validation_step_outputs.append(d)\n", + " return d\n", + "\n", + " def on_validation_epoch_end(self):\n", + " val_loss, num_items = 0, 0\n", + " for output in self.validation_step_outputs:\n", + " val_loss += output[\"val_loss\"].sum().item()\n", + " num_items += output[\"val_number\"]\n", + " mean_val_dice = self.dice_metric.aggregate().item()\n", + " self.dice_metric.reset()\n", + " mean_val_loss = torch.tensor(val_loss / num_items)\n", + " tensorboard_logs = {\n", + " \"val_dice\": mean_val_dice,\n", + " \"val_loss\": mean_val_loss,\n", + " }\n", + " if mean_val_dice > self.best_val_dice:\n", + " self.best_val_dice = mean_val_dice\n", + " self.best_val_epoch = self.current_epoch\n", + " print(\n", + " f\"current epoch: {self.current_epoch} \"\n", + " f\"current mean dice: {mean_val_dice:.4f}\"\n", + " f\"\\nbest mean dice: {self.best_val_dice:.4f} \"\n", + " f\"at epoch: {self.best_val_epoch}\"\n", + " )\n", + " self.validation_step_outputs.clear() # free memory\n", + " return {\"log\": tensorboard_logs}" + ] + }, + { + "cell_type": "markdown", + "id": "92e303feb8d4edca", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 0 + }, + "source": [ + "Next, we'll create a `main.py` file to house the training and evaluation scripts. In this example, we use the `lightning_param` dictionary to customize some default arguments in the PyTorch Lightning `Trainer` class. We've set `num_nodes` and `devices` to 1, turned off the sanity checking (`num_sanity_val_steps=0`), and logged the training for every 3 steps (`log_every_n_steps=3`) for demonstration purposes. For more information about the PyTorch Lightning `Trainer` arguments, please refer to the following [link](https://lightning.ai/docs/pytorch/stable/common/trainer.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d49daec7d4ce0b75", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T14:45:48.900509600Z", + "start_time": "2023-10-16T14:45:48.887211086Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing SpleenSegLightning/scripts/main.py\n" + ] + } + ], + "source": [ + "%%writefile SpleenSegLightning/scripts/main.py\n", + "from scripts.model import MySegNet\n", + "import pytorch_lightning\n", + "\n", + "def train(lightninig_param, train_dl, val_dl):\n", + " net = MySegNet()\n", + " trainer = pytorch_lightning.Trainer(max_epochs=lightninig_param['max_epochs'], \n", + " default_root_dir=lightninig_param['default_root_dir'],\n", + " check_val_every_n_epoch=lightninig_param['check_val_every_n_epoch'],\n", + " devices=1, num_nodes=1, log_every_n_steps=3, num_sanity_val_steps=0)\n", + " trainer.fit(model=net, train_dataloaders=train_dl, val_dataloaders=val_dl)\n", + "\n", + "\n", + "def evaluate(lightninig_param, ckpt_file, val_dl):\n", + " net = MySegNet()\n", + " trainer = pytorch_lightning.Trainer(default_root_dir=lightninig_param['default_root_dir'],\n", + " devices=1, num_nodes=1)\n", + " trainer.validate(model=net, dataloaders=val_dl, ckpt_path=ckpt_file)" + ] + }, + { + "cell_type": "markdown", + "id": "eaf81ea7-9ea3-4548-a32e-992f0b9bc0ab", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "## Training\n", + "Now, we'll define a `train.yaml` file to be used to set the configurations for the training stage:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4dfd052e-abe7-473a-bbf4-25674a3b20ea", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T14:55:44.372511589Z", + "start_time": "2023-10-16T14:55:44.304832953Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing SpleenSegLightning/configs/train.yaml\n" + ] + } + ], + "source": [ + "%%writefile SpleenSegLightning/configs/train.yaml\n", + "\n", + "imports:\n", + "- $from scripts.main import train\n", + "- $import glob\n", + "- $import os\n", + "\n", + "# define a default root directory value, this can overridden on the command line\n", + "bundle_dir: .\n", + "data_dir: .\n", + "\n", + "# define hyperparameters for the lightning trainer\n", + "max_epochs: 50\n", + "default_root_dir: $@bundle_dir+\"/lightning_logs\"\n", + "check_val_every_n_epoch: 1\n", + "\n", + "lightninig_param: '${\n", + " ''max_epochs'': @max_epochs,\n", + " ''default_root_dir'': @default_root_dir,\n", + " ''check_val_every_n_epoch'': @check_val_every_n_epoch,\n", + "}'\n", + "\n", + "\n", + "# define a transform sequence by instantiating a Compose instance with a transform sequence\n", + "train_transform:\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: ['@image','@label']\n", + " image_only: true\n", + " - _target_: EnsureChannelFirstd\n", + " keys: ['@image','@label']\n", + " - _target_: Orientationd\n", + " keys: ['@image','@label']\n", + " axcodes: 'RAS'\n", + " - _target_: Spacingd\n", + " keys: ['@image','@label']\n", + " pixdim: [1.5, 1.5, 2.0]\n", + " - _target_: ScaleIntensityRanged\n", + " keys: '@image'\n", + " a_min: -57\n", + " a_max: 164\n", + " b_min: 0.0\n", + " b_max: 1.0\n", + " clip: True\n", + " - _target_: CropForegroundd\n", + " keys: ['@image','@label']\n", + " allow_smaller: False\n", + " source_key: '@image'\n", + " - _target_: RandCropByPosNegLabeld\n", + " keys: ['@image','@label']\n", + " label_key: '@label'\n", + " spatial_size: [96, 96, 96]\n", + " pos: 1\n", + " neg: 1\n", + " num_samples: 4\n", + " image_key: '@image'\n", + " image_threshold: 0\n", + "\n", + "val_transform:\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: ['@image','@label']\n", + " image_only: true\n", + " - _target_: EnsureChannelFirstd\n", + " keys: ['@image','@label']\n", + " - _target_: Orientationd\n", + " keys: ['@image','@label']\n", + " axcodes: 'RAS'\n", + " - _target_: Spacingd\n", + " keys: ['@image','@label']\n", + " pixdim: [1.5, 1.5, 2.0]\n", + " - _target_: ScaleIntensityRanged\n", + " keys: '@image'\n", + " a_min: -57\n", + " a_max: 164\n", + " b_min: 0.0\n", + " b_max: 1.0\n", + " clip: True\n", + " - _target_: CropForegroundd\n", + " keys: ['@image','@label']\n", + " source_key: '@image'\n", + " allow_smaller: False\n", + "\n", + "val_dataset:\n", + " _target_: CacheDataset\n", + " data: '@val_files'\n", + " transform: '@val_transform'\n", + " cache_rate: 1.0\n", + " num_workers: 4\n", + "\n", + "train_dataset:\n", + " _target_: CacheDataset\n", + " data: '@train_files'\n", + " transform: '@train_transform'\n", + " cache_rate: 1.0\n", + " num_workers: 4\n", + " \n", + "train_dl:\n", + " _target_: DataLoader\n", + " dataset: '@train_dataset'\n", + " batch_size: 1\n", + " shuffle: true\n", + " num_workers: 4\n", + " \n", + "val_dl:\n", + " _target_: DataLoader\n", + " dataset: '@val_dataset'\n", + " batch_size: 1\n", + " shuffle: false\n", + " num_workers: 4\n", + "\n", + "train:\n", + "- '$train(@lightninig_param, @train_dl, @val_dl)'" + ] + }, + { + "cell_type": "markdown", + "id": "de752181-80b1-4221-9e4a-315e5f7f22a6", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "source": [ + "We can now train as normal to replicate the original code. For demonstration purpose, we set `max_epochs=1`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1d8ac6fd81493874", + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-16T14:57:13.955241252Z", + "start_time": "2023-10-16T14:57:05.329343826Z" + }, + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "workflow_name None\n", + "config_file ['./SpleenSegLightning/configs/common.yaml', './SpleenSegLightning/configs/train.yaml']\n", + "meta_file ./SpleenSegLightning/configs/metadata.json\n", + "logging_file None\n", + "init_id None\n", + "run_id train\n", + "final_id None\n", + "tracking None\n", + "bundle_dir ./SpleenSegLightning\n", + "data_dir ./Task09_Spleen\n", + "max_epochs 1\n", + "2023-10-18 11:36:59,810 - INFO - --- input summary of monai.bundle.scripts.run ---\n", + "2023-10-18 11:36:59,810 - INFO - > config_file: ['./SpleenSegLightning/configs/common.yaml',\n", + " './SpleenSegLightning/configs/train.yaml']\n", + "2023-10-18 11:36:59,810 - INFO - > meta_file: './SpleenSegLightning/configs/metadata.json'\n", + "2023-10-18 11:36:59,811 - INFO - > run_id: 'train'\n", + "2023-10-18 11:36:59,811 - INFO - > bundle_dir: './SpleenSegLightning'\n", + "2023-10-18 11:36:59,811 - INFO - > data_dir: './Task09_Spleen'\n", + "2023-10-18 11:36:59,811 - INFO - > max_epochs: 1\n", + "2023-10-18 11:36:59,811 - INFO - ---\n", + "\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "monai.bundle.workflows ConfigWorkflow.__init__:workflow_type: Current default value of argument `workflow_type=None` has been deprecated since version 1.2. It will be changed to `workflow_type=train` in version 1.4.\n", + "Default logging file in SpleenSegLightning/configs/logging.conf does not exist, skipping logging.\n", + "Loading dataset: 100%|██████████| 32/32 [00:49<00:00, 1.53s/it]\n", + "Loading dataset: 100%|██████████| 9/9 [00:10<00:00, 1.17s/it]\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "Missing logger folder: SpleenSegLightning/lightning_logs/lightning_logs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "\n", + " | Name | Type | Params\n", + "-------------------------------------------\n", + "0 | _model | UNet | 4.8 M \n", + "1 | loss_function | DiceLoss | 0 \n", + "-------------------------------------------\n", + "4.8 M Trainable params\n", + "0 Non-trainable params\n", + "4.8 M Total params\n", + "19.236 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "configure_optimizers 0.0001\n", + "Epoch 0: 100%|██████████| 32/32 [00:09<00:00, 3.24it/s, v_num=0]\n", + "Validation: 0it [00:00, ?it/s]\u001b[A\n", + "Validation: 0%| | 0/9 [00:00 config_file: ['./SpleenSegLightning/configs/common.yaml',\n", + " './SpleenSegLightning/configs/evaluate.yaml']\n", + "2023-10-18 11:38:38,049 - INFO - > meta_file: './SpleenSegLightning/configs/metadata.json'\n", + "2023-10-18 11:38:38,049 - INFO - > run_id: 'evaluate'\n", + "2023-10-18 11:38:38,049 - INFO - > bundle_dir: './SpleenSegLightning'\n", + "2023-10-18 11:38:38,049 - INFO - > data_dir: './Task09_Spleen'\n", + "2023-10-18 11:38:38,049 - INFO - > ckpt_file: './epoch=599-step=9600.ckpt'\n", + "2023-10-18 11:38:38,049 - INFO - ---\n", + "\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "monai.bundle.workflows ConfigWorkflow.__init__:workflow_type: Current default value of argument `workflow_type=None` has been deprecated since version 1.2. It will be changed to `workflow_type=train` in version 1.4.\n", + "Default logging file in SpleenSegLightning/configs/logging.conf does not exist, skipping logging.\n", + "Loading dataset: 100%|██████████| 9/9 [00:10<00:00, 1.18s/it]\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "Restoring states from the checkpoint path at ./epoch=599-step=9600.ckpt\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "Loaded model weights from the checkpoint at ./epoch=599-step=9600.ckpt\n" + ] + } + ], + "source": [ + "%%bash\n", + "\n", + "DATA_DIR=\"./Task09_Spleen\"\n", + "BUNDLE=\"./SpleenSegLightning\"\n", + "export PYTHONPATH=\"$BUNDLE\"\n", + "\n", + "python -m monai.bundle run evaluate \\\n", + " --bundle_dir \"$BUNDLE\" \\\n", + " --data_dir \"$DATA_DIR\" \\\n", + " --meta_file \"$BUNDLE/configs/metadata.json\" \\\n", + " --config_file \"['$BUNDLE/configs/common.yaml','$BUNDLE/configs/evaluate.yaml']\" \\\n", + " --ckpt_file \"./epoch=599-step=9600.ckpt\"" + ] + }, + { + "cell_type": "markdown", + "id": "6fd62905-4ea8-4f08-bcea-823074fc4ce4", + "metadata": {}, + "source": [ + "## Summary and Next\n", + "\n", + "This tutorial has covered:\n", + "* Creating full training and evaluation scripts in bundles using MONAI and Pytorch Lightning\n", + "* Training a network then evaluating its performance with scripts." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bundle/README.md b/bundle/README.md index 7cb30e8895..35dd011310 100644 --- a/bundle/README.md +++ b/bundle/README.md @@ -14,6 +14,7 @@ Start the tutorial notebooks on constructing bundles: 2. [MedNIST Classification](./02_mednist_classification.ipynb): train a network using the bundle for doing a real task. 3. [MedNIST Classification With Best Practices](./03_mednist_classification_v2.ipynb): do the same again but better. 4. [Integrating Existing Code](./04_integrating_code.ipynb): discussion on how to integrate existing, possible non-MONAI, code into a bundle. +5. [Spleen Segmentation using Pytorch Lightning](./05_spleen_segmentation.ipynb): train a segmentation network using MONAI bundle and Pytorch Lightning. More advanced topics are covered in this directory: