diff --git a/research/lamp-automated-model-parallelism/README.md b/research/lamp-automated-model-parallelism/README.md
new file mode 100644
index 0000000000..321d7a2cdf
--- /dev/null
+++ b/research/lamp-automated-model-parallelism/README.md
@@ -0,0 +1,53 @@
+# LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation
+
+
+
+
+
+
+> If you use this work in your research, please cite the paper.
+
+A reimplementation of the LAMP system originally proposed by:
+
+Wentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020)
+"LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation."
+MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575)
+
+
+## To run the demo:
+
+### Prerequisites
+- install the latest version of MONAI: `git clone https://github.com/Project-MONAI/MONAI` and `pip install -e .`
+- `pip install torchgpipe`
+
+### Data
+```bash
+mkdir ./data;
+cd ./data;
+```
+Head and Neck CT dataset
+
+Please download and unzip the images into `./data` folder.
+
+- `HaN.zip`: https://drive.google.com/file/d/1A2zpVlR3CkvtkJPvtAF3-MH0nr1WZ2Mn/view?usp=sharing
+```bash
+unzip HaN.zip; # unzip
+```
+
+Please find more details of the dataset at https://github.com/wentaozhu/AnatomyNet-for-anatomical-segmentation.git
+
+
+### Minimal hardware requirements for full image training
+- U-Net (`n_feat=32`): 2x 16Gb GPUs
+- U-Net (`n_feat=64`): 4x 16Gb GPUs
+- U-Net (`n_feat=128`): 2x 32Gb GPUs
+
+
+### Commands
+The number of features in the first block (`--n_feat`) can be 32, 64, or 128.
+```bash
+mkdir ./log;
+python train.py --n_feat=128 --crop_size='64,64,64' --bs=16 --ep=4800 --lr=0.001 > ./log/YOURLOG.log
+python train.py --n_feat=128 --crop_size='128,128,128' --bs=4 --ep=1200 --lr=0.001 --pretrain='./HaN_32_16_1200_64,64,64_0.001_*' > ./log/YOURLOG.log
+python train.py --n_feat=128 --crop_size='-1,-1,-1' --bs=1 --ep=300 --lr=0.001 --pretrain='./HaN_32_16_1200_64,64,64_0.001_*' > ./log/YOURLOG.log
+```
diff --git a/research/lamp-automated-model-parallelism/__init__.py b/research/lamp-automated-model-parallelism/__init__.py
new file mode 100644
index 0000000000..d0044e3563
--- /dev/null
+++ b/research/lamp-automated-model-parallelism/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2020 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/research/lamp-automated-model-parallelism/data_utils.py b/research/lamp-automated-model-parallelism/data_utils.py
new file mode 100644
index 0000000000..b4825c1910
--- /dev/null
+++ b/research/lamp-automated-model-parallelism/data_utils.py
@@ -0,0 +1,66 @@
+# Copyright 2020 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+from monai.transforms import DivisiblePad
+
+STRUCTURES = (
+ "BrainStem",
+ "Chiasm",
+ "Mandible",
+ "OpticNerve_L",
+ "OpticNerve_R",
+ "Parotid_L",
+ "Parotid_R",
+ "Submandibular_L",
+ "Submandibular_R",
+)
+
+
+def get_filenames(path, maskname=STRUCTURES):
+ """
+ create file names according to the predefined folder structure.
+
+ Args:
+ path: data folder name
+ maskname: target structure names
+ """
+ maskfiles = []
+ for seg in maskname:
+ if os.path.exists(os.path.join(path, "./structures/" + seg + "_crp_v2.npy")):
+ maskfiles.append(os.path.join(path, "./structures/" + seg + "_crp_v2.npy"))
+ else:
+ # the corresponding mask is missing seg, path.split("/")[-1]
+ maskfiles.append(None)
+ return os.path.join(path, "img_crp_v2.npy"), maskfiles
+
+
+def load_data_and_mask(data, mask_data):
+ """
+ Load data filename and mask_data (list of file names)
+ into a dictionary of {'image': array, "label": list of arrays, "name": str}.
+ """
+ pad_xform = DivisiblePad(k=32)
+ img = np.load(data) # z y x
+ img = pad_xform(img[None])[0]
+ item = dict(image=img, label=[])
+ for idx, maskfnm in enumerate(mask_data):
+ if maskfnm is None:
+ ms = np.zeros(img.shape, np.uint8)
+ else:
+ ms = np.load(maskfnm).astype(np.uint8)
+ assert ms.min() == 0 and ms.max() == 1
+ mask = pad_xform(ms[None])[0]
+ item["label"].append(mask)
+ assert len(item["label"]) == 9
+ item["name"] = str(data)
+ return item
diff --git a/research/lamp-automated-model-parallelism/fig/acc_speed_han_0_5hor.png b/research/lamp-automated-model-parallelism/fig/acc_speed_han_0_5hor.png
new file mode 100644
index 0000000000..f8a8254832
Binary files /dev/null and b/research/lamp-automated-model-parallelism/fig/acc_speed_han_0_5hor.png differ
diff --git a/research/lamp-automated-model-parallelism/test_unet_pipe.py b/research/lamp-automated-model-parallelism/test_unet_pipe.py
new file mode 100644
index 0000000000..6783996480
--- /dev/null
+++ b/research/lamp-automated-model-parallelism/test_unet_pipe.py
@@ -0,0 +1,52 @@
+# Copyright 2020 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from unet_pipe import UNetPipe
+
+TEST_CASES = [
+ [ # 1-channel 3D, batch 12
+ {"spatial_dims": 3, "out_channels": 2, "in_channels": 1, "depth": 3, "n_feat": 8},
+ torch.randn(12, 1, 32, 64, 48),
+ (12, 2, 32, 64, 48),
+ ],
+ [ # 1-channel 3D, batch 16
+ {"spatial_dims": 3, "out_channels": 2, "in_channels": 1, "depth": 3},
+ torch.randn(16, 1, 32, 64, 48),
+ (16, 2, 32, 64, 48),
+ ],
+ [ # 4-channel 3D, batch 16, batch normalisation
+ {"spatial_dims": 3, "out_channels": 3, "in_channels": 2},
+ torch.randn(16, 2, 64, 64, 64),
+ (16, 3, 64, 64, 64),
+ ],
+]
+
+
+class TestUNETPipe(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_shape(self, input_param, input_data, expected_shape):
+ net = UNetPipe(**input_param)
+ if torch.cuda.is_available():
+ net = net.to(torch.device("cuda"))
+ input_data = input_data.to(torch.device("cuda"))
+ net.eval()
+ with torch.no_grad():
+ result = net.forward(input_data.float())
+ self.assertEqual(result.shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/research/lamp-automated-model-parallelism/train.py b/research/lamp-automated-model-parallelism/train.py
new file mode 100644
index 0000000000..1f6f578591
--- /dev/null
+++ b/research/lamp-automated-model-parallelism/train.py
@@ -0,0 +1,242 @@
+# Copyright 2020 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from argparse import ArgumentParser
+import os
+
+import numpy as np
+import torch
+from monai.transforms import AddChannelDict, Compose, RandCropByPosNegLabeld, Rand3DElasticd, SpatialPadd
+from monai.losses import DiceLoss, FocalLoss
+from monai.metrics import compute_meandice
+from monai.data import Dataset, list_data_collate
+from monai.utils import first
+from torchgpipe import GPipe
+from torchgpipe.balance import balance_by_size
+
+from unet_pipe import UNetPipe, flatten_sequential
+from data_utils import get_filenames, load_data_and_mask
+
+N_CLASSES = 10
+TRAIN_PATH = "./data/HaN/train/" # training data folder
+VAL_PATH = "./data/HaN/test/" # validation data folder
+
+torch.backends.cudnn.enabled = True
+
+
+class ImageLabelDataset:
+ """
+ Load image and multi-class labels based on the predefined folder structure.
+ """
+
+ def __init__(self, path, n_class=10):
+ self.path = path
+ self.data = sorted(os.listdir(path))
+ self.n_class = n_class
+
+ def __getitem__(self, index):
+ data = os.path.join(self.path, self.data[index])
+ train_data, train_masks_data = get_filenames(data)
+ data = load_data_and_mask(train_data, train_masks_data) # read into a data dict
+ # loading image
+ data["image"] = data["image"].astype(np.float32) # shape (H W D)
+ # loading labels
+ class_shape = (1,) + data["image"].shape
+ mask0 = np.zeros(class_shape)
+ mask_list = []
+ flagvect = np.ones((self.n_class,), np.float32)
+ for i, mask in enumerate(data["label"]):
+ if mask is None:
+ mask = np.zeros(class_shape)
+ flagvect[0] = 0
+ flagvect[i + 1] = 0
+ mask0 = np.logical_or(mask0, mask)
+ mask_list.append(mask.reshape(class_shape))
+ mask0 = 1 - mask0
+ data["label"] = np.concatenate([mask0] + mask_list, axis=0).astype(np.uint8) # shape (C H W D)
+ # setting flags
+ data["with_complete_groundtruth"] = flagvect # flagvec is a boolean indicator for complete annotation
+ return data
+
+ def __len__(self):
+ return len(self.data)
+
+
+def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None):
+ model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_"
+ print(f"save the best model as '{model_name}' during training.")
+
+ crop_size = [int(cz) for cz in crop_size.split(",")]
+ print(f"input image crop_size: {crop_size}")
+
+ # starting training set loader
+ train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES)
+ if np.any([cz == -1 for cz in crop_size]): # using full image
+ train_transform = Compose(
+ [
+ AddChannelDict(keys="image"),
+ Rand3DElasticd(
+ keys=("image", "label"),
+ spatial_size=crop_size,
+ sigma_range=(10, 50), # 30
+ magnitude_range=[600, 1200], # 1000
+ prob=0.8,
+ rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
+ shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
+ translate_range=(sz * 0.05 for sz in crop_size),
+ scale_range=(0.2, 0.2, 0.2),
+ mode=("bilinear", "nearest"),
+ padding_mode=("border", "zeros"),
+ ),
+ ]
+ )
+ train_dataset = Dataset(train_images, transform=train_transform)
+ # when bs > 1, the loader assumes that the full image sizes are the same across the dataset
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True)
+ else:
+ # draw balanced foreground/background window samples according to the ground truth label
+ train_transform = Compose(
+ [
+ AddChannelDict(keys="image"),
+ SpatialPadd(keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size
+ RandCropByPosNegLabeld(
+ keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs
+ ),
+ Rand3DElasticd(
+ keys=("image", "label"),
+ spatial_size=crop_size,
+ sigma_range=(10, 50), # 30
+ magnitude_range=[600, 1200], # 1000
+ prob=0.8,
+ rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
+ shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
+ translate_range=(sz * 0.05 for sz in crop_size),
+ scale_range=(0.2, 0.2, 0.2),
+ mode=("bilinear", "nearest"),
+ padding_mode=("border", "zeros"),
+ ),
+ ]
+ )
+ train_dataset = Dataset(train_images, transform=train_transform) # each dataset item is a list of windows
+ train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor
+ train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate
+ )
+ first_sample = first(train_dataloader)
+ print(first_sample["image"].shape)
+
+ # starting validation set loader
+ val_transform = Compose([AddChannelDict(keys="image")])
+ val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform)
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1)
+ print(val_dataset[0]["image"].shape)
+ print(f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}")
+
+ model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat)
+ model = flatten_sequential(model)
+ lossweight = torch.from_numpy(np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32))
+
+ if optimizer.lower() == "rmsprop":
+ optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4
+ elif optimizer.lower() == "momentum":
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning
+ else:
+ raise ValueError(f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum').")
+
+ # config GPipe
+ x = first_sample["image"].float()
+ x = torch.autograd.Variable(x.cuda())
+ partitions = torch.cuda.device_count()
+ print(f"partition: {partitions}, input: {x.size()}")
+ balance = balance_by_size(partitions, model, x)
+ model = GPipe(model, balance, chunks=4, checkpoint="always")
+
+ # config loss functions
+ dice_loss_func = DiceLoss(softmax=True, reduction="none")
+ # use the same pipeline and loss in
+ # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy,
+ # Medical Physics, 2018.
+ focal_loss_func = FocalLoss(reduction="none")
+
+ if pretrain:
+ print(f"loading from {pretrain}.")
+ pretrained_dict = torch.load(pretrain)["weight"]
+ model_dict = model.state_dict()
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(pretrained_dict)
+
+ b_time = time.time()
+ best_val_loss = [0] * (N_CLASSES - 1) # foreground
+ best_ave = -1
+ for epoch in range(ep):
+ model.train()
+ trainloss = 0
+ for b_idx, data_dict in enumerate(train_dataloader):
+ x_train = data_dict["image"]
+ y_train = data_dict["label"]
+ flagvec = data_dict["with_complete_groundtruth"]
+
+ x_train = torch.autograd.Variable(x_train.cuda())
+ y_train = torch.autograd.Variable(y_train.cuda().float())
+ optimizer.zero_grad()
+ o = model(x_train).to(0, non_blocking=True).float()
+
+ loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean()
+ loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean()
+ loss.backward()
+ optimizer.step()
+ trainloss += loss.item()
+
+ if b_idx % 20 == 0:
+ print(f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}")
+ print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}")
+
+ if epoch % 10 == 0:
+ model.eval()
+ # check validation dice
+ val_loss = [0] * (N_CLASSES - 1)
+ n_val = [0] * (N_CLASSES - 1)
+ for data_dict in val_dataloader:
+ x_val = data_dict["image"]
+ y_val = data_dict["label"]
+ with torch.no_grad():
+ x_val = torch.autograd.Variable(x_val.cuda())
+ o = model(x_val).to(0, non_blocking=True)
+ loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False)
+ val_loss = [l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss)]
+ n_val = [n + 1 if l == l else n for l, n in zip(loss[0], n_val)]
+ val_loss = [l / n for l, n in zip(val_loss, n_val)]
+ print("validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss))
+ for c in range(1, 10):
+ if best_val_loss[c - 1] < val_loss[c - 1]:
+ best_val_loss[c - 1] = val_loss[c - 1]
+ state = {"epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1]}
+ torch.save(state, f"{model_name}" + str(c))
+ print("best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss))
+
+ print("total time", time.time() - b_time)
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("--n_feat", type=int, default=32, dest="n_feat")
+ parser.add_argument("--crop_size", type=str, default="-1,-1,-1", dest="crop_size")
+ parser.add_argument("--bs", type=int, default=1, dest="bs") # batch size
+ parser.add_argument("--ep", type=int, default=150, dest="ep") # number of epochs
+ parser.add_argument("--lr", type=float, default=5e-4, dest="lr") # learning rate
+ parser.add_argument("--optimizer", type=str, default="rmsprop", dest="optimizer") # type of optimizer
+ parser.add_argument("--pretrain", type=str, default=None, dest="pretrain")
+ args = parser.parse_args()
+
+ input_dict = vars(args)
+ print(input_dict)
+ train(**input_dict)
diff --git a/research/lamp-automated-model-parallelism/unet_pipe.py b/research/lamp-automated-model-parallelism/unet_pipe.py
new file mode 100644
index 0000000000..d563de8257
--- /dev/null
+++ b/research/lamp-automated-model-parallelism/unet_pipe.py
@@ -0,0 +1,171 @@
+# Copyright 2020 MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+from typing import List
+
+import torch
+from monai.networks.blocks import Convolution, UpSample
+from monai.networks.layers.factories import Act, Conv, Norm
+from torch import nn
+from torchgpipe.skip import Namespace, pop, skippable, stash
+
+
+@skippable(stash=["skip"], pop=[])
+class Stash(nn.Module):
+ def forward(self, input: torch.Tensor):
+ yield stash("skip", input)
+ return input # noqa using yield together with return
+
+
+@skippable(stash=[], pop=["skip"])
+class PopCat(nn.Module):
+ def forward(self, input: torch.Tensor):
+ skip = yield pop("skip")
+ if skip is not None:
+ input = torch.cat([input, skip], dim=1)
+ return input
+
+
+def flatten_sequential(module: nn.Sequential):
+ """
+ Recursively make all the submodules sequential.
+
+ Args:
+ module: a torch sequential model.
+ """
+ if not isinstance(module, nn.Sequential):
+ raise TypeError("module must be a nn.Sequential instance.")
+
+ def _flatten(module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Sequential):
+ for sub_name, sub_child in _flatten(child):
+ yield f"{name}_{sub_name}", sub_child
+ else:
+ yield name, child
+
+ return nn.Sequential(OrderedDict(_flatten(module)))
+
+
+class DoubleConv(nn.Module):
+ def __init__(
+ self,
+ spatial_dims,
+ in_channels,
+ out_channels,
+ stride=2,
+ act_1=Act.LEAKYRELU,
+ norm_1=Norm.BATCH,
+ act_2=Act.LEAKYRELU,
+ norm_2=Norm.BATCH,
+ conv_only=True,
+ ):
+ """
+ A sequence of Conv_1 + Norm_1 + Act_1 + Conv_2 (+ Norm_2 + Act_2).
+
+ `norm_2` and `act_2` are ignored when `conv_only` is True.
+ `stride` is for `Conv_1`, typically stride=2 for 2x spatial downsampling.
+
+ Args:
+ spatial_dims: number of the input spatial dimension.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ stride: stride of the first conv., mainly used for 2x downsampling when stride=2.
+ act_1: activation type of the first convolution.
+ norm_1: normalization type of the first convolution.
+ act_2: activation type of the second convolution.
+ norm_2: normalization type of the second convolution.
+ conv_only: whether the second conv is convolution layer only. Default to True,
+ indicates that `act_2` and `norm_2` are not in use.
+ """
+ super(DoubleConv, self).__init__()
+ self.conv = nn.Sequential(
+ Convolution(spatial_dims, in_channels, out_channels, strides=stride, act=act_1, norm=norm_1, bias=False,),
+ Convolution(spatial_dims, out_channels, out_channels, act=act_2, norm=norm_2, conv_only=conv_only),
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class UNetPipe(nn.Sequential):
+ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, n_feat: int = 32, depth: int = 4):
+ """
+ A UNet-like architecture for model parallelism.
+
+ Args:
+ spatial_dims: number of input spatial dimensions,
+ 2 for (B, in_channels, H, W), 3 for (B, in_channels, H, W, D).
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ n_feat: number of features in the first convolution.
+ depth: number of downsampling stages.
+ """
+ super(UNetPipe, self).__init__()
+ n_enc_filter: List[int] = [n_feat]
+ for i in range(1, depth + 1):
+ n_enc_filter.append(min(n_enc_filter[-1] * 2, 1024))
+ namespaces = [Namespace() for _ in range(depth)]
+
+ # construct the encoder
+ encoder_layers: List[nn.Module] = []
+ init_conv = Convolution(
+ spatial_dims, in_channels, n_enc_filter[0], strides=2, act=Act.LEAKYRELU, norm=Norm.BATCH, bias=False,
+ )
+ encoder_layers.append(
+ nn.Sequential(OrderedDict([("Conv", init_conv,), ("skip", Stash().isolate(namespaces[0]))]))
+ )
+ for i in range(1, depth + 1):
+ down_conv = DoubleConv(spatial_dims, n_enc_filter[i - 1], n_enc_filter[i])
+ if i == depth:
+ layer_dict = OrderedDict([("Down", down_conv)])
+ else:
+ layer_dict = OrderedDict([("Down", down_conv), ("skip", Stash().isolate(namespaces[i]))])
+ encoder_layers.append(nn.Sequential(layer_dict))
+ encoder = nn.Sequential(*encoder_layers)
+
+ # construct the decoder
+ decoder_layers: List[nn.Module] = []
+ for i in reversed(range(1, depth + 1)):
+ in_ch, out_ch = n_enc_filter[i], n_enc_filter[i - 1]
+ layer_dict = OrderedDict(
+ [
+ ("Up", UpSample(spatial_dims, in_ch, out_ch, 2, True)),
+ ("skip", PopCat().isolate(namespaces[i - 1])),
+ ("Conv1x1x1", Conv[Conv.CONV, spatial_dims](out_ch * 2, in_ch, kernel_size=1)),
+ ("Conv", DoubleConv(spatial_dims, in_ch, out_ch, stride=1, conv_only=True)),
+ ]
+ )
+ decoder_layers.append(nn.Sequential(layer_dict))
+ in_ch = min(n_enc_filter[0] // 2, 32)
+ layer_dict = OrderedDict(
+ [
+ ("Up", UpSample(spatial_dims, n_feat, in_ch, 2, True)),
+ ("RELU", Act[Act.LEAKYRELU](inplace=False)),
+ ("out", Conv[Conv.CONV, spatial_dims](in_ch, out_channels, kernel_size=3, padding=1),),
+ ]
+ )
+ decoder_layers.append(nn.Sequential(layer_dict))
+ decoder = nn.Sequential(*decoder_layers)
+
+ # making a sequential model
+ self.add_module("encoder", encoder)
+ self.add_module("decoder", decoder)
+
+ for m in self.modules():
+ if isinstance(m, Conv[Conv.CONV, spatial_dims]):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, Norm[Norm.BATCH, spatial_dims]):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, Conv[Conv.CONVTRANS, spatial_dims]):
+ nn.init.kaiming_normal_(m.weight)
diff --git a/setup.py b/setup.py
index 83372856ea..5158fa1fb9 100644
--- a/setup.py
+++ b/setup.py
@@ -17,7 +17,7 @@
setup(
version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(),
- packages=find_packages(exclude=("docs", "examples", "tests")),
+ packages=find_packages(exclude=("docs", "examples", "tests", "research")),
zip_safe=False,
package_data={"monai": ["py.typed"]},
)