From dec344ee8c3a2ca4c44aca7fe8cf0e6f23deb3c3 Mon Sep 17 00:00:00 2001 From: "brian.kim" Date: Thu, 30 Jan 2025 15:14:36 -0800 Subject: [PATCH] changing function and varible names to use binary map/prob map --- .../_tests/fakes/fake_viewer.py | 4 +-- .../core/prediction_result_input_widget.py | 4 ++- src/allencell_ml_segmenter/main/i_viewer.py | 4 +-- src/allencell_ml_segmenter/main/viewer.py | 4 +-- .../thresholding/thresholding_service.py | 26 +++++++++++-------- 5 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py b/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py index 8797825a..5aa4b95b 100644 --- a/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py +++ b/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py @@ -118,7 +118,7 @@ def get_seg_layers(self) -> list[Layer]: if layer.name.startswith("[seg]") ] - def insert_threshold( + def insert_binary_map_into_layer( self, layer: ImageLayer, img: np.ndarray, seg_layers: bool = False ) -> None: self.threshold_inserted[f"[threshold] {layer.name}"] = img @@ -138,7 +138,7 @@ def get_source_path(self, layer: Layer) -> Optional[Path]: return None - def get_all_segmentation_labels(self) -> list[Layer]: + def get_all_layers_containing_prob_map(self) -> list[Layer]: return [ layer for layer in self.get_all_images() diff --git a/src/allencell_ml_segmenter/core/prediction_result_input_widget.py b/src/allencell_ml_segmenter/core/prediction_result_input_widget.py index b18f7276..d4a44fa0 100644 --- a/src/allencell_ml_segmenter/core/prediction_result_input_widget.py +++ b/src/allencell_ml_segmenter/core/prediction_result_input_widget.py @@ -30,7 +30,9 @@ def __init__( def _update_layer_list(self, event: Optional[NapariEvent] = None) -> None: previous_selections: list[int] = self._image_list.get_checked_rows() self._image_list.clear() - self._prediction_layers = self._viewer.get_all_segmentation_labels() + self._prediction_layers = ( + self._viewer.get_all_layers_containing_prob_map() + ) for idx, prediction_output_layer in enumerate(self._prediction_layers): self._image_list.add_item( prediction_output_layer.name, diff --git a/src/allencell_ml_segmenter/main/i_viewer.py b/src/allencell_ml_segmenter/main/i_viewer.py index 8d2f941f..9762677b 100644 --- a/src/allencell_ml_segmenter/main/i_viewer.py +++ b/src/allencell_ml_segmenter/main/i_viewer.py @@ -83,7 +83,7 @@ def get_seg_layers(self) -> list[Layer]: pass @abstractmethod - def insert_threshold( + def insert_binary_map_into_layer( self, layer_name: str, img: np.ndarray, seg_layers: bool = False ) -> None: """ @@ -107,5 +107,5 @@ def get_source_path(self, layer: Layer) -> Optional[Path]: pass @abstractmethod - def get_all_segmentation_labels(self) -> list[Layer]: + def get_all_layers_containing_prob_map(self) -> list[Layer]: pass diff --git a/src/allencell_ml_segmenter/main/viewer.py b/src/allencell_ml_segmenter/main/viewer.py index 48f4ae22..988ff8d7 100644 --- a/src/allencell_ml_segmenter/main/viewer.py +++ b/src/allencell_ml_segmenter/main/viewer.py @@ -141,7 +141,7 @@ def get_seg_layers(self) -> list[Layer]: if layer.name.startswith("[seg]") ] - def insert_threshold( + def insert_binary_map_into_layer( self, layer: Layer, image: np.ndarray, @@ -176,7 +176,7 @@ def get_source_path(self, layer: Layer) -> Optional[Path]: return Path(layer.metadata["source_path"]) return None - def get_all_segmentation_labels(self) -> list[Layer]: + def get_all_layers_containing_prob_map(self) -> list[Layer]: """ Get all segmentation labels layers that currently exist in the viewer. """ diff --git a/src/allencell_ml_segmenter/thresholding/thresholding_service.py b/src/allencell_ml_segmenter/thresholding/thresholding_service.py index 1335a577..df68df44 100644 --- a/src/allencell_ml_segmenter/thresholding/thresholding_service.py +++ b/src/allencell_ml_segmenter/thresholding/thresholding_service.py @@ -69,8 +69,8 @@ def _handle_thresholding_error(self, error: Exception) -> None: show_info("Thresholding failed: " + str(error)) def _on_threshold_changed(self, _: Event) -> None: - segmentation_labels: list[Layer] = ( - self._viewer.get_all_segmentation_labels() + layers_containing_prob_map: list[Layer] = ( + self._viewer.get_all_layers_containing_prob_map() ) # determine thresholding function to use @@ -80,11 +80,15 @@ def _on_threshold_changed(self, _: Event) -> None: ) else: thresh_function = self._threshold_image - for idx, layer in enumerate(segmentation_labels): - selected_idx: Optional[list[int]] = ( - self._file_input_model.get_selected_idx() - ) - if selected_idx is not None and idx in selected_idx: + + selected_idx: Optional[list[int]] = ( + self._file_input_model.get_selected_idx() + ) + + if selected_idx is not None: + for idx in selected_idx: + layer: Layer = layers_containing_prob_map[idx] + # Creating helper functions for mypy strict typing def thresholding_task( layer_instance: Layer = layer, @@ -98,16 +102,16 @@ def thresholding_task( raise ValueError( "Layer metadata must be a dictionary containing the 'prob_map' key in order to threshold." ) - + # This thresholding task returns a binary map return thresh_function(layer_instance.metadata["prob_map"]) def on_return( - threshold_output: np.ndarray, + resulting_binary_map: np.ndarray, layer_instance: Layer = layer, ) -> None: - self._viewer.insert_threshold( + self._viewer.insert_binary_map_into_layer( layer_instance, - threshold_output, + resulting_binary_map, self._main_model.are_predictions_in_viewer(), )