Skip to content

Commit

Permalink
changing function and varible names to use binary map/prob map
Browse files Browse the repository at this point in the history
  • Loading branch information
yrkim98 committed Jan 30, 2025
1 parent 80eaf0f commit dec344e
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_seg_layers(self) -> list[Layer]:
if layer.name.startswith("[seg]")
]

def insert_threshold(
def insert_binary_map_into_layer(
self, layer: ImageLayer, img: np.ndarray, seg_layers: bool = False
) -> None:
self.threshold_inserted[f"[threshold] {layer.name}"] = img
Expand All @@ -138,7 +138,7 @@ def get_source_path(self, layer: Layer) -> Optional[Path]:

return None

def get_all_segmentation_labels(self) -> list[Layer]:
def get_all_layers_containing_prob_map(self) -> list[Layer]:
return [
layer
for layer in self.get_all_images()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(
def _update_layer_list(self, event: Optional[NapariEvent] = None) -> None:
previous_selections: list[int] = self._image_list.get_checked_rows()
self._image_list.clear()
self._prediction_layers = self._viewer.get_all_segmentation_labels()
self._prediction_layers = (
self._viewer.get_all_layers_containing_prob_map()
)
for idx, prediction_output_layer in enumerate(self._prediction_layers):
self._image_list.add_item(
prediction_output_layer.name,
Expand Down
4 changes: 2 additions & 2 deletions src/allencell_ml_segmenter/main/i_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_seg_layers(self) -> list[Layer]:
pass

@abstractmethod
def insert_threshold(
def insert_binary_map_into_layer(
self, layer_name: str, img: np.ndarray, seg_layers: bool = False
) -> None:
"""
Expand All @@ -107,5 +107,5 @@ def get_source_path(self, layer: Layer) -> Optional[Path]:
pass

@abstractmethod
def get_all_segmentation_labels(self) -> list[Layer]:
def get_all_layers_containing_prob_map(self) -> list[Layer]:
pass
4 changes: 2 additions & 2 deletions src/allencell_ml_segmenter/main/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_seg_layers(self) -> list[Layer]:
if layer.name.startswith("[seg]")
]

def insert_threshold(
def insert_binary_map_into_layer(
self,
layer: Layer,
image: np.ndarray,
Expand Down Expand Up @@ -176,7 +176,7 @@ def get_source_path(self, layer: Layer) -> Optional[Path]:
return Path(layer.metadata["source_path"])
return None

def get_all_segmentation_labels(self) -> list[Layer]:
def get_all_layers_containing_prob_map(self) -> list[Layer]:
"""
Get all segmentation labels layers that currently exist in the viewer.
"""
Expand Down
26 changes: 15 additions & 11 deletions src/allencell_ml_segmenter/thresholding/thresholding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def _handle_thresholding_error(self, error: Exception) -> None:
show_info("Thresholding failed: " + str(error))

def _on_threshold_changed(self, _: Event) -> None:
segmentation_labels: list[Layer] = (
self._viewer.get_all_segmentation_labels()
layers_containing_prob_map: list[Layer] = (
self._viewer.get_all_layers_containing_prob_map()
)

# determine thresholding function to use
Expand All @@ -80,11 +80,15 @@ def _on_threshold_changed(self, _: Event) -> None:
)
else:
thresh_function = self._threshold_image
for idx, layer in enumerate(segmentation_labels):
selected_idx: Optional[list[int]] = (
self._file_input_model.get_selected_idx()
)
if selected_idx is not None and idx in selected_idx:

selected_idx: Optional[list[int]] = (
self._file_input_model.get_selected_idx()
)

if selected_idx is not None:
for idx in selected_idx:
layer: Layer = layers_containing_prob_map[idx]

# Creating helper functions for mypy strict typing
def thresholding_task(
layer_instance: Layer = layer,
Expand All @@ -98,16 +102,16 @@ def thresholding_task(
raise ValueError(
"Layer metadata must be a dictionary containing the 'prob_map' key in order to threshold."
)

# This thresholding task returns a binary map
return thresh_function(layer_instance.metadata["prob_map"])

def on_return(
threshold_output: np.ndarray,
resulting_binary_map: np.ndarray,
layer_instance: Layer = layer,
) -> None:
self._viewer.insert_threshold(
self._viewer.insert_binary_map_into_layer(
layer_instance,
threshold_output,
resulting_binary_map,
self._main_model.are_predictions_in_viewer(),
)

Expand Down

0 comments on commit dec344e

Please sign in to comment.