From f3e7d03d93b30a79a37581108067e6bb8f428c94 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 22 Mar 2024 14:07:47 +0100 Subject: [PATCH 01/22] Added function get_img_at_mpp to class OpenSlideWSIReader of module wsi_reader.py --- monai/data/wsi_reader.py | 82 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index b31d4d9c3a..f3f099160f 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -19,6 +19,7 @@ import numpy as np import torch +import cv2 from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images @@ -940,6 +941,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + mpp_closest_lvl = mpp_list[closest_lvl] + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + return closest_lvl_wsi + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + return closest_lvl_wsi + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + return closest_lvl_wsi + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From 88002e8a91d6466a2fdcb60a19b8cd1ed9e89558 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 22 Mar 2024 16:42:18 +0100 Subject: [PATCH 02/22] Added get_img_at_mpp to class CuCIMWSIReader --- monai/data/wsi_reader.py | 98 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index f3f099160f..0c49143220 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,6 +603,23 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) + + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + return self.reader.get_img_at_mpp(wsi, mpp, atol, rtol) def get_power(self, wsi, level: int) -> float: """ @@ -745,6 +762,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + mpp_closest_lvl = mpp_list[closest_lvl] + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + return closest_lvl_wsi + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + return closest_lvl_wsi + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + return closest_lvl_wsi + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From a9fe772d56a458c853bbb69ecef4585cb2ef564d Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 19:18:42 +0100 Subject: [PATCH 03/22] Added function get_img_at_mpp to class TifffileWSIReader; changed resizing function to Image.resize, cucim.skimage.transform.resize --- monai/data/wsi_reader.py | 160 +++++++++++++++++++++++++++++++-------- 1 file changed, 130 insertions(+), 30 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 0c49143220..4f02cee285 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -19,7 +19,6 @@ import numpy as np import torch -import cv2 from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images @@ -778,9 +777,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ + cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + cp, _ = optional_import("cupy") + user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] @@ -797,13 +801,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - return closest_lvl_wsi else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x @@ -814,15 +817,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - return closest_lvl_wsi + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -833,15 +837,18 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_y = mpp_closest_lvl_y / user_mpp_y closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') - return closest_lvl_wsi + + wsi_arr = cp.asnumpy(closest_lvl_wsi) + return wsi_arr def get_power(self, wsi, level: int) -> float: """ @@ -1055,9 +1062,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ + pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.level_dimensions[closest_lvl] @@ -1078,9 +1088,8 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - return closest_lvl_wsi else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x @@ -1091,15 +1100,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) - + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - return closest_lvl_wsi + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -1110,15 +1118,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_y = mpp_closest_lvl_y / user_mpp_y closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) - + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') - return closest_lvl_wsi + + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr def get_power(self, wsi, level: int) -> float: """ @@ -1276,8 +1285,10 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: - unit = str(unit.value)[8:] + if unit is not None: # Needs to be extended + # unit = str(unit.value)[8:] + unit = str(unit.value.name).lower() # TODO: Merge both methods + else: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") unit = "micrometer" @@ -1290,6 +1301,95 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + pil_image, _ = optional_import("PIL", name="Image") + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # QuPath show 4 levels in the pyramid, but len(wsi.pages) is 1? + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + + mpp_closest_lvl = mpp_list[closest_lvl] + + lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] # Returns size in (height, width) + closest_lvl_dim = lvl_dims[closest_lvl] + closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = lvl_dims[closest_lvl] + closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From feac0dc57ef2cc6fd3f24204d43881fd186c5595 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 23:17:37 +0100 Subject: [PATCH 04/22] Small changes --- monai/data/wsi_reader.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 4f02cee285..3f2e26f9e2 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -607,8 +607,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -765,8 +767,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -786,7 +790,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] # x,y notation print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -805,7 +809,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) # size in x,y notation else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -823,8 +827,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) # output_shape in row, col notation print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') else: @@ -843,7 +846,6 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') @@ -1050,8 +1052,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -1123,7 +1127,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) # Output size in x,y notation print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') wsi_arr = np.array(closest_lvl_wsi) @@ -1305,7 +1309,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + If the user-provided mpp is larger than the mpp of the closest level the image is downscaled to a resolution that matches the user-provided mpp. Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. @@ -1319,8 +1323,8 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # QuPath show 4 levels in the pyramid, but len(wsi.pages) is 1? - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] @@ -1358,7 +1362,6 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) @@ -1378,7 +1381,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) From 81940261f2c35ae5f8c4485054523b9306a3eeff Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 23:21:57 +0100 Subject: [PATCH 05/22] Small changes --- monai/data/wsi_reader.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 3f2e26f9e2..9c6ee9c387 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -790,7 +790,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] # x,y notation + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -809,7 +809,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) # size in x,y notation + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -827,13 +827,13 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) # output_shape in row, col notation + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl ds_factor_x = mpp_closest_lvl_x / user_mpp_x @@ -1115,7 +1115,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl ds_factor_x = mpp_closest_lvl_x / user_mpp_x @@ -1127,7 +1127,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) # Output size in x,y notation + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') wsi_arr = np.array(closest_lvl_wsi) @@ -1289,9 +1289,9 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: # Needs to be extended - # unit = str(unit.value)[8:] - unit = str(unit.value.name).lower() # TODO: Merge both methods + if unit is not None: # Needs to be improved + unit = str(unit.value)[8:] + # unit = str(unit.value.name).lower() # TODO: Merge both methods else: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") @@ -1309,8 +1309,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -1329,7 +1331,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_closest_lvl = mpp_list[closest_lvl] - lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] # Returns size in (height, width) + lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) From 4df0b4b61e5fded01369ea531782b4df01bac3ba Mon Sep 17 00:00:00 2001 From: cxlcl Date: Fri, 22 Mar 2024 09:54:40 -0700 Subject: [PATCH 06/22] Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient (#7308) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Based on the discussion topic [here](https://github.com/Project-MONAI/MONAI/discussions/7161#discussion-5773293), we implemented the Conjugate-Gradient algorithm for linear operator inversion, and Stein's Unbiased Risk Estimator (SURE) [1] loss for ground-truth-date free diffusion process guidance that is proposed in [2] and illustrated in the algorithm below: Screenshot 2023-12-10 at 10 19 25 PM The Conjugate-Gradient (CG) algorithm is used to solve for the inversion of the linear operator in Line-4 in the algorithm above, where the linear operator is too large to store explicitly as a matrix (such as FFT/IFFT of an image) and invert directly. Instead, we can solve for the linear inversion iteratively as in CG. The SURE loss is applied for Line-6 above. This is a differentiable loss function that can be used to train/giude an operator (e.g. neural network), where the pseudo ground truth is available but the reference ground truth is not. For example, in the MRI reconstruction, the pseudo ground truth is the zero-filled reconstruction and the reference ground truth is the fully sampled reconstruction. The reference ground truth is not available due to the lack of fully sampled. **Reference** [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics 1981 [[paper link](https://projecteuclid.org/journals/annals-of-statistics/volume-9/issue-6/Estimation-of-the-Mean-of-a-Multivariate-Normal-Distribution/10.1214/aos/1176345632.full)] [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. MICCAI 2023 [[paper link](https://arxiv.org/pdf/2310.01799.pdf)] ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Nikolas Schmitz --- docs/source/losses.rst | 5 + docs/source/networks.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/sure_loss.py | 200 ++++++++++++++++++++ monai/networks/layers/__init__.py | 1 + monai/networks/layers/conjugate_gradient.py | 112 +++++++++++ tests/test_conjugate_gradient.py | 55 ++++++ tests/test_sure_loss.py | 71 +++++++ 8 files changed, 450 insertions(+) create mode 100644 monai/losses/sure_loss.py create mode 100644 monai/networks/layers/conjugate_gradient.py create mode 100644 tests/test_conjugate_gradient.py create mode 100644 tests/test_sure_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 61dd959807..ba794af3eb 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -139,6 +139,11 @@ Reconstruction Losses .. autoclass:: JukeboxLoss :members: +`SURELoss` +~~~~~~~~~~ +.. autoclass:: SURELoss + :members: + Loss Wrappers ------------- diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..b59c8af5fc 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -408,6 +408,11 @@ Layers .. autoclass:: LLTM :members: +`ConjugateGradient` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConjugateGradient + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 4ebedb2084..e937b53fa4 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -41,5 +41,6 @@ from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss +from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py new file mode 100644 index 0000000000..ebf25613a6 --- /dev/null +++ b/monai/losses/sure_loss.py @@ -0,0 +1,200 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + + +def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + First compute the difference in the complex domain, + then get the absolute value and take the mse + + Args: + x, y - B, 2, H, W real valued tensors representing complex numbers + or B,1,H,W complex valued tensors + Returns: + l2_loss - scalar + """ + if not x.is_complex(): + x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) + if not y.is_complex(): + y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) + + diff = torch.abs(x - y) + return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean") + + +def sure_loss_function( + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + eps: Optional[float] = -1.0, + perturb_noise: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, +) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. For complex input, the shape is + (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) + real. + + y_ref (torch.Tensor, optional): The reference output tensor of shape + (B, C, H, W) used to compute the divergence. Defaults to None. For + complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, + the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Set to -1 to set it + automatically estimated based on y_pseudo_gtk + + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). + Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + complex_input(bool, optional): Whether the input is complex or not. + Defaults to False. + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + # perturb input + if perturb_noise is None: + perturb_noise = torch.randn_like(x) + if eps == -1.0: + eps = float(torch.abs(y_pseudo_gt.max())) / 1000 + # get y_ref if not provided + if y_ref is None: + y_ref = operator(x) + + # get perturbed output + x_perturbed = x + eps * perturb_noise + y_perturbed = operator(x_perturbed) + # divergence + divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore + # l2 loss between y_ref, y_pseudo_gt + if complex_input: + l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) + else: + # real input + l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean") + + # sure loss + sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3]) + return sure_loss + + +class SURELoss(_Loss): + """ + Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator. + + This is a differentiable loss function that can be used to train/guide an + operator (e.g. neural network), where the pseudo ground truth is available + but the reference ground truth is not. For example, in the MRI + reconstruction, the pseudo ground truth is the zero-filled reconstruction + and the reference ground truth is the fully sampled reconstruction. Often, + the reference ground truth is not available due to the lack of fully sampled + data. + + The original SURE loss is proposed in [1]. The SURE loss used for guiding + the diffusion model based MRI reconstruction is proposed in [2]. + + Reference + + [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics + + [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. + (https://arxiv.org/pdf/2310.01799.pdf) + """ + + def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None: + """ + Args: + perturb_noise (torch.Tensor, optional): The noise vector of shape + (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Defaults to None. + """ + super().__init__() + self.perturb_noise = perturb_noise + self.eps = eps + + def forward( + self, + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, + ) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka + C=2 real. For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex + input, the shape is (B, 2, H, W) aka C=2 real. For real input, the + shape is (B, 1, H, W) real. + + y_ref (torch.Tensor, optional): The reference output tensor of the + same shape as y_pseudo_gt + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + + # check inputs shapes + if x.dim() != 4: + raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.") + if y_pseudo_gt.dim() != 4: + raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.") + if y_ref is not None and y_ref.dim() != 4: + raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.") + if x.shape != y_pseudo_gt.shape: + raise ValueError( + f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, " + f"y_pseudo_gt shape {y_pseudo_gt.shape}." + ) + if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: + raise ValueError( + f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, " + f"y_ref shape {y_ref.shape}." + ) + + # compute loss + loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) + + return loss diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..3a6e4aa554 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .conjugate_gradient import ConjugateGradient from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py new file mode 100644 index 0000000000..93a45930d7 --- /dev/null +++ b/monai/networks/layers/conjugate_gradient.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable + +import torch +from torch import nn + + +def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensors x1 and x2: sum(x1.*x2) + """ + if torch.is_complex(x1): + assert torch.is_complex(x2), "x1 and x2 must both be complex" + return torch.sum(x1.conj() * x2) + else: + return torch.sum(x1 * x2) + + +def _zdot_single(x: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensor x and itself + """ + res = _zdot(x, x) + if torch.is_complex(res): + return res.real + else: + return res + + +class ConjugateGradient(nn.Module): + """ + Congugate Gradient (CG) solver for linear systems Ax = y. + + For linear_op that is positive definite and self-adjoint, CG is + guaranteed to converge CG is often used to solve linear systems of the form + Ax = y, where A is too large to store explicitly, but can be computed via a + linear operator. + + As a result, here we won't set A explicitly as a matrix, but rather as a + linear operator. For example, A could be a FFT/IFFT operation + """ + + def __init__(self, linear_op: Callable, num_iter: int): + """ + Args: + linear_op: Linear operator + num_iter: Number of iterations to run CG + """ + super().__init__() + + self.linear_op = linear_op + self.num_iter = num_iter + + def update( + self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + perform one iteration of the CG method. It takes the current solution x, + the current search direction p, the current residual r, and the old + residual norm rsold as inputs. Then it computes the new solution, search + direction, residual, and residual norm, and returns them. + """ + + dy = self.linear_op(p) + p_dot_dy = _zdot(p, dy) + alpha = rsold / p_dot_dy + x = x + alpha * p + r = r - alpha * dy + rsnew = _zdot_single(r) + beta = rsnew / rsold + rsold = rsnew + p = beta * p + r + return x, p, r, rsold + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + run conjugate gradient for num_iter iterations to solve Ax = y + + Args: + x: tensor (real or complex); Initial guess for linear system Ax = y. + The size of x should be applicable to the linear operator. For + example, if the linear operator is FFT, then x is HCHW; if the + linear operator is a matrix multiplication, then x is a vector + + y: tensor (real or complex); Measurement. Same size as x + + Returns: + x: Solution to Ax = y + """ + # Compute residual + r = y - self.linear_op(x) + rsold = _zdot_single(r) + p = r + + # Update + for _i in range(self.num_iter): + x, p, r, rsold = self.update(x, p, r, rsold) + if rsold < 1e-10: + break + return x diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py new file mode 100644 index 0000000000..239dbe3ecd --- /dev/null +++ b/tests/test_conjugate_gradient.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.layers import ConjugateGradient + + +class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): + """Test ConjugateGradient with real-valued input: when the input is real + value, the output should be the inverse of the matrix.""" + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + # define the measurement + y = torch.tensor([1, 2, 3], dtype=torch.float) + # solve for x + x = cg_solver(torch.zeros(a_dim), y) + x_ref = torch.linalg.solve(a_mat, y) + # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + def test_complex_valued_inverse(self): + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + y = torch.tensor([1, 2, 3], dtype=torch.complex64) + x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) + x_ref = torch.linalg.solve(a_mat, y) + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py new file mode 100644 index 0000000000..945da657bf --- /dev/null +++ b/tests/test_sure_loss.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.losses import SURELoss + + +class TestSURELoss(unittest.TestCase): + def test_real_value(self): + """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" + sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) + + def operator(x): + return x + + y_pseudo_gt = torch.randn(2, 1, 128, 128) + x = torch.randn(2, 1, 128, 128) + loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_value(self): + """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" + + def operator(x): + return x + + sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1) + y_pseudo_gt = torch.randn(2, 2, 128, 128) + x = torch.randn(2, 2, 128, 128) + loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_general_input(self): + """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" + + def operator(x): + return x + + perturb_noise_real = torch.randn(2, 1, 128, 128) + perturb_noise_complex = torch.zeros(2, 2, 128, 128) + perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze() + y_pseudo_gt_real = torch.randn(2, 1, 128, 128) + y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) + y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze() + x_real = torch.randn(2, 1, 128, 128) + x_complex = torch.zeros(2, 2, 128, 128) + x_complex[:, 0, :, :] = x_real.squeeze() + + sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) + sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) + + loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) + loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + + +if __name__ == "__main__": + unittest.main() From d989c18e9ae3f5d338156ef1c32da4561bf07cbb Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 00:13:12 +0100 Subject: [PATCH 07/22] Renamed function to get_wsi_at_mpp Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 9c6ee9c387..5e3b0e9d36 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,7 +603,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -620,7 +620,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 rtol: the acceptable relative tolerance for resolution in micro per pixel. """ - return self.reader.get_img_at_mpp(wsi, mpp, atol, rtol) + return self.reader.get_wsi_at_mpp(wsi, mpp, atol, rtol) def get_power(self, wsi, level: int) -> float: """ @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1048,7 +1048,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1305,7 +1305,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From 105f00b7c8c1bd1cbee5f6fdc0841b001fc1e636 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Mar 2024 23:53:13 +0000 Subject: [PATCH 08/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 5e3b0e9d36..1f036e334e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -602,7 +602,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. @@ -829,7 +829,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -1088,7 +1088,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') @@ -1326,7 +1326,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] @@ -1348,7 +1348,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') From 5db27c1e15e5b2f2db8a523572ece33edd61807a Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 02:21:13 +0100 Subject: [PATCH 09/22] Reformatted wsi_reader.py Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 42 +++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 1f036e334e..16a3150c4a 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -785,14 +785,14 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 cp, _ = optional_import("cupy") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -808,8 +808,10 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -821,14 +823,16 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) wsi_arr = cp.array(closest_lvl_wsi) target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -839,15 +843,17 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) wsi_arr = cp.array(closest_lvl_wsi) target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -1075,7 +1081,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.level_dimensions[closest_lvl] - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -1091,7 +1097,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) else: @@ -1110,7 +1116,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -1128,7 +1134,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1335,7 +1341,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -1351,7 +1357,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) else: @@ -1370,7 +1376,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -1390,7 +1396,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = np.array(closest_lvl_wsi) return wsi_arr From 18e82bd0674221331cfd82ff20785774688ce296 Mon Sep 17 00:00:00 2001 From: monai-bot <64792179+monai-bot@users.noreply.github.com> Date: Mon, 25 Mar 2024 07:26:43 +0000 Subject: [PATCH 10/22] auto updates (#7577) Signed-off-by: monai-bot Signed-off-by: monai-bot Signed-off-by: Nikolas Schmitz --- tests/test_conjugate_gradient.py | 1 + tests/test_sure_loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index 239dbe3ecd..64efe3b168 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -19,6 +19,7 @@ class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): """Test ConjugateGradient with real-valued input: when the input is real value, the output should be the inverse of the matrix.""" diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 945da657bf..903f9bd2ca 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -19,6 +19,7 @@ class TestSURELoss(unittest.TestCase): + def test_real_value(self): """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) From 5bb531e8b3ae16316b465162998072014fb50792 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 11:18:03 +0100 Subject: [PATCH 11/22] Fixed return type Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 16a3150c4a..d7cfb444e3 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,7 +603,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1054,7 +1054,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From 5214c56241509fa447e3cb4a8a59a515e287a8fe Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 12:07:06 +0100 Subject: [PATCH 12/22] Small fixes Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index d7cfb444e3..be121efa40 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> Any: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1311,7 +1311,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From 3f055a9022d027386566e2e94420f997360988da Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Mon, 25 Mar 2024 22:13:56 -0400 Subject: [PATCH 13/22] Remove nested error propagation on `ConfigComponent` instantiate (#7569) Fixes #7451 ### Description Reduces the length of error messages and error messages being propagated twice. This helps debug better when long `ConfigComponent`s are being instantiated. Refer to issue #7451 for more details ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Suraj Pai Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/bundle/config_item.py | 5 +---- monai/utils/module.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 844d5b30bf..e5122bf3de 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -289,10 +289,7 @@ def instantiate(self, **kwargs: Any) -> object: mode = self.get_config().get("_mode_", CompInitMode.DEFAULT) args = self.resolve_args() args.update(kwargs) - try: - return instantiate(modname, mode, **args) - except Exception as e: - raise RuntimeError(f"Failed to instantiate {self}") from e + return instantiate(modname, mode, **args) class ConfigExpression(ConfigItem): diff --git a/monai/utils/module.py b/monai/utils/module.py index 5e058c105b..6f301d8067 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -272,7 +272,7 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: return pdb.runcall(component, **kwargs) except Exception as e: raise RuntimeError( - f"Failed to instantiate component '{__path}' with kwargs: {kwargs}" + f"Failed to instantiate component '{__path}' with keywords: {','.join(kwargs.keys())}" f"\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode." ) from e From 3264079906ffa02055c80eb427f7157fd398b151 Mon Sep 17 00:00:00 2001 From: Juampa <1523654+juampatronics@users.noreply.github.com> Date: Tue, 26 Mar 2024 03:57:36 +0100 Subject: [PATCH 14/22] 2872 implementation of mixup, cutmix and cutout (#7198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #2872 ### Description Implementation of mixup, cutmix and cutout as described in the original papers. Current implementation support both, the dictionary-based batches and tuples of tensors. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: monai-bot Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Ishan Dutta Signed-off-by: dependabot[bot] Signed-off-by: kaibo Signed-off-by: heyufan1995 Signed-off-by: binliu Signed-off-by: axel.vlaminck Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ishan Dutta Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: axel.vlaminck Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl --- docs/source/transforms.rst | 42 +++++ docs/source/transforms_idx.rst | 10 + monai/transforms/__init__.py | 12 ++ monai/transforms/regularization/__init__.py | 10 + monai/transforms/regularization/array.py | 173 ++++++++++++++++++ monai/transforms/regularization/dictionary.py | 97 ++++++++++ tests/test_regularization.py | 90 +++++++++ 7 files changed, 434 insertions(+) create mode 100644 monai/transforms/regularization/__init__.py create mode 100644 monai/transforms/regularization/array.py create mode 100644 monai/transforms/regularization/dictionary.py create mode 100644 tests/test_regularization.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8990e7991d..bd3feb3497 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -661,6 +661,27 @@ Post-processing :members: :special-members: __call__ +Regularization +^^^^^^^^^^^^^^ + +`CutMix` +"""""""" +.. autoclass:: CutMix + :members: + :special-members: __call__ + +`CutOut` +"""""""" +.. autoclass:: CutOut + :members: + :special-members: __call__ + +`MixUp` +""""""" +.. autoclass:: MixUp + :members: + :special-members: __call__ + Signal ^^^^^^^ @@ -1707,6 +1728,27 @@ Post-processing (Dict) :members: :special-members: __call__ +Regularization (Dict) +^^^^^^^^^^^^^^^^^^^^^ + +`CutMixd` +""""""""" +.. autoclass:: CutMixd + :members: + :special-members: __call__ + +`CutOutd` +""""""""" +.. autoclass:: CutOutd + :members: + :special-members: __call__ + +`MixUpd` +"""""""" +.. autoclass:: MixUpd + :members: + :special-members: __call__ + Signal (Dict) ^^^^^^^^^^^^^ diff --git a/docs/source/transforms_idx.rst b/docs/source/transforms_idx.rst index f4d02a483f..650d45db71 100644 --- a/docs/source/transforms_idx.rst +++ b/docs/source/transforms_idx.rst @@ -74,6 +74,16 @@ Post-processing post.array post.dictionary +Regularization +^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: _gen + :nosignatures: + + regularization.array + regularization.dictionary + Signal ^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2aa8fbf8a1..349533fb3e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -336,6 +336,18 @@ VoteEnsembled, VoteEnsembleDict, ) +from .regularization.array import CutMix, CutOut, MixUp +from .regularization.dictionary import ( + CutMixd, + CutMixD, + CutMixDict, + CutOutd, + CutOutD, + CutOutDict, + MixUpd, + MixUpD, + MixUpDict, +) from .signal.array import ( SignalContinuousWavelet, SignalFillEmpty, diff --git a/monai/transforms/regularization/__init__.py b/monai/transforms/regularization/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/regularization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py new file mode 100644 index 0000000000..6c9022d647 --- /dev/null +++ b/monai/transforms/regularization/array.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from math import ceil, sqrt + +import torch + +from ..transform import RandomizableTransform + +__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] + + +class Mixer(RandomizableTransform): + def __init__(self, batch_size: int, alpha: float = 1.0) -> None: + """ + Mixer is a base class providing the basic logic for the mixup-class of + augmentations. In all cases, we need to sample the mixing weights for each + sample (lambda in the notation used in the papers). Also, pairs of samples + being mixed are picked by randomly shuffling the batch samples. + + Args: + batch_size (int): number of samples per batch. That is, samples are expected tp + be of size batchsize x channels [x depth] x height x width. + alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha) + distribution. Defaults to 1.0, the uniform distribution. + """ + super().__init__() + if alpha <= 0: + raise ValueError(f"Expected positive number, but got {alpha = }") + self.alpha = alpha + self.batch_size = batch_size + + @abstractmethod + def apply(self, data: torch.Tensor): + raise NotImplementedError() + + def randomize(self, data=None) -> None: + """ + Sometimes you need may to apply the same transform to different tensors. + The idea is to get a sample and then apply it with apply() as often + as needed. You need to call this method everytime you apply the transform to a new + batch. + """ + self._params = ( + torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), + self.R.permutation(self.batch_size), + ) + + +class MixUp(Mixer): + """MixUp as described in: + Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz. + mixup: Beyond Empirical Risk Minimization, ICLR 2018 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. + """ + + def apply(self, data: torch.Tensor): + weight, perm = self._params + nsamples, *dims = data.shape + if len(weight) != nsamples: + raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") + + if len(dims) not in [3, 4]: + raise ValueError("Unexpected number of dimensions") + + mixweight = weight[(Ellipsis,) + (None,) * len(dims)] + return mixweight * data + (1 - mixweight) * data[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + if labels is None: + return self.apply(data) + return self.apply(data), self.apply(labels) + + +class CutMix(Mixer): + """CutMix augmentation as described in: + Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo. + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, + ICCV 2019 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles used during for mixing. + Please refer to the paper for details. + + The most common use case is something close to: + + .. code-block:: python + + cm = CutMix(batch_size=8, alpha=0.5) + for batch in loader: + images, labels = batch + augimg, auglabels = cm(images, labels) + output = model(augimg) + loss = loss_function(output, auglabels) + ... + + """ + + def apply(self, data: torch.Tensor): + weights, perm = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + (1 - mask) * data[perm, ...] + + def apply_on_labels(self, labels: torch.Tensor): + weights, perm = self._params + nsamples, *dims = labels.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mixweight = weights[(Ellipsis,) + (None,) * len(dims)] + return mixweight * labels + (1 - mixweight) * labels[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + augmented = self.apply(data) + return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + + +class CutOut(Mixer): + """Cutout as described in the paper: + Terrance DeVries, Graham W. Taylor. + Improved Regularization of Convolutional Neural Networks with Cutout, + arXiv:1708.04552 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles being cut put. + Please refer to the paper for details. + """ + + def apply(self, data: torch.Tensor): + weights, _ = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + + def __call__(self, data: torch.Tensor): + self.randomize() + return self.apply(data) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py new file mode 100644 index 0000000000..373913da99 --- /dev/null +++ b/monai/transforms/regularization/dictionary.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.config import KeysCollection +from monai.utils.misc import ensure_tuple + +from ..transform import MapTransform +from .array import CutMix, CutOut, MixUp + +__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] + + +class MixUpd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.MixUp`. + + Notice that the mixup transformation will be the same for all entries + for consistency, i.e. images and labels must be applied the same augmenation. + """ + + def __init__( + self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixup = MixUp(batch_size, alpha) + + def __call__(self, data): + self.mixup.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixup.apply(data[k]) + return result + + +class CutMixd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutMix`. + + Notice that the mixture weights will be the same for all entries + for consistency, i.e. images and labels must be aggregated with the same weights, + but the random crops are not. + """ + + def __init__( + self, + keys: KeysCollection, + batch_size: int, + label_keys: KeysCollection | None = None, + alpha: float = 1.0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixer = CutMix(batch_size, alpha) + self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] + + def __call__(self, data): + self.mixer.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixer.apply(data[k]) + for k in self.label_keys: + result[k] = self.mixer.apply_on_labels(data[k]) + return result + + +class CutOutd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutOut`. + + Notice that the cutout is different for every entry in the dictionary. + """ + + def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.cutout = CutOut(batch_size) + + def __call__(self, data): + result = dict(data) + self.cutout.randomize() + for k in self.keys: + result[k] = self.cutout(data[k]) + return result + + +MixUpD = MixUpDict = MixUpd +CutMixD = CutMixDict = CutMixd +CutOutD = CutOutDict = CutOutd diff --git a/tests/test_regularization.py b/tests/test_regularization.py new file mode 100644 index 0000000000..d381ea72ca --- /dev/null +++ b/tests/test_regularization.py @@ -0,0 +1,90 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd + + +class TestMixup(unittest.TestCase): + def test_mixup(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup = MixUp(6, 1.0) + output = mixup(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + + with self.assertRaises(ValueError): + MixUp(6, -0.5) + + mixup = MixUp(6, 0.5) + for dims in [2, 3]: + with self.assertRaises(ValueError): + shape = (5, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup(sample) + + def test_mixupd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + mixup = MixUpd(["a", "b"], 6) + output = mixup(sample) + self.assertTrue(torch.allclose(output["a"], output["b"])) + + with self.assertRaises(ValueError): + MixUpd(["k1", "k2"], 6, -0.5) + + +class TestCutMix(unittest.TestCase): + def test_cutmix(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutmix = CutMix(6, 1.0) + output = cutmix(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) + + def test_cutmixd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + label = torch.randint(0, 1, shape) + sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} + cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + output = cutmix(sample) + # croppings are different on each application + self.assertTrue(not torch.allclose(output["a"], output["b"])) + # but mixing of labels is not affected by it + self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) + + +class TestCutOut(unittest.TestCase): + def test_cutout(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutout = CutOut(6, 1.0) + output = cutout(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) + + +if __name__ == "__main__": + unittest.main() From 6fcc4a6995a012fca2d6a7928ec0ff64ce9672c3 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Wed, 31 Jul 2024 19:29:47 +0200 Subject: [PATCH 15/22] Updated function get_wsi_at_mpp; added function _resize_to_mpp_res to reduce redundancy; for get_mpp of TiffFileWSIReader: added check to prevent division by zero error. --- monai/data/wsi_reader.py | 236 ++++++++++++++++++++------------------- 1 file changed, 122 insertions(+), 114 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index be121efa40..7b8e3b12d1 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -781,7 +781,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + # cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") user_mpp_x, user_mpp_y = mpp @@ -789,11 +789,10 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + # closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + # mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -808,9 +807,8 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + (0, 0), level=closest_lvl, size=wsi.resolutions["level_dimensions"][closest_lvl], num_workers=self.num_workers ) else: @@ -820,40 +818,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers - ) - wsi_arr = cp.array(closest_lvl_wsi) - - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers - ) - wsi_arr = cp.array(closest_lvl_wsi) - - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -941,6 +911,36 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + cp, _ = optional_import("cupy") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @require_pkg(pkg_name="openslide") @@ -1072,17 +1072,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -1097,8 +1092,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=wsi.level_dimensions[closest_lvl] + ) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -1107,34 +1103,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1208,6 +1182,34 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @require_pkg(pkg_name="tifffile") @@ -1295,21 +1297,27 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: # Needs to be improved - unit = str(unit.value)[8:] - # unit = str(unit.value.name).lower() # TODO: Merge both methods + if unit is not None: # Test with more tiff files + if isinstance(unit.value, int): + unit = str(unit.value.name).lower() + else: + unit = str(unit.value)[8:] else: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") unit = "micrometer" convert_to_micron = ConvertUnits(unit, "micrometer") - # Here x and y resolutions are rational numbers so each of them is represented by a tuple. + + # Here, x and y resolutions are rational numbers so each of them is represented by a tuple. yres = wsi.pages[level].tags["YResolution"].value xres = wsi.pages[level].tags["XResolution"].value - return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) - - raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + if xres[0] & yres[0]: + return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) + else: + raise ValueError("The `XResolution` and/or `YResolution` property of the image is zero, " + "which is needed to obtain `mpp` for this file. Please use `level` instead.") + raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ @@ -1331,18 +1339,15 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - - lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] - closest_lvl_dim = lvl_dims[closest_lvl] - closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + # lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] + # closest_lvl_dim = lvl_dims[closest_lvl] + # closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -1357,8 +1362,8 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + print('Tifffile, within tolerance') + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -1367,36 +1372,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: - # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = lvl_dims[closest_lvl] - closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1445,7 +1425,7 @@ def _get_patch( Extracts and returns a patch image form the whole slide image. Args: - wsi: a whole slide image object loaded from a file or a lis of such objects + wsi: a whole slide image object loaded from a file or a list of such objects location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). If None, it is set to the full image size at the given level. @@ -1477,3 +1457,31 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = self.get_size(wsi, closest_lvl) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From 4b0c9baf6a31b4fac7056e79381dc62a85809fef Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Wed, 31 Jul 2024 19:54:29 +0200 Subject: [PATCH 16/22] Minor fixes: removed unnecessary comments --- monai/data/wsi_reader.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 7b8e3b12d1..7df8256ad6 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1337,16 +1337,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - # lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] - # closest_lvl_dim = lvl_dims[closest_lvl] - # closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level @@ -1362,7 +1357,6 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print('Tifffile, within tolerance') closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) else: From 66508e92506b315fe745a4664d6cbe5a6763d2cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 22:20:38 +0000 Subject: [PATCH 17/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 53048c2d70..1ba799f095 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -911,17 +911,17 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") @@ -1182,17 +1182,17 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") @@ -1447,11 +1447,11 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. From 441b4629f8905dba1fecbc16bb303c0c86f7ff17 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 4 Aug 2024 17:44:48 +0200 Subject: [PATCH 18/22] Added function _compute_mpp_target_res to BaseWSIReader --- monai/data/wsi_reader.py | 76 ++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 7df8256ad6..c6e2b67914 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,6 +431,28 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + return target_res_x, target_res_y + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -911,7 +933,7 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). @@ -926,20 +948,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @@ -1182,7 +1197,7 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). @@ -1196,19 +1211,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @@ -1297,13 +1306,10 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: # Test with more tiff files - if isinstance(unit.value, int): - unit = str(unit.value.name).lower() - else: - unit = str(unit.value)[8:] - else: + if unit is not None: + unit = str(unit.value.name) + if unit is None or len(unit) == 0: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") unit = "micrometer" @@ -1451,7 +1457,7 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). @@ -1461,21 +1467,15 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = self.get_size(wsi, closest_lvl) - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From d73d739de08101dc3781f3cf06cd862a8a775e16 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 4 Aug 2024 17:53:19 +0200 Subject: [PATCH 19/22] Added new feature and merged updates from main repository --- monai/data/wsi_reader.py | 61 +++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 1ba799f095..e217f41c7e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,6 +431,28 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + return target_res_x, target_res_y + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -926,20 +948,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @@ -1196,19 +1211,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @@ -1457,21 +1466,15 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = self.get_size(wsi, closest_lvl) - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From 5461801547e48d941d20c876729da15769868eb1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:36:01 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index e217f41c7e..81afafb246 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -434,13 +434,13 @@ def get_data( def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -1466,7 +1466,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") From 59683bc07cef236bbe4a61129fa6f963533f217c Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 11 Aug 2024 22:45:54 +0200 Subject: [PATCH 21/22] Added a function _compute_mpp_tolerances which checks the mpp tolerances to BaseWSIReader; Edited docstrings --- monai/data/wsi_reader.py | 163 +++++++++++++++++---------------------- 1 file changed, 70 insertions(+), 93 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 81afafb246..7b2e2eb0db 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,27 +431,62 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata - def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: tuple): """ - Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + Computes the target dimensions for resizing a whole slide image + to match a user-specified resolution in microns per pixel (MPP). Args: - wsi: whole slide image object from WSIReader - user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. - closest_lvl: the wsi level that is closest to the user-provided mpp resolution. - mpp_list: list of mpp values for all levels of a whole slide image. + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + closest_lvl_dim: Dimensions (height, width) of the image at the closest level. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. """ mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + ds_factor_x = mpp_closest_lvl_x / mpp[0] + ds_factor_y = mpp_closest_lvl_y / mpp[1] target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) return target_res_x, target_res_y + + def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> bool: + """ + Determines if user-provided MPP values are within a specified tolerance of the closest + level's MPP and checks if the closest level has higher resolution than desired MPP. + + Args: + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. + atol: Absolute tolerance for MPP comparison. + rtol: Relative tolerance for MPP comparison. + + """ + user_mpp_x, user_mpp_y = mpp + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + is_within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + is_within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + is_within_tolerance = is_within_tolerance_x & is_within_tolerance_y + + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + return is_within_tolerance, closest_level_is_bigger def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -802,50 +837,27 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 rtol: the acceptable relative tolerance for resolution in micro per pixel. """ - - # cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - - # closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - - # mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region( (0, 0), level=closest_lvl, size=wsi.resolutions["level_dimensions"][closest_lvl], num_workers=self.num_workers ) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - # Else: increase resolution (ie, decrement level) and then downsample - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -1087,43 +1099,25 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region( (0, 0), level=closest_lvl, size=wsi.level_dimensions[closest_lvl] ) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - # Else: increase resolution (ie, decrement level) and then downsample - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1342,40 +1336,23 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr From 547442ed138734a29c472eb97613b8e6ad2a8e4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Aug 2024 20:46:19 +0000 Subject: [PATCH 22/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 7b2e2eb0db..57df016140 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -453,7 +453,7 @@ def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: t target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) return target_res_x, target_res_y - + def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> bool: """ Determines if user-provided MPP values are within a specified tolerance of the closest