From 530cc1f35e40e50d4b35de30f75b75ef6fb78d08 Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 7 Nov 2024 09:30:40 +0100 Subject: [PATCH 01/23] Fix ImageFilter to allow Gaussian filter without filter_size (#8189) Fixes #8127 Update `ImageFilter` to handle Gaussian filter without requiring `filter_size`. * Modify `monai/transforms/utility/array.py` to allow Gaussian filter without `filter_size`. - Adjust `_check_filter_format` method to skip `filter_size` check for Gaussian filter. Indeed Gauss filter is the only one in the list that doesn't require a filter_size. * Add unit test in `tests/test_image_filter.py` for Gaussian filter without `filter_size`. - Verify output shape matches input shape. Note that this method is compliant with the dictionnary version since this one load the fixed version. Signed-off-by: Eloi --------- Signed-off-by: Eloi Navet Signed-off-by: Eloi Signed-off-by: Eloi eloi.navet@gmail.com --- monai/transforms/utility/array.py | 4 ++-- tests/test_image_filter.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 72dd189009..1b3c59afdb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1609,9 +1609,9 @@ def _check_all_values_uneven(self, x: tuple) -> None: def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None: if isinstance(filter, str): - if not filter_size: + if filter != "gauss" and not filter_size: # Gauss is the only filter that does not require `filter_size` raise ValueError("`filter_size` must be specified when specifying filters by string.") - if filter_size % 2 == 0: + if filter_size and filter_size % 2 == 0: raise ValueError("`filter_size` should be a single uneven integer.") if filter not in self.supported_filters: raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.") diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 76e38d94f4..fb08b2295d 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -134,6 +134,12 @@ def test_pass_empty_metadata_dict(self): out_tensor = filter(image) self.assertTrue(isinstance(out_tensor, MetaTensor)) + def test_gaussian_filter_without_filter_size(self): + "Test Gaussian filter without specifying filter_size" + filter = ImageFilter("gauss", sigma=2) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + class TestImageFilterDict(unittest.TestCase): From 0bb20a88ec7869f6453aa58890df50ad6b2b6271 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 11 Nov 2024 10:51:20 +0800 Subject: [PATCH 02/23] Update base image to 2410 (#8164) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/cron.yml | 16 ++++++++++------ Dockerfile | 2 +- tests/test_trt_compile.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 6732ab7256..516e2d4743 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -16,7 +16,8 @@ jobs: - "PT110+CUDA113" - "PT113+CUDA118" - "PT210+CUDA121" - - "PTLATEST+CUDA124" + - "PT240+CUDA126" + - "PTLATEST+CUDA126" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT110+CUDA113 @@ -28,9 +29,12 @@ jobs: - environment: PT210+CUDA121 pytorch: "pytorch==2.1.0 torchvision==0.16.0 --extra-index-url https://download.pytorch.org/whl/cu121" base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.1 - - environment: PTLATEST+CUDA124 + - environment: PT240+CUDA126 + pytorch: "pytorch==2.4.0 torchvision==0.19.0 --extra-index-url https://download.pytorch.org/whl/cu121" + base: "nvcr.io/nvidia/pytorch:24.08-py3" # CUDA 12.6 + - environment: PTLATEST+CUDA126 pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121" - base: "nvcr.io/nvidia/pytorch:24.08-py3" # CUDA 12.4 + base: "nvcr.io/nvidia/pytorch:24.10-py3" # CUDA 12.6 container: image: ${{ matrix.base }} options: "--gpus all" @@ -80,7 +84,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:23.08", "pytorch:24.08"] + container: ["pytorch:23.08", "pytorch:24.08", "pytorch:24.10"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -129,7 +133,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:24.08"] + container: ["pytorch:24.10"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -233,7 +237,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:24.08-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:24.10-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, integration] steps: diff --git a/Dockerfile b/Dockerfile index e45932c6bb..5fcfcf274d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.10-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 5a56f0e4a2..6df5d520bd 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -46,7 +46,7 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) - @SkipIfAtLeastPyTorchVersion((2, 5, 0)) + @SkipIfAtLeastPyTorchVersion((2, 4, 1)) def test_handler(self): from ignite.engine import Engine From b6663b90f52b41eecd7e412ef5d8937ca640c6ad Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:16:39 +0800 Subject: [PATCH 03/23] Add SM architecture version check (#8199) Fixes #8198 NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. Review the [TensorRT Support Matrix](https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html) for which GPUs are supported by this release. Add SM architecture version check to skip trt test before 7.0. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/bundle/scripts.py | 2 + monai/networks/trt_compiler.py | 4 +- monai/utils/__init__.py | 1 + monai/utils/module.py | 41 +++++++++++++++++++ tests/test_bundle_trt_export.py | 9 +++- tests/test_convert_to_trt.py | 3 +- tests/test_trt_compile.py | 9 +++- ...version_after.py => test_version_after.py} | 22 ++++++++-- tests/utils.py | 16 +++++++- 9 files changed, 99 insertions(+), 8 deletions(-) rename tests/{test_pytorch_version_after.py => test_version_after.py} (72%) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 884723ed68..131c78008b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1589,6 +1589,8 @@ def trt_export( """ Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript. Currently, this API only supports converting models whose inputs are all tensors. + Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. + Review the TensorRT Support Matrix for which GPUs are supported. There are two ways to export a model: 1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript. diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a360f63dbd..d2d05fae22 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -505,7 +505,9 @@ def trt_compile( ) -> torch.nn.Module: """ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. - Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x. + NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. + Review the TensorRT Support Matrix for which GPUs are supported. Args: model: module to patch with TrtCompiler object. base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 79dc1f2304..8f2f400b5d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -107,6 +107,7 @@ InvalidPyTorchVersionError, OptionalImportError, allow_missing_reference, + compute_capabilities_after, damerau_levenshtein_distance, exact_version, get_full_type_name, diff --git a/monai/utils/module.py b/monai/utils/module.py index 1f7f8aecfc..d3f2ff09f2 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -634,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st if is_prerelease: return False return True + + +@functools.lru_cache(None) +def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool: + """ + Compute whether the current system GPU CUDA compute capability is after or equal to the specified version. + The current system GPU CUDA compute capability is determined by the first GPU in the system. + The compared version is a string in the form of "major.minor". + + Args: + major: major version number to be compared with. + minor: minor version number to be compared with. Defaults to 0. + current_ver_string: if None, the current system GPU CUDA compute capability will be used. + + Returns: + True if the current system GPU CUDA compute capability is greater than or equal to the specified version. + """ + if current_ver_string is None: + cuda_available = torch.cuda.is_available() + pynvml, has_pynvml = optional_import("pynvml") + if not has_pynvml: # assuming that the user has Ampere and later GPU + return True + if not cuda_available: + return False + else: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU + major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + current_ver_string = f"{major_c}.{minor_c}" + pynvml.nvmlShutdown() + + ver, has_ver = optional_import("packaging.version", name="parse") + if has_ver: + return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore + parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2) + while len(parts) < 2: + parts += ["0"] + c_major, c_minor = parts[:2] + c_mn = int(c_major), int(c_minor) + mn = int(major), int(minor) + return c_mn > mn diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 833a0ca1dc..835c8e5c1d 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -22,7 +22,13 @@ from monai.data import load_net_with_metadata from monai.networks import save_state from monai.utils import optional_import -from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import ( + SkipIfBeforeComputeCapabilityVersion, + command_line_tests, + skip_if_no_cuda, + skip_if_quick, + skip_if_windows, +) _, has_torchtrt = optional_import( "torch_tensorrt", @@ -47,6 +53,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestTRTExport(unittest.TestCase): def setUp(self): diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 5579539764..712d887c3b 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -20,7 +20,7 @@ from monai.networks import convert_to_trt from monai.networks.nets import UNet from monai.utils import optional_import -from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows _, has_torchtrt = optional_import( "torch_tensorrt", @@ -38,6 +38,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestConvertToTRT(unittest.TestCase): def setUp(self): diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 6df5d520bd..49404fdbbe 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -21,7 +21,13 @@ from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import ( + SkipIfAtLeastPyTorchVersion, + SkipIfBeforeComputeCapabilityVersion, + skip_if_no_cuda, + skip_if_quick, + skip_if_windows, +) trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -36,6 +42,7 @@ @skip_if_quick @unittest.skipUnless(trt_imported, "tensorrt is required") @unittest.skipUnless(polygraphy_imported, "polygraphy is required") +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestTRTCompile(unittest.TestCase): def setUp(self): diff --git a/tests/test_pytorch_version_after.py b/tests/test_version_after.py similarity index 72% rename from tests/test_pytorch_version_after.py rename to tests/test_version_after.py index 147707d2c0..b6cb741382 100644 --- a/tests/test_pytorch_version_after.py +++ b/tests/test_version_after.py @@ -15,9 +15,9 @@ from parameterized import parameterized -from monai.utils import pytorch_after +from monai.utils import compute_capabilities_after, pytorch_after -TEST_CASES = ( +TEST_CASES_PT = ( (1, 5, 9, "1.6.0"), (1, 6, 0, "1.6.0"), (1, 6, 1, "1.6.0", False), @@ -36,14 +36,30 @@ (1, 6, 1, "1.6.0+cpu", False), ) +TEST_CASES_SM = [ + # (major, minor, sm, expected) + (6, 1, "6.1", True), + (6, 1, "6.0", False), + (6, 0, "8.6", True), + (7, 0, "8", True), + (8, 6, "8", False), +] + class TestPytorchVersionCompare(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TEST_CASES_PT) def test_compare(self, a, b, p, current, expected=True): """Test pytorch_after with a and b""" self.assertEqual(pytorch_after(a, b, p, current), expected) +class TestComputeCapabilitiesAfter(unittest.TestCase): + + @parameterized.expand(TEST_CASES_SM) + def test_compute_capabilities_after(self, major, minor, sm, expected): + self.assertEqual(compute_capabilities_after(major, minor, sm), expected) + + if __name__ == "__main__": unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 77b53cebb8..2a00af50e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -47,7 +47,7 @@ from monai.networks import convert_to_onnx, convert_to_torchscript from monai.utils import optional_import from monai.utils.misc import MONAIEnvVars -from monai.utils.module import pytorch_after +from monai.utils.module import compute_capabilities_after, pytorch_after from monai.utils.tf32 import detect_default_tf32 from monai.utils.type_conversion import convert_data_type @@ -286,6 +286,20 @@ def __call__(self, obj): )(obj) +class SkipIfBeforeComputeCapabilityVersion: + """Decorator to be used if test should be skipped + with Compute Capability older than that given.""" + + def __init__(self, compute_capability_tuple): + self.min_version = compute_capability_tuple + self.version_too_old = not compute_capabilities_after(*compute_capability_tuple) + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}" + )(obj) + + def is_main_test_process(): ps = torch.multiprocessing.current_process() if not ps or not hasattr(ps, "name"): From 941e739c933691a2da2d111d3a693f47d6330939 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Wed, 13 Nov 2024 10:44:17 -0500 Subject: [PATCH 04/23] Add MedNext implementation (#8004) Fixes #7786 ### Description Added MedNext architectures implementation for MONAI. Since a lot of the code is heavily sourced from the original MedNext repo, https://github.com/MIC-DKFZ/MedNeXt, I wanted to check if there is an attribution policy with regarded to borrowed source code. I've added a derivative notice bellow the monai copyright comment. Let me know if this needs to be changed. The blocks have been taken almost as is but the network implementation has been changed largely to allow flexible blocks and follow MONAI segresnet styling. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Suraj Pai Signed-off-by: Robin CREMESE Co-authored-by: Robin CREMESE Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/mednext_block.py | 309 +++++++++++++++++++++ monai/networks/nets/__init__.py | 19 ++ monai/networks/nets/mednext.py | 354 +++++++++++++++++++++++++ tests/test_mednext.py | 122 +++++++++ 5 files changed, 805 insertions(+) create mode 100644 monai/networks/blocks/mednext_block.py create mode 100644 monai/networks/nets/mednext.py create mode 100644 tests/test_mednext.py diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 47abc4a1c4..499caf2e0f 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -26,6 +26,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7 from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock +from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock from .mlp import MLPBlock from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py new file mode 100644 index 0000000000..0aa2bb6b58 --- /dev/null +++ b/monai/networks/blocks/mednext_block.py @@ -0,0 +1,309 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this code are derived from the original repository at: +# https://github.com/MIC-DKFZ/MedNeXt +# and are used under the terms of the Apache License, Version 2.0. + +from __future__ import annotations + +import torch +import torch.nn as nn + +all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"] + + +def get_conv_layer(spatial_dim: int = 3, transpose: bool = False): + if spatial_dim == 2: + return nn.ConvTranspose2d if transpose else nn.Conv2d + else: # spatial_dim == 3 + return nn.ConvTranspose3d if transpose else nn.Conv3d + + +class MedNeXtBlock(nn.Module): + """ + MedNeXtBlock class for the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (int): Whether to use residual connection. Defaults to True. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: int = True, + norm_type: str = "group", + dim="3d", + global_resp_norm=False, + ): + + super().__init__() + + self.do_res = use_residual_connection + + self.dim = dim + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) + global_resp_norm_param_shape = (1,) * (2 if dim == "2d" else 3) + # First convolution layer with DepthWise Convolutions + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=in_channels, + ) + + # Normalization Layer. GroupNorm is used by default. + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore + elif norm_type == "layer": + self.norm = nn.LayerNorm( + normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore + ) + # Second convolution (Expansion) layer with Conv3D 1x1x1 + self.conv2 = conv( + in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0 + ) + + # GeLU activations + self.act = nn.GELU() + + # Third convolution (Compression) layer with Conv3D 1x1x1 + self.conv3 = conv( + in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + self.global_resp_norm = global_resp_norm + if self.global_resp_norm: + global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape + self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True) + self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True) + + def forward(self, x): + """ + Forward pass of the MedNeXtBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + x1 = x + x1 = self.conv1(x1) + x1 = self.act(self.conv2(self.norm(x1))) + + if self.global_resp_norm: + # gamma, beta: learnable affine transform parameters + # X: input of shape (N,C,H,W,D) + if self.dim == "2d": + gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True) + else: + gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6) + x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1 + x1 = self.conv3(x1) + if self.do_res: + x1 = x + x1 + return x1 + + +class MedNeXtDownBlock(MedNeXtBlock): + """ + MedNeXtDownBlock class for downsampling in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (bool): Whether to use residual connection. Defaults to False. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + global_resp_norm: bool = False, + ): + + super().__init__( + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + global_resp_norm=global_resp_norm, + ) + + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) + self.resample_do_res = use_residual_connection + if use_residual_connection: + self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) + + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x): + """ + Forward pass of the MedNeXtDownBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + x1 = super().forward(x) + + if self.resample_do_res: + res = self.res_conv(x) + x1 = x1 + res + + return x1 + + +class MedNeXtUpBlock(MedNeXtBlock): + """ + MedNeXtUpBlock class for upsampling in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (bool): Whether to use residual connection. Defaults to False. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + global_resp_norm: bool = False, + ): + super().__init__( + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + global_resp_norm=global_resp_norm, + ) + + self.resample_do_res = use_residual_connection + + self.dim = dim + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) + if use_residual_connection: + self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) + + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x): + """ + Forward pass of the MedNeXtUpBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + x1 = super().forward(x) + # Asymmetry but necessary to match shape + + if self.dim == "2d": + x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0)) + else: + x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0)) + + if self.resample_do_res: + res = self.res_conv(x) + if self.dim == "2d": + res = torch.nn.functional.pad(res, (1, 0, 1, 0)) + else: + res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0)) + x1 = x1 + res + + return x1 + + +class MedNeXtOutBlock(nn.Module): + """ + MedNeXtOutBlock class for the output block in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + n_classes (int): Number of output classes. + dim (str): Dimension of the input. Can be "2d" or "3d". + """ + + def __init__(self, in_channels, n_classes, dim): + super().__init__() + + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) + self.conv_out = conv(in_channels, n_classes, kernel_size=1) + + def forward(self, x): + """ + Forward pass of the MedNeXtOutBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + return self.conv_out(x) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 0570c9fcc1..b876e6a3fc 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,25 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .mednext import ( + MedNeXt, + MedNext, + MedNextB, + MedNeXtB, + MedNextBase, + MedNextL, + MedNeXtL, + MedNeXtLarge, + MedNextLarge, + MedNextM, + MedNeXtM, + MedNeXtMedium, + MedNextMedium, + MedNextS, + MedNeXtS, + MedNeXtSmall, + MedNextSmall, +) from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py new file mode 100644 index 0000000000..427572ba60 --- /dev/null +++ b/monai/networks/nets/mednext.py @@ -0,0 +1,354 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this code are derived from the original repository at: +# https://github.com/MIC-DKFZ/MedNeXt +# and are used under the terms of the Apache License, Version 2.0. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock + +__all__ = [ + "MedNeXt", + "MedNeXtSmall", + "MedNeXtBase", + "MedNeXtMedium", + "MedNeXtLarge", + "MedNext", + "MedNextS", + "MedNeXtS", + "MedNextSmall", + "MedNextB", + "MedNeXtB", + "MedNextBase", + "MedNextM", + "MedNeXtM", + "MedNextMedium", + "MedNextL", + "MedNeXtL", + "MedNextLarge", +] + + +class MedNeXt(nn.Module): + """ + MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975 + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2. + decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2. + bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2. + kernel_size: kernel size for convolutions. Defaults to 7. + deep_supervision: whether to use deep supervision. Defaults to False. + use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False. + blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2]. + blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. + blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2]. + norm_type: type of normalization layer. Defaults to 'group'. + global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808 + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + encoder_expansion_ratio: Sequence[int] | int = 2, + decoder_expansion_ratio: Sequence[int] | int = 2, + bottleneck_expansion_ratio: int = 2, + kernel_size: int = 7, + deep_supervision: bool = False, + use_residual_connection: bool = False, + blocks_down: Sequence[int] = (2, 2, 2, 2), + blocks_bottleneck: int = 2, + blocks_up: Sequence[int] = (2, 2, 2, 2), + norm_type: str = "group", + global_resp_norm: bool = False, + ): + """ + Initialize the MedNeXt model. + + This method sets up the architecture of the model, including: + - Stem convolution + - Encoder stages and downsampling blocks + - Bottleneck blocks + - Decoder stages and upsampling blocks + - Output blocks for deep supervision (if enabled) + """ + super().__init__() + + self.do_ds = deep_supervision + assert spatial_dims in [2, 3], "`spatial_dims` can only be 2 or 3." + spatial_dims_str = f"{spatial_dims}d" + enc_kernel_size = dec_kernel_size = kernel_size + + if isinstance(encoder_expansion_ratio, int): + encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down) + + if isinstance(decoder_expansion_ratio, int): + decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up) + + conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d + + self.stem = conv(in_channels, init_filters, kernel_size=1) + + enc_stages = [] + down_blocks = [] + + for i, num_blocks in enumerate(blocks_down): + enc_stages.append( + nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2**i), + out_channels=init_filters * (2**i), + expansion_ratio=encoder_expansion_ratio[i], + kernel_size=enc_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + for _ in range(num_blocks) + ] + ) + ) + + down_blocks.append( + MedNeXtDownBlock( + in_channels=init_filters * (2**i), + out_channels=init_filters * (2 ** (i + 1)), + expansion_ratio=encoder_expansion_ratio[i], + kernel_size=enc_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + ) + ) + + self.enc_stages = nn.ModuleList(enc_stages) + self.down_blocks = nn.ModuleList(down_blocks) + + self.bottleneck = nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2 ** len(blocks_down)), + out_channels=init_filters * (2 ** len(blocks_down)), + expansion_ratio=bottleneck_expansion_ratio, + kernel_size=dec_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + for _ in range(blocks_bottleneck) + ] + ) + + up_blocks = [] + dec_stages = [] + for i, num_blocks in enumerate(blocks_up): + up_blocks.append( + MedNeXtUpBlock( + in_channels=init_filters * (2 ** (len(blocks_up) - i)), + out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + expansion_ratio=decoder_expansion_ratio[i], + kernel_size=dec_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + ) + + dec_stages.append( + nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + expansion_ratio=decoder_expansion_ratio[i], + kernel_size=dec_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + for _ in range(num_blocks) + ] + ) + ) + + self.up_blocks = nn.ModuleList(up_blocks) + self.dec_stages = nn.ModuleList(dec_stages) + + self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str) + + if deep_supervision: + out_blocks = [ + MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str) + for i in range(1, len(blocks_up) + 1) + ] + + out_blocks.reverse() + self.out_blocks = nn.ModuleList(out_blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]: + """ + Forward pass of the MedNeXt model. + + This method performs the forward pass through the model, including: + - Stem convolution + - Encoder stages and downsampling + - Bottleneck blocks + - Decoder stages and upsampling with skip connections + - Output blocks for deep supervision (if enabled) + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor or Sequence[torch.Tensor]: Output tensor(s). + """ + # Apply stem convolution + x = self.stem(x) + + # Encoder forward pass + enc_outputs = [] + for enc_stage, down_block in zip(self.enc_stages, self.down_blocks): + x = enc_stage(x) + enc_outputs.append(x) + x = down_block(x) + + # Bottleneck forward pass + x = self.bottleneck(x) + + # Initialize deep supervision outputs if enabled + if self.do_ds: + ds_outputs = [] + + # Decoder forward pass with skip connections + for i, (up_block, dec_stage) in enumerate(zip(self.up_blocks, self.dec_stages)): + if self.do_ds and i < len(self.out_blocks): + ds_outputs.append(self.out_blocks[i](x)) + + x = up_block(x) + x = x + enc_outputs[-(i + 1)] + x = dec_stage(x) + + # Final output block + x = self.out_0(x) + + # Return output(s) + if self.do_ds and self.training: + return (x, *ds_outputs[::-1]) + else: + return x + + +# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975 +def create_mednext( + variant: str, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, +) -> MedNeXt: + """ + Factory method to create MedNeXt variants. + + Args: + variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L'). + spatial_dims (int): Number of spatial dimensions. Defaults to 3. + in_channels (int): Number of input channels. Defaults to 1. + out_channels (int): Number of output channels. Defaults to 2. + kernel_size (int): Kernel size for convolutions. Defaults to 3. + deep_supervision (bool): Whether to use deep supervision. Defaults to False. + + Returns: + MedNeXt: The specified MedNeXt variant. + + Raises: + ValueError: If an invalid variant is specified. + """ + common_args = { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "deep_supervision": deep_supervision, + "use_residual_connection": True, + "norm_type": "group", + "global_resp_norm": False, + "init_filters": 32, + } + + if variant.upper() == "S": + return MedNeXt( + encoder_expansion_ratio=2, + decoder_expansion_ratio=2, + bottleneck_expansion_ratio=2, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + **common_args, # type: ignore + ) + elif variant.upper() == "B": + return MedNeXt( + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + **common_args, # type: ignore + ) + elif variant.upper() == "M": + return MedNeXt( + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + blocks_down=(3, 4, 4, 4), + blocks_bottleneck=4, + blocks_up=(4, 4, 4, 3), + **common_args, # type: ignore + ) + elif variant.upper() == "L": + return MedNeXt( + encoder_expansion_ratio=(3, 4, 8, 8), + decoder_expansion_ratio=(8, 8, 4, 3), + bottleneck_expansion_ratio=8, + blocks_down=(3, 4, 8, 8), + blocks_bottleneck=8, + blocks_up=(8, 8, 4, 3), + **common_args, # type: ignore + ) + else: + raise ValueError(f"Invalid MedNeXt variant: {variant}") + + +MedNext = MedNeXt +MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs) +MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs) +MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs) +MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs) diff --git a/tests/test_mednext.py b/tests/test_mednext.py new file mode 100644 index 0000000000..b4ba4f9939 --- /dev/null +++ b/tests/test_mednext.py @@ -0,0 +1,122 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_MEDNEXT = [] +for spatial_dims in range(2, 4): + for init_filters in [8, 16]: + for deep_supervision in [False, True]: + for do_res in [False, True]: + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": init_filters, + "deep_supervision": deep_supervision, + "use_residual_connection": do_res, + }, + (2, 1, *([16] * spatial_dims)), + (2, 2, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT.append(test_case) + +TEST_CASE_MEDNEXT_2 = [] +for spatial_dims in range(2, 4): + for out_channels in [1, 2]: + for deep_supervision in [False, True]: + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": 8, + "out_channels": out_channels, + "deep_supervision": deep_supervision, + }, + (2, 1, *([16] * spatial_dims)), + (2, out_channels, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT_2.append(test_case) + +TEST_CASE_MEDNEXT_VARIANTS = [] +for model in [MedNeXtS, MedNeXtM, MedNeXtL]: + for spatial_dims in range(2, 4): + for out_channels in [1, 2]: + test_case = [ + model, # type: ignore + {"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels}, + (2, 1, *([16] * spatial_dims)), + (2, out_channels, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT_VARIANTS.append(test_case) + + +class TestMedNeXt(unittest.TestCase): + + @parameterized.expand(TEST_CASE_MEDNEXT) + def test_shape(self, input_param, input_shape, expected_shape): + net = MedNeXt(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + if input_param["deep_supervision"] and net.training: + assert isinstance(result, tuple) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + else: + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + @parameterized.expand(TEST_CASE_MEDNEXT_2) + def test_shape2(self, input_param, input_shape, expected_shape): + net = MedNeXt(**input_param).to(device) + + net.train() + result = net(torch.randn(input_shape).to(device)) + if input_param["deep_supervision"]: + assert isinstance(result, tuple) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + else: + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + net.eval() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + MedNeXt(spatial_dims=4) + + @parameterized.expand(TEST_CASE_MEDNEXT_VARIANTS) + def test_mednext_variants(self, model, input_param, input_shape, expected_shape): + net = model(**input_param).to(device) + + net.train() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + net.eval() + with torch.no_grad(): + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + +if __name__ == "__main__": + unittest.main() From 746a97abdb6dc27f070e0d72061cf3ba087ac8cd Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 13 Nov 2024 21:12:53 -0800 Subject: [PATCH 05/23] TRT support for MAISI (#8153) ### Description Added trt_compile() support for Lists and Tuples in arguments for forward() - needed for MAISI. Did not add support for grouping return results yet - MAISI worked with explicit workaround unrolling the return results. ### Notes To successfully export MAISI, either latest Torch nightly is needed, or this patch needs to be applied to 24.09-based container: ``` --- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak 2024-10-09 01:38:04.920316673 +0000 +++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py 2024-10-09 01:38:25.228053951 +0000 @@ -148,7 +148,6 @@ is_causal and symbolic_helper._is_none(attn_mask) ), "is_causal and attn_mask cannot be set at the same time" - scale = symbolic_helper._maybe_get_const(scale, "f") if symbolic_helper._is_none(scale): scale = _attention_scale(g, query) ``` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Boris Fomitchev Signed-off-by: Yiheng Wang Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> --- Dockerfile | 6 + monai/networks/nets/vista3d.py | 1 - monai/networks/trt_compiler.py | 212 ++++++++++++++++++++++++--------- monai/networks/utils.py | 11 +- monai/torch.patch | 9 ++ monai/utils/module.py | 6 +- tests/test_trt_compile.py | 66 +++++----- tests/test_version_after.py | 2 +- 8 files changed, 215 insertions(+), 98 deletions(-) create mode 100644 monai/torch.patch diff --git a/Dockerfile b/Dockerfile index 5fcfcf274d..d538fd3145 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,6 +41,10 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \ COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./ COPY tests ./tests COPY monai ./monai + +# TODO: remove this line and torch.patch for 24.11 +RUN patch -R -d /usr/local/lib/python3.10/dist-packages/torch/onnx/ < ./monai/torch.patch + RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \ && rm -rf build __pycache__ @@ -57,4 +61,6 @@ RUN apt-get update \ # append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations ENV PATH=${PATH}:/opt/tools ENV POLYGRAPHY_AUTOINSTALL_DEPS=1 + + WORKDIR /opt/monai diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 6313b7812d..6ecb664b85 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -641,7 +641,6 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor): # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d) masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1) - return masks_embedding, class_embedding diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index d2d05fae22..d96b712003 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -18,12 +18,12 @@ from collections import OrderedDict from pathlib import Path from types import MethodType -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union import torch from monai.apps.utils import get_logger -from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes from monai.utils.module import optional_import polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -125,6 +125,7 @@ def __init__(self, plan_path, logger=None): self.output_names = [] self.dtypes = [] self.cur_profile = 0 + self.input_table = {} dtype_dict = trt_to_torch_dtype_dict() for idx in range(self.engine.num_io_tensors): binding = self.engine[idx] @@ -134,6 +135,9 @@ def __init__(self, plan_path, logger=None): self.output_names.append(binding) dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) + self.logger.info( + f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}" + ) def allocate_buffers(self, device): """ @@ -163,7 +167,8 @@ def set_inputs(self, feed_dict, stream): last_profile = self.cur_profile def try_set_inputs(): - for binding, t in feed_dict.items(): + for binding in self.input_names: + t = feed_dict.get(self.input_table[binding], None) if t is not None: t = t.contiguous() shape = t.shape @@ -180,7 +185,8 @@ def try_set_inputs(): raise self.cur_profile = next_profile ctx.set_optimization_profile_async(self.cur_profile, stream) - + except Exception: + raise left = ctx.infer_shapes() assert len(left) == 0 @@ -217,6 +223,74 @@ def infer(self, stream, use_cuda_graph=False): return self.tensors +def make_tensor(d): + return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda() + + +def unroll_input(input_names, input_example): + # Simulate list/tuple unrolling during ONNX export + unrolled_input = {} + for name in input_names: + val = input_example[name] + if val is not None: + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + unrolled_input[f"{name}_{i}"] = make_tensor(val[i]) + else: + unrolled_input[name] = make_tensor(val) + return unrolled_input + + +def parse_groups( + ret: List[torch.Tensor], output_lists: List[List[int]] +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]: + """ + Implements parsing of 'output_lists' arg of trt_compile(). + + Args: + ret: plain list of Tensors + + output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list + of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. + Format: [[group_n] | [], ...] + [] or group_n == 0 : next output from ret is a scalar + group_n > 0 : next output from ret is a list of group_n length + group_n == -1: next output is a dynamic list. This entry can be at any + position in output_lists, but can appear only once. + Returns: + Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists + + """ + groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + cur = 0 + for l in range(len(output_lists)): + gl = output_lists[l] + assert len(gl) == 0 or len(gl) == 1 + if len(gl) == 0 or gl[0] == 0: + groups = (*groups, ret[cur]) + cur = cur + 1 + elif gl[0] > 0: + groups = (*groups, ret[cur : cur + gl[0]]) + cur = cur + gl[0] + elif gl[0] == -1: + rev_groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + rcur = len(ret) + for rl in range(len(output_lists) - 1, l, -1): + rgl = output_lists[rl] + assert len(rgl) == 0 or len(rgl) == 1 + if len(rgl) == 0 or rgl[0] == 0: + rcur = rcur - 1 + rev_groups = (*rev_groups, ret[rcur]) + elif rgl[0] > 0: + rcur = rcur - rgl[0] + rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]]) + else: + raise ValueError("Two -1 lists in output") + groups = (*groups, ret[cur:rcur], *rev_groups[::-1]) + break + return groups + + class TrtCompiler: """ This class implements: @@ -233,6 +307,7 @@ def __init__( method="onnx", input_names=None, output_names=None, + output_lists=None, export_args=None, build_args=None, input_profiles=None, @@ -240,6 +315,7 @@ def __init__( use_cuda_graph=False, timestamp=None, fallback=False, + forward_override=None, logger=None, ): """ @@ -255,6 +331,8 @@ def __init__( 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. input_names: Optional list of input names. If None, will be read from the function signature. output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list + of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. input_profiles: Optional list of profiles for TRT builder and ONNX export. @@ -279,6 +357,7 @@ def __init__( self.method = method self.return_dict = output_names is not None self.output_names = output_names or [] + self.output_lists = output_lists or [] self.profiles = input_profiles or [] self.dynamic_batchsize = dynamic_batchsize self.export_args = export_args or {} @@ -289,11 +368,19 @@ def __init__( self.disabled = False self.logger = logger or get_logger("monai.networks.trt_compiler") + self.argspec = inspect.getfullargspec(model.forward) # Normally we read input_names from forward() but can be overridden if input_names is None: - argspec = inspect.getfullargspec(model.forward) - input_names = argspec.args[1:] + input_names = self.argspec.args[1:] + self.defaults = {} + if self.argspec.defaults is not None: + for i in range(len(self.argspec.defaults)): + d = self.argspec.defaults[-i - 1] + if d is not None: + d = make_tensor(d) + self.defaults[self.argspec.args[-i - 1]] = d + self.input_names = input_names self.old_forward = model.forward @@ -314,9 +401,18 @@ def _load_engine(self): """ try: self.engine = TRTEngine(self.plan_path, self.logger) - self.input_names = self.engine.input_names + # Make sure we have names correct + input_table = {} + for name in self.engine.input_names: + if name.startswith("__") and name not in self.input_names: + orig_name = name[2:] + else: + orig_name = name + input_table[name] = orig_name + self.engine.input_table = input_table + self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}") except Exception as e: - self.logger.debug(f"Exception while loading the engine:\n{e}") + self.logger.info(f"Exception while loading the engine:\n{e}") def forward(self, model, argv, kwargs): """ @@ -329,6 +425,11 @@ def forward(self, model, argv, kwargs): Returns: Passing through wrapped module's forward() return value(s) """ + args = self.defaults + args.update(kwargs) + if len(argv) > 0: + args.update(self._inputs_to_dict(argv)) + if self.engine is None and not self.disabled: # Restore original forward for export new_forward = model.forward @@ -336,11 +437,10 @@ def forward(self, model, argv, kwargs): try: self._load_engine() if self.engine is None: - build_args = kwargs.copy() - if len(argv) > 0: - build_args.update(self._inputs_to_dict(argv)) - self._build_and_save(model, build_args) - # This will reassign input_names from the engine + build_args = args.copy() + with torch.no_grad(): + self._build_and_save(model, build_args) + # This will reassign input_names from the engine self._load_engine() assert self.engine is not None except Exception as e: @@ -355,19 +455,16 @@ def forward(self, model, argv, kwargs): del param # Call empty_cache to release GPU memory torch.cuda.empty_cache() + # restore TRT hook model.forward = new_forward # Run the engine try: - if len(argv) > 0: - kwargs.update(self._inputs_to_dict(argv)) - argv = () - if self.engine is not None: # forward_trt is not thread safe as we do not use per-thread execution contexts with lock_sm: device = torch.cuda.current_device() stream = torch.cuda.Stream(device=device) - self.engine.set_inputs(kwargs, stream.cuda_stream) + self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) self.engine.allocate_buffers(device=device) # Need this to synchronize with Torch stream stream.wait_stream(torch.cuda.current_stream()) @@ -375,11 +472,13 @@ def forward(self, model, argv, kwargs): # if output_names is not None, return dictionary if not self.return_dict: ret = list(ret.values()) - if len(ret) == 1: + if self.output_lists: + ret = parse_groups(ret, self.output_lists) + elif len(ret) == 1: ret = ret[0] return ret except Exception as e: - if model is not None: + if self.fallback: self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") else: raise e @@ -391,16 +490,11 @@ def _onnx_to_trt(self, onnx_path): """ profiles = [] - if self.profiles: - for input_profile in self.profiles: - if isinstance(input_profile, Profile): - profiles.append(input_profile) - else: - p = Profile() - for name, dims in input_profile.items(): - assert len(dims) == 3 - p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - profiles.append(p) + for profile in self.profiles: + p = Profile() + for id, val in profile.items(): + p.add(id, min=val[0], opt=val[1], max=val[2]) + profiles.append(p) build_args = self.build_args.copy() build_args["tf32"] = self.precision != "fp32" @@ -425,7 +519,7 @@ def _build_and_save(self, model, input_example): return export_args = self.export_args - + engine_bytes = None add_casts_around_norms(model) if self.method == "torch_trt": @@ -435,7 +529,6 @@ def _build_and_save(self, model, input_example): elif self.precision == "bf16": enabled_precisions.append(torch.bfloat16) inputs = list(input_example.values()) - ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) def get_torch_trt_input(input_shape, dynamic_batchsize): min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) @@ -445,12 +538,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] engine_bytes = torch_tensorrt.convert_method_to_trt_engine( - ir_model, - "forward", - inputs=tt_inputs, - ir="torchscript", - enabled_precisions=enabled_precisions, - **export_args, + model, "forward", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args ) else: dbs = self.dynamic_batchsize @@ -459,33 +547,47 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") if len(dbs) != 3: raise ValueError("dynamic_batchsize has to have len ==3 ") - profiles = {} + profile = {} for id, val in input_example.items(): - sh = val.shape[1:] - profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] - self.profiles = [profiles] - if len(self.profiles) > 0: - export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + def add_profile(id, val): + sh = val.shape + if len(sh) > 0: + sh = sh[1:] + profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + add_profile(f"{id}_{i}", val[i]) + elif isinstance(val, torch.Tensor): + add_profile(id, val) + self.profiles = [profile] + + self.dynamic_axes = get_dynamic_axes(self.profiles) + + if len(self.dynamic_axes) > 0: + export_args.update({"dynamic_axes": self.dynamic_axes}) # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: - onnx_path = Path(tmpdir) / "model.onnx" + unrolled_input = unroll_input(self.input_names, input_example) + onnx_path = str(Path(tmpdir) / "model.onnx") self.logger.info( - f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" + f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n" + + f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" ) convert_to_onnx( model, input_example, - filename=str(onnx_path), - input_names=self.input_names, + filename=onnx_path, + input_names=list(unrolled_input.keys()), output_names=self.output_names, **export_args, ) self.logger.info("Export to ONNX successful.") - engine_bytes = self._onnx_to_trt(str(onnx_path)) - - open(self.plan_path, "wb").write(engine_bytes) + engine_bytes = self._onnx_to_trt(onnx_path) + if engine_bytes: + open(self.plan_path, "wb").write(engine_bytes) def trt_forward(self, *argv, **kwargs): @@ -542,9 +644,11 @@ def trt_compile( args["timestamp"] = timestamp def wrap(model, path): - wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) - model._trt_compiler = wrapper - model.forward = MethodType(trt_forward, model) + if not hasattr(model, "_trt_compiler"): + model.orig_forward = model.forward + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) def find_sub(parent, submodule): idx = submodule.find(".") diff --git a/monai/networks/utils.py b/monai/networks/utils.py index cfad0364c3..05627f9c00 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -632,7 +632,6 @@ def convert_to_onnx( use_trace: bool = True, do_constant_folding: bool = True, constant_size_threshold: int = 16 * 1024 * 1024 * 1024, - dynamo=False, **kwargs, ): """ @@ -673,6 +672,9 @@ def convert_to_onnx( # let torch.onnx.export to trace the model. mode_to_export = model torch_versioned_kwargs = kwargs + if "dynamo" in kwargs and kwargs["dynamo"] and verify: + torch_versioned_kwargs["verify"] = verify + verify = False else: if not pytorch_after(1, 10): if "example_outputs" not in kwargs: @@ -695,13 +697,13 @@ def convert_to_onnx( f = temp_file.name else: f = filename - + print(f"torch_versioned_kwargs={torch_versioned_kwargs}") torch.onnx.export( mode_to_export, onnx_inputs, f=f, input_names=input_names, - output_names=output_names, + output_names=output_names or None, dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=do_constant_folding, @@ -715,6 +717,9 @@ def convert_to_onnx( fold_constants(onnx_model, size_threshold=constant_size_threshold) if verify: + if isinstance(inputs, dict): + inputs = list(inputs.values()) + if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/monai/torch.patch b/monai/torch.patch new file mode 100644 index 0000000000..e53980968b --- /dev/null +++ b/monai/torch.patch @@ -0,0 +1,9 @@ +--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py 2024-10-31 06:09:21.139938791 +0000 ++++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak 2024-10-31 06:01:50.207462739 +0000 +@@ -150,6 +150,7 @@ + ), "is_causal and attn_mask cannot be set at the same time" + assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ++ scale = symbolic_helper._maybe_get_const(scale, "f") + if symbolic_helper._is_none(scale): + scale = _attention_scale(g, query) diff --git a/monai/utils/module.py b/monai/utils/module.py index d3f2ff09f2..1ad001fc87 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s current_ver_string: if None, the current system GPU CUDA compute capability will be used. Returns: - True if the current system GPU CUDA compute capability is greater than or equal to the specified version. + True if the current system GPU CUDA compute capability is greater than the specified version. """ if current_ver_string is None: cuda_available = torch.cuda.is_available() @@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: - return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore + return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2) while len(parts) < 2: parts += ["0"] c_major, c_minor = parts[:2] c_mn = int(c_major), int(c_minor) mn = int(major), int(minor) - return c_mn > mn + return c_mn >= mn diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 49404fdbbe..9716a4a715 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -19,17 +19,12 @@ from monai.handlers import TrtHandler from monai.networks import trt_compile -from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 +from monai.networks.nets import cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import -from tests.utils import ( - SkipIfAtLeastPyTorchVersion, - SkipIfBeforeComputeCapabilityVersion, - skip_if_no_cuda, - skip_if_quick, - skip_if_windows, -) +from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) +torch_tensorrt, torch_trt_imported = optional_import("torch_tensorrt") polygraphy, polygraphy_imported = optional_import("polygraphy") build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") @@ -37,6 +32,19 @@ TEST_CASE_2 = ["fp16"] +class ListAdd(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = 0.1): + y1 = y.clone() + x1 = x.copy() + z1 = z + y + for xi in x: + y1 = y1 + xi + bs + return x1, [y1, z1], y1 + z1 + + @skip_if_windows @skip_if_no_cuda @skip_if_quick @@ -53,7 +61,7 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) - @SkipIfAtLeastPyTorchVersion((2, 4, 1)) + @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required") def test_handler(self): from ignite.engine import Engine @@ -74,29 +82,19 @@ def test_handler(self): net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) self.assertIsNotNone(net1._trt_compiler.engine) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_unet_value(self, precision): - model = UNet( - spatial_dims=3, - in_channels=1, - out_channels=2, - channels=(2, 2, 4, 8, 4), - strides=(2, 2, 2, 2), - num_res_units=2, - norm="batch", - ).cuda() + def test_lists(self): + model = ListAdd().cuda() + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: - model.eval() - input_example = torch.randn(2, 1, 96, 96, 96).cuda() - output_example = model(input_example) - args: dict = {"builder_optimization_level": 1} - trt_compile( - model, - f"{tmpdir}/test_unet_trt_compile", - args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, - ) + args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}} + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile(model, f"{tmpdir}/test_lists", args=args) self.assertIsNone(model._trt_compiler.engine) - trt_output = model(input_example) + trt_output = model(*input_example) # Check that lazy TRT build succeeded self.assertIsNotNone(model._trt_compiler.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) @@ -109,11 +107,7 @@ def test_cell_sam_wrapper_value(self, precision): model.eval() input_example = torch.randn(1, 3, 128, 128).to("cuda") output_example = model(input_example) - trt_compile( - model, - f"{tmpdir}/test_cell_sam_wrapper_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, - ) + trt_compile(model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", args={"precision": precision}) self.assertIsNone(model._trt_compiler.engine) trt_output = model(input_example) # Check that lazy TRT build succeeded @@ -130,7 +124,7 @@ def test_vista3d(self, precision): model = trt_compile( model, f"{tmpdir}/test_vista3d_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + args={"precision": precision, "dynamic_batchsize": [1, 2, 4]}, submodule=["image_encoder.encoder", "class_head"], ) self.assertIsNotNone(model.image_encoder.encoder._trt_compiler) diff --git a/tests/test_version_after.py b/tests/test_version_after.py index b6cb741382..34a5054974 100644 --- a/tests/test_version_after.py +++ b/tests/test_version_after.py @@ -38,7 +38,7 @@ TEST_CASES_SM = [ # (major, minor, sm, expected) - (6, 1, "6.1", True), + (6, 1, "6.1", False), (6, 1, "6.0", False), (6, 0, "8.6", True), (7, 0, "8", True), From 13b96aedc48ad2da16149490b06a1a6bd8361335 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 15 Nov 2024 00:48:02 -0800 Subject: [PATCH 06/23] Fixed fold_constants, test_handler switched to onnx (#8211) Fixed fold_constants: the result was not saved. test_handler switched to onnx as torch-tensorrt is causing issues with CI on various Torch versions and is not used anyway. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Boris Fomitchev --- monai/networks/utils.py | 5 +++-- tests/test_trt_compile.py | 10 +++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 05627f9c00..1b4cb220ae 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -712,9 +712,10 @@ def convert_to_onnx( onnx_model = onnx.load(f) if do_constant_folding and polygraphy_imported: - from polygraphy.backend.onnx.loader import fold_constants + from polygraphy.backend.onnx.loader import fold_constants, save_onnx - fold_constants(onnx_model, size_threshold=constant_size_threshold) + onnx_model = fold_constants(onnx_model, size_threshold=constant_size_threshold) + save_onnx(onnx_model, f) if verify: if isinstance(inputs, dict): diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 9716a4a715..e1323c201f 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -61,7 +61,7 @@ def tearDown(self): if current_device != self.gpu_device: torch.cuda.set_device(self.gpu_device) - @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required") + # @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required") def test_handler(self): from ignite.engine import Engine @@ -74,7 +74,7 @@ def test_handler(self): with tempfile.TemporaryDirectory() as tempdir: engine = Engine(lambda e, b: None) - args = {"method": "torch_trt"} + args = {"method": "onnx", "dynamic_batchsize": [1, 4, 8]} TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine) engine.run([0] * 8, max_epochs=1) self.assertIsNotNone(net1._trt_compiler) @@ -86,7 +86,11 @@ def test_lists(self): model = ListAdd().cuda() with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: - args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}} + args = { + "output_lists": [[-1], [2], []], + "export_args": {"dynamo": False, "verbose": True}, + "dynamic_batchsize": [1, 4, 8], + } x = torch.randn(1, 16).to("cuda") y = torch.randn(1, 16).to("cuda") z = torch.randn(1, 16).to("cuda") From b1e915c323a8065cfe9e92de3013476f2f67c1b2 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 21 Nov 2024 09:48:03 +0000 Subject: [PATCH 07/23] Fixing a minor bug in a test (#8223) ### Description There is a minor bug in a test which causes a second fail to occur when one does. In `test_module_list.py`, there is a list of classes to check which must have aliases removed from it. This must be done before the test assert so that in the event the assert fails this removal isn't skipped. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/test_module_list.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_module_list.py b/tests/test_module_list.py index d21ba53b7c..833441cbca 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -58,13 +58,17 @@ def test_transform_api(self): continue with self.subTest(n=n): basename = n[:-1] # Transformd basename is Transform + + # remove aliases to check, do this before the assert below so that a failed assert does skip this + for postfix in ("D", "d", "Dict"): + remained.remove(f"{basename}{postfix}") + for docname in (f"{basename}", f"{basename}d"): if docname in to_exclude_docs: continue if (contents is not None) and f"`{docname}`" not in f"{contents}": self.assertTrue(False, f"please add `{docname}` to docs/source/transforms.rst") - for postfix in ("D", "d", "Dict"): - remained.remove(f"{basename}{postfix}") + self.assertFalse(remained) From d94df3fbefd7fc8dccb41c2f629710530d9454f0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 24 Nov 2024 15:38:24 +0800 Subject: [PATCH 08/23] 8134 Add unit test for responsive inference (#8146) Fixes #8134 . ### Description This PR added unit test to cover the realtime inference with bundles. And updated `BundleWorkflow` to support cyclically calling the `run` function with all components instantiated. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Nic Ma Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/bundle/reference_resolver.py | 10 ++ monai/bundle/workflows.py | 19 +++- tests/test_bundle_workflow.py | 35 ++++++- tests/testing_data/responsive_inference.json | 101 +++++++++++++++++++ 4 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 tests/testing_data/responsive_inference.json diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index df69b021e1..b55c62174b 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str """ return self._resolve_one_item(id=id, **kwargs) + def remove_resolved_content(self, id: str) -> Any | None: + """ + Remove the resolved ``ConfigItem`` by id. + + Args: + id: id name of the expected item. + + """ + return self.resolved_content.pop(id) if id in self.resolved_content else None + @classmethod def normalize_id(cls, id: str | int) -> str: """ diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 3ecd5dfbc5..dbfa6bb54c 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -394,8 +394,23 @@ def check_properties(self) -> list[str] | None: ret.extend(wrong_props) return ret - def _run_expr(self, id: str, **kwargs: dict) -> Any: - return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None + def _run_expr(self, id: str, **kwargs: dict) -> list[Any]: + """ + Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored, + allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process. + """ + ret = [] + if id in self.parser: + # suppose all the expressions are in a list, run and reset the expressions + if isinstance(self.parser[id], list): + for i in range(len(self.parser[id])): + sub_id = f"{id}{ID_SEP_KEY}{i}" + ret.append(self.parser.get_parsed_content(sub_id, **kwargs)) + self.parser.ref_resolver.remove_resolved_content(sub_id) + else: + ret.append(self.parser.get_parsed_content(id, **kwargs)) + self.parser.ref_resolver.remove_resolved_content(id) + return ret def _get_prop_id(self, name: str, property: dict) -> Any: prop_id = property[BundlePropertyConfig.ID] diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 1727fcdf53..b10f6448ff 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -26,7 +26,7 @@ from monai.data import Dataset from monai.inferers import SimpleInferer, SlidingWindowInferer from monai.networks.nets import UNet -from monai.transforms import Compose, LoadImage +from monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged from tests.nonconfig_workflow import NonConfigWorkflow TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")] @@ -35,6 +35,8 @@ TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] +TEST_CASE_4 = [os.path.join(os.path.dirname(__file__), "testing_data", "responsive_inference.json")] + TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."] @@ -45,7 +47,9 @@ def setUp(self): self.expected_shape = (128, 128, 128) test_image = np.random.rand(*self.expected_shape) self.filename = os.path.join(self.data_dir, "image.nii") + self.filename1 = os.path.join(self.data_dir, "image1.nii") nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename1) def tearDown(self): shutil.rmtree(self.data_dir) @@ -115,6 +119,35 @@ def test_inference_config(self, config_file): self._test_inferer(inferer) self.assertEqual(inferer.workflow_type, None) + @parameterized.expand([TEST_CASE_4]) + def test_responsive_inference_config(self, config_file): + input_loader = LoadImaged(keys="image") + output_saver = SaveImaged(keys="pred", output_dir=self.data_dir, output_postfix="seg") + + # test standard MONAI model-zoo config workflow + inferer = ConfigWorkflow( + workflow_type="infer", + config_file=config_file, + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + ) + # FIXME: temp add the property for test, we should add it to some formal realtime infer properties + inferer.add_property(name="dataflow", required=True, config_id="dataflow") + + inferer.initialize() + inferer.dataflow.update(input_loader({"image": self.filename})) + inferer.run() + output_saver(inferer.dataflow) + self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image", "image_seg.nii.gz"))) + + # bundle is instantiated and idle, just change the input for next inference + inferer.dataflow.clear() + inferer.dataflow.update(input_loader({"image": self.filename1})) + inferer.run() + output_saver(inferer.dataflow) + self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image1", "image1_seg.nii.gz"))) + + inferer.finalize() + @parameterized.expand([TEST_CASE_3]) def test_train_config(self, config_file): # test standard MONAI model-zoo config workflow diff --git a/tests/testing_data/responsive_inference.json b/tests/testing_data/responsive_inference.json new file mode 100644 index 0000000000..16d953d38e --- /dev/null +++ b/tests/testing_data/responsive_inference.json @@ -0,0 +1,101 @@ +{ + "imports": [ + "$from collections import defaultdict" + ], + "bundle_root": "will override", + "device": "$torch.device('cpu')", + "network_def": { + "_target_": "UNet", + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 2, + "channels": [ + 2, + 2, + 4, + 8, + 4 + ], + "strides": [ + 2, + 2, + 2, + 2 + ], + "num_res_units": 2, + "norm": "batch" + }, + "network": "$@network_def.to(@device)", + "dataflow": "$defaultdict()", + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "EnsureChannelFirstd", + "keys": "image" + }, + { + "_target_": "ScaleIntensityd", + "keys": "image" + }, + { + "_target_": "RandRotated", + "_disabled_": true, + "keys": "image" + } + ] + }, + "dataset": { + "_target_": "Dataset", + "data": [ + "@dataflow" + ], + "transform": "@preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@dataset", + "batch_size": 1, + "shuffle": false, + "num_workers": 0 + }, + "inferer": { + "_target_": "SlidingWindowInferer", + "roi_size": [ + 64, + 64, + 32 + ], + "sw_batch_size": 4, + "overlap": 0.25 + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + }, + { + "_target_": "AsDiscreted", + "keys": "pred", + "argmax": true + } + ] + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@dataloader", + "network": "@network", + "inferer": "@inferer", + "postprocessing": "@postprocessing", + "amp": false, + "epoch_length": 1 + }, + "run": [ + "$@evaluator.run()", + "$@dataflow.update(@evaluator.state.output[0])" + ] +} From 3ee4cd22a8cc7b6b4cb3c5fd228dfa9ef153e60c Mon Sep 17 00:00:00 2001 From: Eloi Date: Mon, 25 Nov 2024 07:32:25 +0100 Subject: [PATCH 09/23] SwinUNETR refactor to accept additional parameters (#8212) ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: Eloi Navet eloi.navet@labri.fr Signed-off-by: Eloi Navet Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/swin_unetr.py | 36 ++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 832135ad06..32b817d584 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -13,7 +13,6 @@ import itertools from collections.abc import Sequence -from typing import Final import numpy as np import torch @@ -51,8 +50,6 @@ class SwinUNETR(nn.Module): " """ - patch_size: Final[int] = 2 - @deprecated_arg( name="img_size", since="1.3", @@ -65,18 +62,24 @@ def __init__( img_size: Sequence[int] | int, in_channels: int, out_channels: int, + patch_size: int = 2, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (3, 6, 12, 24), + window_size: Sequence[int] | int = 7, + qkv_bias: bool = True, + mlp_ratio: float = 4.0, feature_size: int = 24, norm_name: tuple | str = "instance", drop_rate: float = 0.0, attn_drop_rate: float = 0.0, dropout_path_rate: float = 0.0, normalize: bool = True, + norm_layer: type[LayerNorm] = nn.LayerNorm, + patch_norm: bool = True, use_checkpoint: bool = False, spatial_dims: int = 3, - downsample="merging", - use_v2=False, + downsample: str | nn.Module = "merging", + use_v2: bool = False, ) -> None: """ Args: @@ -86,14 +89,20 @@ def __init__( It will be removed in an upcoming version. in_channels: dimension of input channels. out_channels: dimension of output channels. + patch_size: size of the patch token. feature_size: dimension of network feature size. depths: number of layers in each stage. num_heads: number of attention heads. + window_size: local window size. + qkv_bias: add a learnable bias to query, key, value. + mlp_ratio: ratio of mlp hidden dim to embedding dim. norm_name: feature normalization type and arguments. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. dropout_path_rate: drop path rate. normalize: normalize output intermediate features in each stage. + norm_layer: normalization layer. + patch_norm: whether to apply normalization to the patch embedding. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: number of spatial dims. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a @@ -116,13 +125,15 @@ def __init__( super().__init__() - img_size = ensure_tuple_rep(img_size, spatial_dims) - patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) - window_size = ensure_tuple_rep(7, spatial_dims) - if spatial_dims not in (2, 3): raise ValueError("spatial dimension should be 2 or 3.") + self.patch_size = patch_size + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) + window_size = ensure_tuple_rep(window_size, spatial_dims) + self._check_input_size(img_size) if not (0 <= drop_rate <= 1): @@ -146,12 +157,13 @@ def __init__( patch_size=patch_sizes, depths=depths, num_heads=num_heads, - mlp_ratio=4.0, - qkv_bias=True, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dropout_path_rate, - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, + patch_norm=patch_norm, use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, From 649c7c8b2fdc4c68e10360437ae233d3696d6833 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Tue, 26 Nov 2024 04:02:35 +0100 Subject: [PATCH 10/23] Allow an arbitrary mask to be used in the self attention (#8235) ### Description The aim of this PR is to enable the use of an arbitrary mask in the self attention module, which is very useful in the case of missing data or masked modeling. Official torch implementations allow the use of an arbitrary mask, and in MONAI the use of a mask is also made possible with the `causal` argument. Here, it's just a generalization directly in the forward pass. In the `SABlock` and `TransformerBlock`, it is now possible to input a boolean mask of size `(BS, Seq_length)`. Only the columns of the masked token are set to `-inf` and not the rows, as is rarely the case in common implementations. Masked tokens don't contribute to the gradient anyway. In cases where causal attention is required, inputting a mask is not supported to avoid masks overlapping. I haven't implemented the addition mask to the attention matrix, which allows you to use values other than `-inf` in certain cases, as may be the case here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html If you think it's relevant, it could be added. ### Types of changes - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/networks/blocks/selfattention.py | 22 ++++++++++++++++++---- monai/networks/blocks/transformerblock.py | 6 ++++-- tests/test_selfattention.py | 18 ++++++++++++++++++ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index ac96b077bd..86e1b1d3ae 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -154,10 +154,12 @@ def __init__( ) self.input_size = input_size - def forward(self, x): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + attn_mask (torch.Tensor, optional): mask to apply to the attention matrix. + B x (s_dim_1 * ... * s_dim_n). Defaults to None. Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C @@ -176,7 +178,13 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + query=q, + key=k, + value=v, + attn_mask=attn_mask, + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale @@ -186,10 +194,16 @@ def forward(self, x): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: + if attn_mask is not None: + raise ValueError("Causal attention does not support attention masks.") att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) + attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1) + att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf")) + att_mat = att_mat.softmax(dim=-1) if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 05eb3b07ab..6f0da73e7b 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -90,8 +90,10 @@ def __init__( use_flash_attention=use_flash_attention, ) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x)) + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), attn_mask=attn_mask) if self.with_cross_attention: x = x + self.cross_attn(self.norm_cross_attn(x), context=context) x = x + self.mlp(self.norm2(x)) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 88919fd8b1..338f1bf840 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -122,6 +122,24 @@ def test_causal(self): # check upper triangular part of the attention matrix is zero assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + def test_masked_selfattention(self): + n = 64 + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True) + input_shape = (1, n, 128) + # generate a mask randomly with zeros and ones of shape (1, n) + mask = torch.randint(0, 2, (1, n)).bool() + block(torch.randn(input_shape), attn_mask=mask) + att_mat = block.att_mat.squeeze() + # ensure all masked columns are zeros + assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)])) + + def test_causal_and_mask(self): + with self.assertRaises(ValueError): + block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64) + inputs = torch.randn(2, 64, 128) + mask = torch.randint(0, 2, (2, 64)).bool() + block(inputs, attn_mask=mask) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format From e73257caa79309dcce1e93abf1632f4bfd75b11f Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 27 Nov 2024 18:25:45 +0800 Subject: [PATCH 11/23] Add PythonicWorkflow (#8151) Fixes # . ### Description Add PythonicWorkflow ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Boris Fomitchev Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Boris Fomitchev Co-authored-by: Boris Fomitchev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/bundle/__init__.py | 2 +- monai/bundle/workflows.py | 190 ++++++++++++++++-- monai/utils/module.py | 6 +- tests/nonconfig_workflow.py | 64 +++++- tests/test_bundle_trt_export.py | 2 +- tests/test_bundle_workflow.py | 74 ++++++- tests/test_convert_to_trt.py | 2 +- tests/test_trt_compile.py | 2 +- tests/test_version_after.py | 2 +- tests/testing_data/fl_infer_properties.json | 135 +++++++------ .../python_workflow_properties.json | 26 +++ 11 files changed, 409 insertions(+), 96 deletions(-) create mode 100644 tests/testing_data/python_workflow_properties.json diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index a4a2176f14..3f3c8d545e 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -43,4 +43,4 @@ MACRO_KEY, load_bundle_config, ) -from .workflows import BundleWorkflow, ConfigWorkflow +from .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index dbfa6bb54c..75cf7b0b09 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -44,12 +44,18 @@ class BundleWorkflow(ABC): workflow_type: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. - default to `train` for train workflow. + default to `None` for only using meta properties. workflow: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. - properties_path: the path to the JSON file of properties. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "meta". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. @@ -97,29 +103,50 @@ def __init__( meta_file = None workflow_type = workflow if workflow is not None else workflow_type - if workflow_type is None and properties_path is None: - self.properties = copy(MetaProperties) - self.workflow_type = None - self.meta_file = meta_file - return + if workflow_type is not None: + if workflow_type.lower() in self.supported_train_type: + workflow_type = "train" + elif workflow_type.lower() in self.supported_infer_type: + workflow_type = "infer" + else: + raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + if properties_path is not None: properties_path = Path(properties_path) if not properties_path.is_file(): raise ValueError(f"Property file {properties_path} does not exist.") with open(properties_path) as json_file: - self.properties = json.load(json_file) - self.workflow_type = None - self.meta_file = meta_file - return - if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr] - self.properties = {**TrainProperties, **MetaProperties} - self.workflow_type = "train" - elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr] - self.properties = {**InferProperties, **MetaProperties} - self.workflow_type = "infer" + try: + properties = json.load(json_file) + self.properties: dict = {} + if workflow_type is not None and workflow_type in properties: + self.properties = properties[workflow_type] + if "meta" in properties: + self.properties.update(properties["meta"]) + elif workflow_type is None: + if "meta" in properties: + self.properties = properties["meta"] + logger.info( + "No workflow type specified, default to load meta properties from property file." + ) + else: + logger.warning("No 'meta' key found in properties while workflow_type is None.") + except KeyError as e: + raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e + except json.JSONDecodeError as e: + raise ValueError(f"Error decoding JSON from property file {properties_path}") from e else: - raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + if workflow_type == "train": + self.properties = {**TrainProperties, **MetaProperties} + elif workflow_type == "infer": + self.properties = {**InferProperties, **MetaProperties} + elif workflow_type is None: + self.properties = copy(MetaProperties) + logger.info("No workflow type and property file specified, default to 'meta' properties.") + else: + raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + self.workflow_type = workflow_type self.meta_file = meta_file @abstractmethod @@ -226,6 +253,124 @@ def check_properties(self) -> list[str] | None: return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)] +class PythonicWorkflow(BundleWorkflow): + """ + Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow. + It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc. + This also provides the interface to get / set public properties to interact with a bundle workflow through + defined `get_` accessor methods or directly defining members of the object. + For how to set the properties, users can define the `_set_` methods or directly set the members of the object. + The `initialize` method is called to set up the workflow before running. This method sets up internal state + and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized` + is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is + properly set up with the new property values. + + Args: + workflow_type: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for only using meta properties. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "meta". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. + config_file: path to the config file, typically used to store hyperparameters. + meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. + + """ + + supported_train_type: tuple = ("train", "training") + supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation") + + def __init__( + self, + workflow_type: str | None = None, + properties_path: PathLike | None = None, + config_file: str | Sequence[str] | None = None, + meta_file: str | Sequence[str] | None = None, + logging_file: str | None = None, + **override: Any, + ): + meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file + super().__init__( + workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file + ) + self._props_vals: dict = {} + self._set_props_vals: dict = {} + self.parser = ConfigParser() + if config_file is not None: + self.parser.read_config(f=config_file) + if self.meta_file is not None: + self.parser.read_meta(f=self.meta_file) + + # the rest key-values in the _args are to override config content + self.parser.update(pairs=override) + self._is_initialized: bool = False + + def initialize(self, *args: Any, **kwargs: Any) -> Any: + """ + Initialize the bundle workflow before running. + """ + self._props_vals = {} + self._is_initialized = True + + def _get_property(self, name: str, property: dict) -> Any: + """ + With specified property name and information, get the expected property value. + If the property is already generated, return from the bucket directly. + If user explicitly set the property, return it directly. + Otherwise, generate the expected property as a class private property with prefix "_". + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + """ + if not self._is_initialized: + raise RuntimeError("Please execute 'initialize' before getting any properties.") + value = None + if name in self._set_props_vals: + value = self._set_props_vals[name] + elif name in self._props_vals: + value = self._props_vals[name] + elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index] + id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None) + value = self.parser[id] + else: + try: + value = getattr(self, f"get_{name}")() + except AttributeError as e: + if property[BundleProperty.REQUIRED]: + raise ValueError( + f"unsupported property '{name}' is required in the bundle properties," + f"need to implement a method 'get_{name}' to provide the property." + ) from e + self._props_vals[name] = value + return value + + def _set_property(self, name: str, property: dict, value: Any) -> Any: + """ + With specified property name and information, set value for the expected property. + Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + value: value to set for the property. + + """ + self._set_props_vals[name] = value + self._is_initialized = False + + class ConfigWorkflow(BundleWorkflow): """ Specification for the config-based bundle workflow. @@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. - properties_path: the path to the JSON file of properties. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "train". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` @@ -324,7 +475,6 @@ def __init__( self.parser.read_config(f=config_file) if self.meta_file is not None: self.parser.read_meta(f=self.meta_file) - # the rest key-values in the _args are to override config content self.parser.update(pairs=override) self.init_id = init_id diff --git a/monai/utils/module.py b/monai/utils/module.py index 1ad001fc87..d3f2ff09f2 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s current_ver_string: if None, the current system GPU CUDA compute capability will be used. Returns: - True if the current system GPU CUDA compute capability is greater than the specified version. + True if the current system GPU CUDA compute capability is greater than or equal to the specified version. """ if current_ver_string is None: cuda_available = torch.cuda.is_available() @@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: - return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore + return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2) while len(parts) < 2: parts += ["0"] c_major, c_minor = parts[:2] c_mn = int(c_major), int(c_minor) mn = int(major), int(minor) - return c_mn >= mn + return c_mn > mn diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index b2c44c12c6..fcfc5b2951 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -13,7 +13,7 @@ import torch -from monai.bundle import BundleWorkflow +from monai.bundle import BundleWorkflow, PythonicWorkflow from monai.data import DataLoader, Dataset from monai.engines import SupervisedEvaluator from monai.inferers import SlidingWindowInferer @@ -26,8 +26,9 @@ LoadImaged, SaveImaged, ScaleIntensityd, + ScaleIntensityRanged, ) -from monai.utils import BundleProperty, set_determinism +from monai.utils import BundleProperty, CommonKeys, set_determinism class NonConfigWorkflow(BundleWorkflow): @@ -176,3 +177,62 @@ def _set_property(self, name, property, value): self._numpy_version = value elif property[BundleProperty.REQUIRED]: raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + +class PythonicWorkflowImpl(PythonicWorkflow): + """ + Test class simulates the bundle workflow defined by Python script directly. + """ + + def __init__( + self, + workflow_type: str = "inference", + config_file: str | None = None, + properties_path: str | None = None, + meta_file: str | None = None, + ): + super().__init__( + workflow_type=workflow_type, properties_path=properties_path, config_file=config_file, meta_file=meta_file + ) + self.dataflow: dict = {} + + def initialize(self): + self._props_vals = {} + self._is_initialized = True + self.net = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(16, 32, 64, 128), + strides=(2, 2, 2), + num_res_units=2, + ).to(self.device) + preprocessing = Compose( + [ + EnsureChannelFirstd(keys=["image"]), + ScaleIntensityd(keys="image"), + ScaleIntensityRanged(keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), + ] + ) + self.dataset = Dataset(data=[self.dataflow], transform=preprocessing) + self.postprocessing = Compose([Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True)]) + + def run(self): + data = self.dataset[0] + inputs = data[CommonKeys.IMAGE].unsqueeze(0).to(self.device) + self.net.eval() + with torch.no_grad(): + data[CommonKeys.PRED] = self.inferer(inputs, self.net) + self.dataflow.update({CommonKeys.PRED: self.postprocessing(data)[CommonKeys.PRED]}) + + def finalize(self): + pass + + def get_bundle_root(self): + return "." + + def get_device(self): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def get_inferer(self): + return SlidingWindowInferer(roi_size=self.parser.roi_size, sw_batch_size=1, overlap=0) diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 835c8e5c1d..27e1ee97a8 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -53,7 +53,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestTRTExport(unittest.TestCase): def setUp(self): diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index b10f6448ff..893b9dc991 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -13,6 +13,7 @@ import os import shutil +import sys import tempfile import unittest from copy import deepcopy @@ -22,12 +23,12 @@ import torch from parameterized import parameterized -from monai.bundle import ConfigWorkflow +from monai.bundle import ConfigWorkflow, create_workflow from monai.data import Dataset from monai.inferers import SimpleInferer, SlidingWindowInferer from monai.networks.nets import UNet from monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged -from tests.nonconfig_workflow import NonConfigWorkflow +from tests.nonconfig_workflow import NonConfigWorkflow, PythonicWorkflowImpl TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")] @@ -112,12 +113,13 @@ def test_inference_config(self, config_file): # test property path inferer = ConfigWorkflow( config_file=config_file, + workflow_type="infer", properties_path=os.path.join(os.path.dirname(__file__), "testing_data", "fl_infer_properties.json"), logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), **override, ) self._test_inferer(inferer) - self.assertEqual(inferer.workflow_type, None) + self.assertEqual(inferer.workflow_type, "infer") @parameterized.expand([TEST_CASE_4]) def test_responsive_inference_config(self, config_file): @@ -197,6 +199,72 @@ def test_non_config_wrong_log_cases(self, meta_file, logging_file, expected_erro with self.assertRaisesRegex(FileNotFoundError, expected_error): NonConfigWorkflow(self.filename, self.data_dir, meta_file, logging_file) + def test_pythonic_workflow(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + config_file = {"roi_size": (64, 64, 32)} + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + property_path = os.path.join(os.path.dirname(__file__), "testing_data", "python_workflow_properties.json") + workflow = PythonicWorkflowImpl( + workflow_type="infer", config_file=config_file, meta_file=meta_file, properties_path=property_path + ) + workflow.initialize() + # Load input data + input_loader = LoadImaged(keys="image") + workflow.dataflow.update(input_loader({"image": self.filename})) + self.assertEqual(workflow.bundle_root, ".") + self.assertEqual(workflow.device, device) + self.assertEqual(workflow.version, "0.1.0") + # check config override correctly + self.assertEqual(workflow.inferer.roi_size, (64, 64, 32)) + workflow.run() + # update input data and run again + workflow.dataflow.update(input_loader({"image": self.filename1})) + workflow.run() + pred = workflow.dataflow["pred"] + self.assertEqual(pred.shape[2:], self.expected_shape) + self.assertEqual(pred.meta["filename_or_obj"], self.filename1) + workflow.finalize() + + def test_create_pythonic_workflow(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + config_file = {"roi_size": (64, 64, 32)} + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + property_path = os.path.join(os.path.dirname(__file__), "testing_data", "python_workflow_properties.json") + sys.path.append(os.path.dirname(__file__)) + workflow = create_workflow( + "nonconfig_workflow.PythonicWorkflowImpl", + workflow_type="infer", + config_file=config_file, + meta_file=meta_file, + properties_path=property_path, + ) + # Load input data + input_loader = LoadImaged(keys="image") + workflow.dataflow.update(input_loader({"image": self.filename})) + self.assertEqual(workflow.bundle_root, ".") + self.assertEqual(workflow.device, device) + self.assertEqual(workflow.version, "0.1.0") + # check config override correctly + self.assertEqual(workflow.inferer.roi_size, (64, 64, 32)) + + # check set property override correctly + workflow.inferer = SlidingWindowInferer(roi_size=config_file["roi_size"], sw_batch_size=1, overlap=0.5) + workflow.initialize() + self.assertEqual(workflow.inferer.overlap, 0.5) + + workflow.run() + # update input data and run again + workflow.dataflow.update(input_loader({"image": self.filename1})) + workflow.run() + pred = workflow.dataflow["pred"] + self.assertEqual(pred.shape[2:], self.expected_shape) + self.assertEqual(pred.meta["filename_or_obj"], self.filename1) + + # test add properties + workflow.add_property(name="net", required=True, desc="network for the training.") + self.assertIn("net", workflow.properties) + workflow.finalize() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 712d887c3b..a7b1edec3c 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -38,7 +38,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestConvertToTRT(unittest.TestCase): def setUp(self): diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index e1323c201f..f7779fec9b 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -50,7 +50,7 @@ def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: f @skip_if_quick @unittest.skipUnless(trt_imported, "tensorrt is required") @unittest.skipUnless(polygraphy_imported, "polygraphy is required") -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestTRTCompile(unittest.TestCase): def setUp(self): diff --git a/tests/test_version_after.py b/tests/test_version_after.py index 34a5054974..b6cb741382 100644 --- a/tests/test_version_after.py +++ b/tests/test_version_after.py @@ -38,7 +38,7 @@ TEST_CASES_SM = [ # (major, minor, sm, expected) - (6, 1, "6.1", False), + (6, 1, "6.1", True), (6, 1, "6.0", False), (6, 0, "8.6", True), (7, 0, "8", True), diff --git a/tests/testing_data/fl_infer_properties.json b/tests/testing_data/fl_infer_properties.json index 72e97cd2c6..6b40edd2ab 100644 --- a/tests/testing_data/fl_infer_properties.json +++ b/tests/testing_data/fl_infer_properties.json @@ -1,67 +1,76 @@ { - "bundle_root": { - "description": "root path of the bundle.", - "required": true, - "id": "bundle_root" + "infer": { + "bundle_root": { + "description": "root path of the bundle.", + "required": true, + "id": "bundle_root" + }, + "device": { + "description": "target device to execute the bundle workflow.", + "required": true, + "id": "device" + }, + "dataset_dir": { + "description": "directory path of the dataset.", + "required": true, + "id": "dataset_dir" + }, + "dataset": { + "description": "PyTorch dataset object for the inference / evaluation logic.", + "required": true, + "id": "dataset" + }, + "evaluator": { + "description": "inference / evaluation workflow engine.", + "required": true, + "id": "evaluator" + }, + "network_def": { + "description": "network module for the inference.", + "required": true, + "id": "network_def" + }, + "inferer": { + "description": "MONAI Inferer object to execute the model computation in inference.", + "required": true, + "id": "inferer" + }, + "dataset_data": { + "description": "data source for the inference / evaluation dataset.", + "required": false, + "id": "dataset::data", + "refer_id": null + }, + "handlers": { + "description": "event-handlers for the inference / evaluation logic.", + "required": false, + "id": "handlers", + "refer_id": "evaluator::val_handlers" + }, + "preprocessing": { + "description": "preprocessing for the input data.", + "required": false, + "id": "preprocessing", + "refer_id": "dataset::transform" + }, + "postprocessing": { + "description": "postprocessing for the model output data.", + "required": false, + "id": "postprocessing", + "refer_id": "evaluator::postprocessing" + }, + "key_metric": { + "description": "the key metric during evaluation.", + "required": false, + "id": "key_metric", + "refer_id": "evaluator::key_val_metric" + } }, - "device": { - "description": "target device to execute the bundle workflow.", - "required": true, - "id": "device" - }, - "dataset_dir": { - "description": "directory path of the dataset.", - "required": true, - "id": "dataset_dir" - }, - "dataset": { - "description": "PyTorch dataset object for the inference / evaluation logic.", - "required": true, - "id": "dataset" - }, - "evaluator": { - "description": "inference / evaluation workflow engine.", - "required": true, - "id": "evaluator" - }, - "network_def": { - "description": "network module for the inference.", - "required": true, - "id": "network_def" - }, - "inferer": { - "description": "MONAI Inferer object to execute the model computation in inference.", - "required": true, - "id": "inferer" - }, - "dataset_data": { - "description": "data source for the inference / evaluation dataset.", - "required": false, - "id": "dataset::data", - "refer_id": null - }, - "handlers": { - "description": "event-handlers for the inference / evaluation logic.", - "required": false, - "id": "handlers", - "refer_id": "evaluator::val_handlers" - }, - "preprocessing": { - "description": "preprocessing for the input data.", - "required": false, - "id": "preprocessing", - "refer_id": "dataset::transform" - }, - "postprocessing": { - "description": "postprocessing for the model output data.", - "required": false, - "id": "postprocessing", - "refer_id": "evaluator::postprocessing" - }, - "key_metric": { - "description": "the key metric during evaluation.", - "required": false, - "id": "key_metric", - "refer_id": "evaluator::key_val_metric" + "meta": { + "version": { + "description": "version of the inference configuration.", + "required": true, + "id": "_meta_::version" + } } } diff --git a/tests/testing_data/python_workflow_properties.json b/tests/testing_data/python_workflow_properties.json new file mode 100644 index 0000000000..cd4295839a --- /dev/null +++ b/tests/testing_data/python_workflow_properties.json @@ -0,0 +1,26 @@ +{ + "infer": { + "bundle_root": { + "description": "root path of the bundle.", + "required": true, + "id": "bundle_root" + }, + "device": { + "description": "target device to execute the bundle workflow.", + "required": true, + "id": "device" + }, + "inferer": { + "description": "MONAI Inferer object to execute the model computation in inference.", + "required": true, + "id": "inferer" + } + }, + "meta": { + "version": { + "description": "version of the inference configuration.", + "required": true, + "id": "_meta_::version" + } + } +} From 20372f0188fc981994cb1d87f3a9136679809d84 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Wed, 27 Nov 2024 16:18:14 +0100 Subject: [PATCH 12/23] Implementation of a Masked Autoencoder for representation learning (#8152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This follows a previous PR (#7598). In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model. In the official masked autoencoder implementation, noise is first generated and then sorted twice using `torch.argsort`. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices. In our implementation, we use `torch.multinomial` to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder. **Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.** ### Description Implementation of the Masked Autoencoder as described in the paper: [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/pdf/2111.06377.pdf) from Kaiming et al. Its effectiveness has already been demonstrated in the literature for medical tasks in the paper [Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation](https://arxiv.org/abs/2203.05573). The PR contains the architecture and associated unit tests. **Note:** The output includes the prediction, which is a tensor of size: ($BS$, $N_{tokens}$, $D$), and the associated mask ($BS$, $N_{tokens}$). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think? ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/masked_autoencoder_vit.py | 211 ++++++++++++++++++ tests/test_masked_autoencoder_vit.py | 160 +++++++++++++ 4 files changed, 377 insertions(+) create mode 100644 monai/networks/nets/masked_autoencoder_vit.py create mode 100644 tests/test_masked_autoencoder_vit.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 64a3a4c9d1..e2e509a99b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -630,6 +630,11 @@ Nets .. autoclass:: ViTAutoEnc :members: +`MaskedAutoEncoderViT` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: MaskedAutoEncoderViT + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index b876e6a3fc..c1917e5293 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,7 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .masked_autoencoder_vit import MaskedAutoEncoderViT from .mednext import ( MedNeXt, MedNext, diff --git a/monai/networks/nets/masked_autoencoder_vit.py b/monai/networks/nets/masked_autoencoder_vit.py new file mode 100644 index 0000000000..e76f097346 --- /dev/null +++ b/monai/networks/nets/masked_autoencoder_vit.py @@ -0,0 +1,211 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +import torch +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.blocks.transformerblock import TransformerBlock +from monai.networks.layers import trunc_normal_ +from monai.utils import ensure_tuple_rep +from monai.utils.module import look_up_option + +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} + +__all__ = ["MaskedAutoEncoderViT"] + + +class MaskedAutoEncoderViT(nn.Module): + """ + Masked Autoencoder (ViT), based on: "Kaiming et al., + Masked Autoencoders Are Scalable Vision Learners " + Only a subset of the patches passes through the encoder. The decoder tries to reconstruct + the masked patches, resulting in improved training speed. + """ + + def __init__( + self, + in_channels: int, + img_size: Sequence[int] | int, + patch_size: Sequence[int] | int, + hidden_size: int = 768, + mlp_dim: int = 512, + num_layers: int = 12, + num_heads: int = 12, + masking_ratio: float = 0.75, + decoder_hidden_size: int = 384, + decoder_mlp_dim: int = 512, + decoder_num_layers: int = 4, + decoder_num_heads: int = 12, + proj_type: str = "conv", + pos_embed_type: str = "sincos", + decoder_pos_embed_type: str = "sincos", + dropout_rate: float = 0.0, + spatial_dims: int = 3, + qkv_bias: bool = False, + save_attn: bool = False, + ) -> None: + """ + Args: + in_channels: dimension of input channels or the number of channels for input. + img_size: dimension of input image. + patch_size: dimension of patch size + hidden_size: dimension of hidden layer. Defaults to 768. + mlp_dim: dimension of feedforward layer. Defaults to 512. + num_layers: number of transformer blocks. Defaults to 12. + num_heads: number of attention heads. Defaults to 12. + masking_ratio: ratio of patches to be masked. Defaults to 0.75. + decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384. + decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512. + decoder_num_layers: number of transformer blocks for decoder. Defaults to 4. + decoder_num_heads: number of attention heads for decoder. Defaults to 12. + proj_type: position embedding layer type. Defaults to "conv". + pos_embed_type: position embedding layer type. Defaults to "sincos". + decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos". + dropout_rate: fraction of the input units to drop. Defaults to 0.0. + spatial_dims: number of spatial dimensions. Defaults to 3. + qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. + save_attn: to make accessible the attention in self attention block. Defaults to False. + Examples:: + # for single channel input with image size of (96,96,96), and sin-cos positional encoding + >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16), + pos_embed_type='sincos') + # for 3-channel with image size of (128,128,128) and a learnable positional encoding + >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable') + # for 3-channel with image size of (224,224) and a masking ratio of 0.25 + >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25, + spatial_dims=2) + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + if decoder_hidden_size % decoder_num_heads != 0: + raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.") + + self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.img_size = ensure_tuple_rep(img_size, spatial_dims) + self.spatial_dims = spatial_dims + for m, p in zip(self.img_size, self.patch_size): + if m % p != 0: + raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") + + self.decoder_hidden_size = decoder_hidden_size + + if masking_ratio <= 0 or masking_ratio >= 1: + raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.") + + self.masking_ratio = masking_ratio + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + proj_type=proj_type, + pos_embed_type=pos_embed_type, + dropout_rate=dropout_rate, + spatial_dims=self.spatial_dims, + ) + blocks = [ + TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) + for _ in range(num_layers) + ] + self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size)) + + # decoder + self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size) + + self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) + + self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) + self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size)) + + decoder_blocks = [ + TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn) + for _ in range(decoder_num_layers) + ] + self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size)) + self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels) + + self._init_weights() + + def _init_weights(self): + """ + similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and + classification tokens + """ + if self.decoder_pos_embed_type == "none": + pass + elif self.decoder_pos_embed_type == "learnable": + trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0) + elif self.decoder_pos_embed_type == "sincos": + grid_size = [] + for in_size, pa_size in zip(self.img_size, self.patch_size): + grid_size.append(in_size // pa_size) + + self.decoder_pos_embedding = build_sincos_position_embedding( + grid_size, self.decoder_hidden_size, self.spatial_dims + ) + + else: + raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.") + + # initialize patch_embedding like nn.Linear (instead of nn.Conv2d) + trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0) + + def _masking(self, x, masking_ratio: float | None = None): + batch_size, num_tokens, _ = x.shape + percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio + selected_indices = torch.multinomial( + torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False + ) + x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens + mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device) + mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0 + + return x_masked, selected_indices, mask + + def forward(self, x, masking_ratio: float | None = None): + x = self.patch_embedding(x) + x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio) + + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.blocks(x) + + # decoder + x = self.decoder_embed(x) + + x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1) + x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token + x_ = x_ + self.decoder_pos_embedding + x = torch.cat([x[:, :1, :], x_], dim=1) + x = self.decoder_blocks(x) + x = self.decoder_pred(x) + + x = x[:, 1:, :] + return x, mask diff --git a/tests/test_masked_autoencoder_vit.py b/tests/test_masked_autoencoder_vit.py new file mode 100644 index 0000000000..f8f6977cc2 --- /dev/null +++ b/tests/test_masked_autoencoder_vit.py @@ -0,0 +1,160 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT +from tests.utils import skip_if_quick + +TEST_CASE_MaskedAutoEncoderViT = [] +for masking_ratio in [0.5]: + for dropout_rate in [0.6]: + for in_channels in [4]: + for hidden_size in [768]: + for img_size in [96, 128]: + for patch_size in [16]: + for num_heads in [12]: + for mlp_dim in [3072]: + for num_layers in [4]: + for decoder_hidden_size in [384]: + for decoder_mlp_dim in [512]: + for decoder_num_layers in [4]: + for decoder_num_heads in [16]: + for pos_embed_type in ["sincos", "learnable"]: + for proj_type in ["conv", "perceptron"]: + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": hidden_size, + "mlp_dim": mlp_dim, + "num_layers": num_layers, + "decoder_hidden_size": decoder_hidden_size, + "decoder_mlp_dim": decoder_mlp_dim, + "decoder_num_layers": decoder_num_layers, + "decoder_num_heads": decoder_num_heads, + "pos_embed_type": pos_embed_type, + "masking_ratio": masking_ratio, + "decoder_pos_embed_type": pos_embed_type, + "num_heads": num_heads, + "proj_type": proj_type, + "dropout_rate": dropout_rate, + }, + (2, in_channels, *([img_size] * nd)), + ( + 2, + (img_size // patch_size) ** nd, + in_channels * (patch_size**nd), + ), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_MaskedAutoEncoderViT.append(test_case) + +TEST_CASE_ill_args = [ + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}], + [{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}], + [{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}], +] + + +@skip_if_quick +class TestMaskedAutoencoderViT(unittest.TestCase): + + @parameterized.expand(TEST_CASE_MaskedAutoEncoderViT) + def test_shape(self, input_param, input_shape, expected_shape): + net = MaskedAutoEncoderViT(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_frozen_pos_embedding(self): + net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16)) + + self.assertEqual(net.decoder_pos_embedding.requires_grad, False) + + @parameterized.expand(TEST_CASE_ill_args) + def test_ill_arg(self, input_param): + with self.assertRaises(ValueError): + MaskedAutoEncoderViT(**input_param) + + def test_access_attn_matrix(self): + # input format + in_channels = 1 + img_size = (96, 96, 96) + patch_size = (16, 16, 16) + in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) + + # no data in the matrix + no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size) + no_matrix_acess_blk(torch.randn(in_shape)) + assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True + ) + matrix_acess_blk(torch.randn(in_shape)) + + assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55) + + def test_masking_ratio(self): + # input format + in_channels = 1 + img_size = (96, 96, 96) + patch_size = (16, 16, 16) + in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) + + # masking ratio 0.25 + masking_ratio_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True + ) + masking_ratio_blk(torch.randn(in_shape)) + desired_num_tokens = int( + (img_size[0] // patch_size[0]) + * (img_size[1] // patch_size[1]) + * (img_size[2] // patch_size[2]) + * (1 - 0.25) + ) + assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens + + # masking ratio 0.33 + masking_ratio_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True + ) + masking_ratio_blk(torch.randn(in_shape)) + desired_num_tokens = int( + (img_size[0] // patch_size[0]) + * (img_size[1] // patch_size[1]) + * (img_size[2] // patch_size[2]) + * (1 - 0.33) + ) + + assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens + + +if __name__ == "__main__": + unittest.main() From 44e249d7d492d858199acfca1c948faa5aa33763 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Thu, 28 Nov 2024 08:35:29 +0100 Subject: [PATCH 13/23] =?UTF-8?q?Implement=20TorchIO=20transforms=20wrappe?= =?UTF-8?q?r=20analogous=20to=20TorchVision=20transfo=E2=80=A6=20(#7579)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rms wrapper and test case Fixes #7499 . ### Description As discussed in the issue, this PR implements a wrapper class for TorchIO transforms, analogous to the TorchVision transforms wrapper. The test cases just check that transforms are callable and that after applying a transform, the result is different from the inputs. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Fabian Klopfer Signed-off-by: Fabian Klopfer Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Fabian Klopfer Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- .gitignore | 3 + docs/source/transforms.rst | 24 ++++++ environment-dev.yml | 1 + monai/transforms/__init__.py | 9 ++ monai/transforms/utility/array.py | 109 +++++++++++++++++++++++-- monai/transforms/utility/dictionary.py | 67 +++++++++++++++ requirements-dev.txt | 1 + setup.cfg | 3 + tests/test_rand_torchio.py | 54 ++++++++++++ tests/test_rand_torchiod.py | 44 ++++++++++ tests/test_torchio.py | 41 ++++++++++ tests/test_torchiod.py | 47 +++++++++++ 12 files changed, 397 insertions(+), 6 deletions(-) create mode 100644 tests/test_rand_torchio.py create mode 100644 tests/test_rand_torchiod.py create mode 100644 tests/test_torchio.py create mode 100644 tests/test_torchiod.py diff --git a/.gitignore b/.gitignore index 437677d2bb..76c6ab0d12 100644 --- a/.gitignore +++ b/.gitignore @@ -149,6 +149,9 @@ tests/testing_data/nrrd_example.nrrd # clang format tool .clang-format-bin/ +# ctags +tags + # VSCode .vscode/ *.zip diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 41bb4ae79a..d2585daf63 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1180,6 +1180,18 @@ Utility :members: :special-members: __call__ +`TorchIO` +""""""""" +.. autoclass:: TorchIO + :members: + :special-members: __call__ + +`RandTorchIO` +""""""""""""" +.. autoclass:: RandTorchIO + :members: + :special-members: __call__ + `MapLabelValue` """"""""""""""" .. autoclass:: MapLabelValue @@ -2253,6 +2265,18 @@ Utility (Dict) :members: :special-members: __call__ +`TorchIOd` +"""""""""" +.. autoclass:: TorchIOd + :members: + :special-members: __call__ + +`RandTorchIOd` +"""""""""""""" +.. autoclass:: RandTorchIOd + :members: + :special-members: __call__ + `MapLabelValued` """""""""""""""" .. autoclass:: MapLabelValued diff --git a/environment-dev.yml b/environment-dev.yml index a4651ec7e4..4a1723e8a5 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,6 +7,7 @@ channels: dependencies: - numpy>=1.24,<2.0 - pytorch>=1.9 + - torchio - torchvision - pytorch-cuda>=11.6 - pip diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2cdd965c91..d15042181b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -531,6 +531,8 @@ RandIdentity, RandImageFilter, RandLambda, + RandTorchIO, + RandTorchVision, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -540,6 +542,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, @@ -620,6 +623,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RandTorchIOd, + RandTorchIOD, + RandTorchIODict, RandTorchVisiond, RandTorchVisionD, RandTorchVisionDict, @@ -653,6 +659,9 @@ ToPILd, ToPILD, ToPILDict, + TorchIOd, + TorchIOD, + TorchIODict, TorchVisiond, TorchVisionD, TorchVisionDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 1b3c59afdb..84422a9ee5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -18,10 +18,10 @@ import sys import time import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Union import numpy as np import torch @@ -99,11 +99,14 @@ "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", + "TorchIO", "MapLabelValue", "IntensityStats", "ToDevice", "CuCIM", "RandCuCIM", + "RandTorchIO", + "RandTorchVision", "ToCupy", "ImageFilter", "RandImageFilter", @@ -1136,12 +1139,44 @@ def __call__( return concatenate((img, points_image), axis=0) -class TorchVision: +class TorchVision(Transform): """ - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. - As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input - data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + """ + img_t, *_ = convert_data_type(img, torch.Tensor) + + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out + + +class RandTorchVision(Transform, RandomizableTrait): + """ + This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. """ backend = [TransformBackends.TORCH] @@ -1172,6 +1207,68 @@ def __call__(self, img: NdarrayOrTensor): return out +class TorchIO(Transform): + """ + This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + + """ + return self.trans(img) + + +class RandTorchIO(Transform, RandomizableTrait): + """ + This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + Use this wrapper for all TorchIO transform inheriting from RandomTransform: + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + + """ + return self.trans(img) + + class MapLabelValue: """ Utility to map label values to another set of values. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 65c721e48e..7dd2397a74 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -60,6 +60,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, @@ -136,6 +137,9 @@ "RandLambdaD", "RandLambdaDict", "RandLambdad", + "RandTorchIOd", + "RandTorchIOD", + "RandTorchIODict", "RandTorchVisionD", "RandTorchVisionDict", "RandTorchVisiond", @@ -172,6 +176,9 @@ "ToTensorD", "ToTensorDict", "ToTensord", + "TorchIOD", + "TorchIODict", + "TorchIOd", "TorchVisionD", "TorchVisionDict", "TorchVisiond", @@ -1445,6 +1452,64 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class TorchIOd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms. + For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`. + """ + + backend = TorchIO.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) + + +class RandTorchIOd(MapTransform, RandomizableTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms. + For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`. + """ + + backend = TorchIO.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) + + class MapLabelValued(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. @@ -1871,8 +1936,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch ConvertToMultiChannelBasedOnBratsClassesd ) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld +TorchIOD = TorchIODict = TorchIOd TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond +RandTorchIOD = RandTorchIODict = RandTorchIOd RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd diff --git a/requirements-dev.txt b/requirements-dev.txt index 72654d3534..bffe304df4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows" types-setuptools mypy>=1.5.0, <1.12.0 ninja +torchio torchvision psutil cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10" diff --git a/setup.cfg b/setup.cfg index 694dc969d9..ecfd717aff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,7 @@ all = tensorboard gdown>=4.7.3 pytorch-ignite==0.4.11 + torchio torchvision itk>=5.2 tqdm>=4.47.0 @@ -102,6 +103,8 @@ gdown = gdown>=4.7.3 ignite = pytorch-ignite==0.4.11 +torchio = + torchio torchvision = torchvision itk = diff --git a/tests/test_rand_torchio.py b/tests/test_rand_torchio.py new file mode 100644 index 0000000000..ab212d4a11 --- /dev/null +++ b/tests/test_rand_torchio.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIO +from monai.utils import optional_import, set_determinism + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [ + [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], + [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], + [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], + [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_torchiod.py b/tests/test_rand_torchiod.py new file mode 100644 index 0000000000..52bcf7c576 --- /dev/null +++ b/tests/test_rand_torchiod.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIOd +from monai.utils import optional_import, set_determinism +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TEST_PARAMS = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIOd(unittest.TestCase): + + @parameterized.expand(TEST_PARAMS) + def test_random_transform(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIOd(**input_param)(input_data) + self.assertFalse(np.allclose(input_data["img1"], result["img1"], atol=1e-6, rtol=1e-6)) + assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchio.py b/tests/test_torchio.py new file mode 100644 index 0000000000..d2d598ca4c --- /dev/null +++ b/tests/test_torchio.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import TorchIO +from monai.utils import optional_import + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)]] + + +@skipUnless(has_torchio, "Requires torchio") +class TestTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + result = TorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchiod.py b/tests/test_torchiod.py new file mode 100644 index 0000000000..892287461c --- /dev/null +++ b/tests/test_torchiod.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.transforms import TorchIOd +from monai.utils import optional_import +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TEST_PARAMS = [ + [ + {"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)}, + {"img": TEST_TENSOR}, + ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42, + ] +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestTorchIOd(unittest.TestCase): + + @parameterized.expand(TEST_PARAMS) + def test_value(self, input_param, input_data, expected_value): + result = TorchIOd(**input_param)(input_data) + assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() From e6cae1ced23f32feb2f42943d15b7dd49481c794 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 1 Dec 2024 16:26:08 +0000 Subject: [PATCH 14/23] Bump codecov/codecov-action from 4 to 5 (#8245) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 4 to 5.
Release notes

Sourced from codecov/codecov-action's releases.

v5.0.0

v5 Release

v5 of the Codecov GitHub Action will use the Codecov Wrapper to encapsulate the CLI. This will help ensure that the Action gets updates quicker.

Migration Guide

The v5 release also coincides with the opt-out feature for tokens for public repositories. In the Global Upload Token section of the settings page of an organization in codecov.io, you can set the ability for Codecov to receive a coverage reports from any source. This will allow contributors or other members of a repository to upload without needing access to the Codecov token. For more details see how to upload without a token.

[!WARNING]
The following arguments have been changed

  • file (this has been deprecated in favor of files)
  • plugin (this has been deprecated in favor of plugins)

The following arguments have been added:

  • binary
  • gcov_args
  • gcov_executable
  • gcov_ignore
  • gcov_include
  • report_type
  • skip_validation
  • swift_project

You can see their usage in the action.yml file.

What's Changed

... (truncated)

Changelog

Sourced from codecov/codecov-action's changelog.

v5 Release

v5 of the Codecov GitHub Action will use the Codecov Wrapper to encapsulate the CLI. This will help ensure that the Action gets updates quicker.

Migration Guide

The v5 release also coincides with the opt-out feature for tokens for public repositories. In the Global Upload Token section of the settings page of an organization in codecov.io, you can set the ability for Codecov to receive a coverage reports from any source. This will allow contributors or other members of a repository to upload without needing access to the Codecov token. For more details see how to upload without a token.

[!WARNING] The following arguments have been changed

  • file (this has been deprecated in favor of files)
  • plugin (this has been deprecated in favor of plugins)

The following arguments have been added:

  • binary
  • gcov_args
  • gcov_executable
  • gcov_ignore
  • gcov_include
  • report_type
  • skip_validation
  • swift_project

You can see their usage in the action.yml file.

What's Changed

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=codecov/codecov-action&package-manager=github_actions&previous-version=4&new-version=5)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/cron.yml | 6 +++--- .github/workflows/pythonapp-gpu.yml | 2 +- .github/workflows/setupapp.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 516e2d4743..e13848f8fc 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -75,7 +75,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml @@ -123,7 +123,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml @@ -228,7 +228,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 70c3153076..d8623c8087 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -137,6 +137,6 @@ jobs: shell: bash - name: Upload coverage if: ${{ github.head_ref != 'dev' && github.event.pull_request.merged != true }} - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: files: ./coverage.xml diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 7e01f55cd9..d9ce9976b8 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -72,7 +72,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml @@ -119,7 +119,7 @@ jobs: BUILD_MONAI=1 ./runtests.sh --build --quick --min coverage xml --ignore-errors - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml From fe3ed448aba53f09d185f0decb95ca8debbff756 Mon Sep 17 00:00:00 2001 From: Vladislav Tumko <56307628+vectorvp@users.noreply.github.com> Date: Mon, 2 Dec 2024 10:37:51 +0400 Subject: [PATCH 15/23] docs: update brats classes description (#8246) Resolves #8225. ### Description Updated brats18 classes description. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Vladislav Tumko <56307628+vectorvp@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utility/array.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 84422a9ee5..2963c8a2f8 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1054,12 +1054,11 @@ def __call__( class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ - Convert labels to multi channels based on brats18 classes: - label 1 is the necrotic and non-enhancing tumor core - label 2 is the peritumoral edema - label 4 is the GD-enhancing tumor - The possible classes are TC (Tumor core), WT (Whole tumor) - and ET (Enhancing tumor). + Convert labels to multi channels based on `brats18 `_ classes, + which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor): + label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion, + label 2 is the peritumoral edema, which is counted only under WT subregion, + label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] From 7f88a469f94de0a84d935589edf57e04f62da8b9 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 2 Dec 2024 23:00:54 +0800 Subject: [PATCH 16/23] Change default value of `patch_norm` to False in `SwinUNETR` (#8249) Fixes #8248 ### Description Update the default value of `patch_norm` to False in `SwinUNETR` to align with the default value in `SwinTransformer`. This [change](https://github.com/Project-MONAI/MONAI/commit/3ee4cd22a8cc7b6b4cb3c5fd228dfa9ef153e60c#diff-04583cc0f4aed09787775eec8cece2b1fd70290799a7a3a2671353f2a3cf9af3R105) modifies the default behavior of the model https://github.com/Project-MONAI/MONAI/blob/e6cae1ced23f32feb2f42943d15b7dd49481c794/monai/networks/nets/swin_unetr.py#L960 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/swin_unetr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 32b817d584..77f0d2ec2f 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -75,7 +75,7 @@ def __init__( dropout_path_rate: float = 0.0, normalize: bool = True, norm_layer: type[LayerNorm] = nn.LayerNorm, - patch_norm: bool = True, + patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, downsample: str | nn.Module = "merging", @@ -102,7 +102,7 @@ def __init__( dropout_path_rate: drop path rate. normalize: normalize output intermediate features in each stage. norm_layer: normalization layer. - patch_norm: whether to apply normalization to the patch embedding. + patch_norm: whether to apply normalization to the patch embedding. Default is False. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: number of spatial dims. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a From 9808ce23c9640462748754ba08c2617747944829 Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Tue, 3 Dec 2024 06:13:08 +0100 Subject: [PATCH 17/23] Modify Dice, Jaccard and Tversky losses (#8138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #8094. ### Description The Dice, Jaccard and Tversky losses in `monai.losses.dice` and `monai.losses.tversky` are modified based on [JDTLoss](https://github.com/zifuwanggg/JDTLosses/blob/master/losses/jdt_loss.py) and [segmentation_models.pytorch](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/losses/_functional.py). In the original versions, when `squared_pred=False`, the loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{\|x\|_p^p + \|y\|_p^p - \|x-y\|_p^p}{2}$. When $p$ is 2 (`squared_pred=True`), this reformulation becomes the classical inner product: $\langle x,y \rangle$. When $p$ is 1 (`squared_pred=False`), the reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new versions are minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, they resolves the issue with soft labels [1, 2]. In summary, there are three scenarios: * [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions are the same as the original versions. * [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. * [Scenario 3] Either $x$ or $y$ is negative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. Due to these differences, particularly in Scenarios 2 and 3, some tests fail with the new versions: * The target is non-binary: `test_multi_scale` * The input is negative: `test_dice_loss`, `test_tversky_loss`, `test_generalized_dice_loss`, `test_masked_loss`, `test_seg_loss_integration` The failures in `test_multi_scale` are expected since the original versions are incorrectly defined for non-binary targets. Furthermore, because Dice, Jaccard, and Tversky losses are fundamentally defined over probabilities—which should be nonnegative—the new versions should not be tested against negative input or target values. ### Example ``` import torch import torch.linalg as LA import torch.nn.functional as F torch.manual_seed(0) b, c, h, w = 4, 3, 32, 32 dims = (0, 2, 3) pred = torch.rand(b, c, h, w).softmax(dim=1) soft_label = torch.rand(b, c, h, w).softmax(dim=1) hard_label = torch.randint(low=0, high=c, size=(b, h, w)) one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float() def dice_old(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord intersection = torch.sum(x * y, dim=dims) return 2 * intersection / cardinality def dice_new(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord intersection = (cardinality - difference) / 2 return 2 * intersection / cardinality print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims)) print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims)) print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims)) print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims)) print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims)) print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims)) # tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317]) # tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700]) # tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.]) # tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935]) # tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503]) # tensor([1., 1., 1.]) tensor([1., 1., 1.]) ``` ### References [1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. *MICCAI 2023*. [2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. *NeurIPS 2023*. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Zifu Wang Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/losses/dice.py | 55 +++++++++++++---------- monai/losses/tversky.py | 19 ++++---- monai/losses/utils.py | 68 +++++++++++++++++++++++++++++ tests/test_dice_loss.py | 16 +++++++ tests/test_generalized_dice_loss.py | 16 +++++++ tests/test_tversky_loss.py | 16 +++++++ 6 files changed, 160 insertions(+), 30 deletions(-) create mode 100644 monai/losses/utils.py diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 3f02fae6b8..4108820bec 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -23,6 +23,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after @@ -39,8 +40,16 @@ class DiceLoss(_Loss): The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of the inter-over-union calculation to smooth results respectively, these values should be small. - The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric - Medical Image Segmentation, 3DV, 2016. + The original papers: + + Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation. 3DV 2016. + + Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with + Soft Labels. NeurIPS 2023. + + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. """ @@ -58,6 +67,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, ) -> None: """ Args: @@ -89,6 +99,8 @@ def __init__( of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -114,6 +126,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -174,21 +187,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, dim=reduce_axis) - - if self.squared_pred: - ground_o = torch.sum(target**2, dim=reduce_axis) - pred_o = torch.sum(input**2, dim=reduce_axis) - else: - ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(input, dim=reduce_axis) - - denominator = ground_o + pred_o - - if self.jaccard: - denominator = 2.0 * (denominator - intersection) + ord = 2 if self.squared_pred else 1 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label) + if not self.jaccard: + fp *= 0.5 + fn *= 0.5 + numerator = 2 * tp + self.smooth_nr + denominator = 2 * (tp + fp + fn) + self.smooth_dr - f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) + f: torch.Tensor = 1 - numerator / denominator num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -272,6 +279,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -295,6 +303,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. If True, the class-weighted intersection and union areas are first summed across the batches. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -319,6 +329,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def w_func(self, grnd): if self.w_type == str(Weight.SIMPLE): @@ -370,13 +381,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, reduce_axis) - ground_o = torch.sum(target, reduce_axis) - pred_o = torch.sum(input, reduce_axis) - - denominator = ground_o + pred_o + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label) + fp *= 0.5 + fn *= 0.5 + denominator = 2 * (tp + fp + fn) + ground_o = torch.sum(target, reduce_axis) w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: @@ -388,7 +399,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = w + infs * max_values final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 4f22bf84b4..154f34c526 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -17,6 +17,7 @@ import torch from torch.nn.modules.loss import _Loss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import LossReduction @@ -28,6 +29,9 @@ class TverskyLoss(_Loss): Sadegh et al. (2017) Tversky loss function for image segmentation using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721) + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. + Adapted from: https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631 @@ -46,6 +50,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -70,6 +75,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -93,6 +100,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -134,20 +142,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - p0 = input - p1 = 1 - p0 - g0 = target - g1 = 1 - g0 - # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - tp = torch.sum(p0 * g0, reduce_axis) - fp = self.alpha * torch.sum(p0 * g1, reduce_axis) - fn = self.beta * torch.sum(p1 * g0, reduce_axis) + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False) + fp *= self.alpha + fn *= self.beta numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr diff --git a/monai/losses/utils.py b/monai/losses/utils.py new file mode 100644 index 0000000000..782fd9c9c2 --- /dev/null +++ b/monai/losses/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.linalg as LA + + +def compute_tp_fp_fn( + input: torch.Tensor, + target: torch.Tensor, + reduce_axis: list[int], + ord: int, + soft_label: bool, + decoupled: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + reduce_axis: the axis to be reduced. + ord: the order of the vector norm. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. + decoupled: whether the input and the target should be decoupled when computing fp and fn. + Only for the original implementation when soft_label is False. + + Adapted from: + https://github.com/zifuwanggg/JDTLosses + """ + + # the original implementation that is erroneous with soft labels + if ord == 1 and not soft_label: + tp = torch.sum(input * target, dim=reduce_axis) + # the original implementation of Dice and Jaccard loss + if decoupled: + fp = torch.sum(input, dim=reduce_axis) - tp + fn = torch.sum(target, dim=reduce_axis) - tp + # the original implementation of Tversky loss + else: + fp = torch.sum(input * (1 - target), dim=reduce_axis) + fn = torch.sum((1 - input) * target, dim=reduce_axis) + # the new implementation that is correct with soft labels + # and it is identical to the original implementation with hard labels + else: + pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis) + ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis) + + if ord > 1: + pred_o = torch.pow(pred_o, exponent=ord) + ground_o = torch.pow(ground_o, exponent=ord) + difference = torch.pow(difference, exponent=ord) + + tp = (pred_o + ground_o - difference) / 2 + fp = pred_o - tp + fn = ground_o - tp + + return tp, fp, fn diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 14aa6ec241..cea6ccf113 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 5738f4a089..9706c2e746 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416597, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307748, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0}, { diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0365503ea2..73a841a55d 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { From e604d1841fe60c0ffb6978ae4116535ca8d8f34f Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:09:22 +0800 Subject: [PATCH 18/23] Fix TypeError in meshgrid (#8252) Fixes #8251 Remove `indexing="ij"` will not affect anything since it's default behavior in torch. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/pos_embed_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index 21586e56da..a9c5176bc2 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -56,7 +56,7 @@ def build_sincos_position_embedding( grid_h = torch.arange(h, dtype=torch.float32) grid_w = torch.arange(w, dtype=torch.float32) - grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij") + grid_h, grid_w = torch.meshgrid(grid_h, grid_w) if embed_dim % 4 != 0: raise AssertionError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") @@ -75,7 +75,7 @@ def build_sincos_position_embedding( grid_w = torch.arange(w, dtype=torch.float32) grid_d = torch.arange(d, dtype=torch.float32) - grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij") + grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d) if embed_dim % 6 != 0: raise AssertionError("Embed dimension must be divisible by 6 for 3D sin-cos position embedding") From 21920a34bc00a114e430e1943e1fd1f572880919 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 10 Dec 2024 20:03:35 +0800 Subject: [PATCH 19/23] Add platform-specific constraints to setup.cfg (#8260) Fixes #8258 ### Description Include platform_system conditions for dependencies in setup.cfg ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- setup.cfg | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/setup.cfg b/setup.cfg index ecfd717aff..0c69051218 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,10 +61,10 @@ all = tqdm>=4.47.0 lmdb psutil - cucim-cu12; python_version >= '3.9' and python_version <= '3.10' + cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10' openslide-python - tifffile - imagecodecs + tifffile; platform_system == "Linux" or platform_system == "Darwin" + imagecodecs; platform_system == "Linux" or platform_system == "Darwin" pandas einops transformers>=4.36.0, <4.41.0; python_version <= '3.10' @@ -78,7 +78,7 @@ all = pynrrd pydicom h5py - nni + nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna onnx>=1.13.0 onnxruntime; python_version <= '3.10' @@ -116,13 +116,13 @@ lmdb = psutil = psutil cucim = - cucim-cu12 + cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10' openslide = openslide-python tifffile = - tifffile + tifffile; platform_system == "Linux" or platform_system == "Darwin" imagecodecs = - imagecodecs + imagecodecs; platform_system == "Linux" or platform_system == "Darwin" pandas = pandas einops = @@ -152,7 +152,7 @@ pydicom = h5py = h5py nni = - nni + nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna = optuna onnx = From e1e3d8ebc1c7247aad9f1bffc649c5a20084340f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 19 Dec 2024 02:50:42 +0000 Subject: [PATCH 20/23] Modify Workflow to Allow IterableDataset Inputs (#8263) ### Description This modifies the behaviour of `Workflow` to permit `IterableDataset` to be used correctly. A check against the `epoch_length` value is removed, to allow that value to be `None`, and a test is added to verify this. The length of a data loader is not defined when using iterable datasets, so try/raise is added to allow that to be queried safely. This is related to my work on the streaming support, in my [prototype gist](https://gist.github.com/ericspod/1904713716b45631260784ac3fcd6fb3) I had to provide a bogus epoch length value in the then change it to `None` later once the evaluator object was created. This PR will remove the need for this hack. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot Signed-off-by: Eric Kerfoot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot --- monai/engines/workflow.py | 22 +++++++++++----------- tests/test_iterable_dataset.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 3629659db1..0c36da6d3d 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -12,7 +12,7 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence, Sized from typing import TYPE_CHECKING, Any import torch @@ -121,24 +121,24 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, ) -> None: - if iteration_update is not None: - super().__init__(iteration_update) - else: - super().__init__(self._iteration) + super().__init__(self._iteration if iteration_update is None else iteration_update) if isinstance(data_loader, DataLoader): - sampler = data_loader.__dict__["sampler"] + sampler = getattr(data_loader, "sampler", None) + + # set the epoch value for DistributedSampler objects when an epoch starts if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - if epoch_length is None: + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader, Sized): + try: epoch_length = len(data_loader) - else: - if epoch_length is None: - raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") + except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None: iteration=0, epoch=0, max_epochs=max_epochs, - epoch_length=epoch_length, + epoch_length=epoch_length, # None when the dataset is iterable and so has no length output=None, batch=None, metrics={}, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index cfa711e4c0..fb554e391c 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -18,8 +18,10 @@ import nibabel as nib import numpy as np +import torch.nn as nn from monai.data import DataLoader, Dataset, IterableDataset +from monai.engines import SupervisedEvaluator from monai.transforms import Compose, LoadImaged, SimulateDelayd @@ -59,6 +61,17 @@ def test_shape(self): for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) + def test_supervisedevaluator(self): + """ + Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader. + """ + data = list(range(10)) + dl = DataLoader(IterableDataset(data)) + evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity()) + evaluator.run() # fails if the epoch length or other internal setup is not done correctly + + self.assertEqual(evaluator.state.iteration, len(data)) + if __name__ == "__main__": unittest.main() From efff647a332f9520e7b7d7565893bd16ab26e041 Mon Sep 17 00:00:00 2001 From: Hsin-Yuan Hsieh <84929237+Jerome-Hsieh@users.noreply.github.com> Date: Sat, 21 Dec 2024 22:18:23 +0800 Subject: [PATCH 21/23] enhance download_and_extract (#8216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #5463 ### Description According to issue, the error messages are not very intuitive. I think maybe we can check if the file name matches the downloaded file’s base name before starting the download. If it doesn’t match, it will notify user. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: jerome_Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/apps/utils.py | 39 +++++++++++++++++++++++++++++- tests/test_download_and_extract.py | 3 ++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index c2e17d3247..95c1450f2a 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -15,6 +15,7 @@ import json import logging import os +import re import shutil import sys import tarfile @@ -30,7 +31,9 @@ from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import +requests, has_requests = optional_import("requests") gdown, has_gdown = optional_import("gdown", "4.7.3") +BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup") if TYPE_CHECKING: from tqdm import tqdm @@ -298,6 +301,29 @@ def extractall( ) +def get_filename_from_url(data_url: str) -> str: + """ + Get the filename from the URL link. + """ + try: + response = requests.head(data_url, allow_redirects=True) + content_disposition = response.headers.get("Content-Disposition") + if content_disposition: + filename = re.findall('filename="?([^";]+)"?', content_disposition) + if filename: + return str(filename[0]) + if "drive.google.com" in data_url: + response = requests.get(data_url) + if "text/html" in response.headers.get("Content-Type", ""): + soup = BeautifulSoup(response.text, "html.parser") + filename_div = soup.find("span", {"class": "uc-name-size"}) + if filename_div: + return str(filename_div.find("a").text) + return _basename(data_url) + except Exception as e: + raise Exception(f"Error processing URL: {e}") from e + + def download_and_extract( url: str, filepath: PathLike = "", @@ -327,7 +353,18 @@ def download_and_extract( be False. progress: whether to display progress bar. """ + url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes) + filepath_ext = "".join(Path(_basename(filepath)).suffixes) + if filepath not in ["", "."]: + if filepath_ext == "": + new_filepath = Path(filepath).with_suffix(url_filename_ext) + logger.warning( + f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}" + ) + filepath = new_filepath + if filepath_ext and filepath_ext != url_filename_ext: + raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}") with tempfile.TemporaryDirectory() as tmp_dir: - filename = filepath or Path(tmp_dir, _basename(url)).resolve() + filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve() download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index 555f7dc250..439a11bbc1 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -20,9 +20,10 @@ from parameterized import parameterized from monai.apps import download_and_extract, download_url, extractall -from tests.utils import skip_if_downloading_fails, skip_if_quick, testing_data_config +from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, testing_data_config +@SkipIfNoModule("requests") class TestDownloadAndExtract(unittest.TestCase): @skip_if_quick From d36f0c80f716c5ad040f0f2cad11407e68d0f33a Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:33:57 +0800 Subject: [PATCH 22/23] enable gpu load nifti (#8188) Related to https://github.com/Project-MONAI/MONAI/issues/8241 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/image_reader.py | 86 ++++++++++++++++++++++++++++++++---- monai/data/meta_tensor.py | 1 - monai/transforms/io/array.py | 1 - tests/test_init_reader.py | 19 ++++++++ tests/test_load_image.py | 41 ++++++++++++++++- 5 files changed, 136 insertions(+), 12 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b4ae562911..5bc38f69ea 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -12,8 +12,11 @@ from __future__ import annotations import glob +import gzip +import io import os import re +import tempfile import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -51,6 +54,9 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) +cp, has_cp = optional_import("cupy") +kvikio, has_kvikio = optional_import("kvikio") + __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict): ) -def _stack_images(image_list: list, meta_dict: dict): +def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False): if len(image_list) <= 1: return image_list[0] if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + if to_cupy and has_cp: + return cp.concatenate(image_list, axis=channel_dim) return np.concatenate(image_list, axis=channel_dim) # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + if to_cupy and has_cp: + return cp.stack(image_list, axis=0) return np.stack(image_list, axis=0) @@ -864,12 +874,18 @@ class NibabelReader(ImageReader): Load NIfTI format images based on Nibabel library. Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) channel_dim: the channel dimension of the input image, default is None. this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. if None, `original_channel_dim` will be either `no_channel` or `-1`. most Nifti files are usually "channel last", no need to specify this argument for them. + as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) + to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading. + Default is False. CuPy and Kvikio are required for this option. + Note: For compressed NIfTI files, some operations may still be performed on CPU memory, + and the acceleration may not be significant. In some cases, it may be slower than loading on CPU. + In practical use, it's recommended to add a warm up call before the actual loading. + A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py @@ -880,14 +896,42 @@ def __init__( channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, + to_gpu: bool = False, **kwargs, ): super().__init__() self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + if to_gpu and (not has_cp or not has_kvikio): + warnings.warn( + "NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading." + ) + to_gpu = False + + if to_gpu: + self.warmup_kvikio() + + self.to_gpu = to_gpu self.kwargs = kwargs + def warmup_kvikio(self): + """ + Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc. + This can accelerate the data loading process when `to_gpu` is set to True. + """ + if has_cp and has_kvikio: + a = cp.arange(100) + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b) + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = filenames kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ + # TODO: the actual type is list[np.ndarray | cp.ndarray] + # should figure out how to define correct types without having cupy not found error + # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 img_array: list[np.ndarray] = [] compatible_meta: dict = {} - for i in ensure_tuple(img): + for i, filename in zip(ensure_tuple(img), self.filenames): header = self._get_meta_dict(i) header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) @@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) + data = self._get_array_data(i, filename) if self.squeeze_non_spatial_dims: for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): if data.shape[d - 1] == 1: @@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - return _stack_images(img_array, compatible_meta), compatible_meta + return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta def _get_meta_dict(self, img) -> dict: """ @@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img): spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) - def _get_array_data(self, img): + def _get_array_data(self, img, filename): """ Get the raw array data of the image, converted to Numpy array. Args: img: a Nibabel image object loaded from an image file. - - """ + filename: file name of the image. + + """ + if self.to_gpu: + file_size = os.path.getsize(filename) + image = cp.empty(file_size, dtype=cp.uint8) + with kvikio.CuFile(filename, "r") as f: + f.read(image) + if filename.endswith(".nii.gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # and may be slower than CPU loading in some cases. + warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.") + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + image = cp.frombuffer(decompressed_data, dtype=cp.uint8) + data_shape = img.shape + data_offset = img.dataobj.offset + data_dtype = img.dataobj.dtype + return image[data_offset:].view(data_dtype).reshape(data_shape, order="F") return np.asanyarray(img.dataobj, order="C") diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ac171e8508..c4c491e1b9 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -553,7 +553,6 @@ def ensure_torch_and_prune_meta( However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray - # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e71870fc9..1023cd7a7d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -286,7 +286,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered: {self.readers}.\n{msg}" ) - img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index cb45cb5146..8331f742ec 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -30,6 +30,17 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_load_image_to_gpu(self): + for to_gpu in [True, False]: + instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance1, LoadImage) + + instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance2, LoadImaged) + @SkipIfNoModule("itk") @SkipIfNoModule("nibabel") @SkipIfNoModule("PIL") @@ -58,6 +69,14 @@ def test_readers(self): inst = NrrdReader() self.assertIsInstance(inst, NrrdReader) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_readers_to_gpu(self): + for to_gpu in [True, False]: + inst = NibabelReader(to_gpu=to_gpu) + self.assertIsInstance(inst, NibabelReader) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 0207079d7d..a3e6d7bcfc 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -29,7 +29,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImage from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config +from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config itk, has_itk = optional_import("itk", allow_namespace_pkg=True) ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator") @@ -74,6 +74,22 @@ def get_data(self, _obj): TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)] + +TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)] + +TEST_CASE_GPU_3 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii", "test_image2.nii", "test_image3.nii"], + (3, 128, 128, 128), +] + +TEST_CASE_GPU_4 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + (3, 128, 128, 128), +] + TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] @@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4]) + def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): + test_image = np.random.rand(128, 128, 128) + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadImage(image_only=True, **input_param)(filenames) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + self.assertEqual(result.meta["space"], "RAS") + assert_allclose(result.affine, torch.eye(4)) + self.assertTupleEqual(result.shape, expected_shape) + + # verify gpu and cpu loaded data are the same + input_param_cpu = input_param.copy() + input_param_cpu["to_gpu"] = False + result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames) + self.assertTrue(torch.equal(result_cpu, result.cpu())) + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): test_image = np.random.rand(128, 128, 128) From 996e876e7542f683508aa04e74b97e284bbde72b Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Mon, 30 Dec 2024 21:14:55 +0800 Subject: [PATCH 23/23] 8274-mitigate-gpu-load-check (#8275) Fixes #8274 . ### Description I tried to use A100 with same container to test, but could not reproduce the issue. Therefore, I think here we can do a bit change on the test, and if there are still same issues, I will try to check more. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang --- tests/test_load_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index a3e6d7bcfc..aa8b71b7fa 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -233,7 +233,7 @@ def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): input_param_cpu = input_param.copy() input_param_cpu["to_gpu"] = False result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames) - self.assertTrue(torch.equal(result_cpu, result.cpu())) + self.assertTrue(torch.allclose(result_cpu, result.cpu(), atol=1e-6)) @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape):