Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/create inputwidget for prediction outputs #582

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
20fd598
prediction_result_list_widget
yrkim98 Jan 29, 2025
980c050
need this still
yrkim98 Jan 29, 2025
4147232
use imagedata
yrkim98 Jan 29, 2025
fc68eb8
oops, should be ImageData.np_data not ImageData.data
yrkim98 Jan 29, 2025
33c003e
change way we insert a new segmentation threshold layer
yrkim98 Jan 29, 2025
c7f13cd
prepend [threshold] to layer name only if this is the first time a th…
yrkim98 Jan 29, 2025
ae26f80
no key check for dictionary
yrkim98 Jan 29, 2025
0af12e6
no key check for dictionary
yrkim98 Jan 29, 2025
92ba3fa
init prediction layers as empty- since for now we will not have any p…
yrkim98 Jan 29, 2025
a45fbb2
dont store ImageData object, can just store np.data directly
yrkim98 Jan 29, 2025
f2889fb
black
yrkim98 Jan 29, 2025
7586d8a
maintain previous selections
yrkim98 Jan 30, 2025
a25cf4a
instance not working correctly
yrkim98 Jan 30, 2025
a598a88
threshold only selected layers
yrkim98 Jan 30, 2025
c22520e
revert
yrkim98 Jan 30, 2025
63cc313
only threshold selected idx
yrkim98 Jan 30, 2025
e935624
rename process checked singal function
yrkim98 Jan 30, 2025
148fa75
lint
yrkim98 Jan 30, 2025
fe7d42c
use the correct layer instance
yrkim98 Jan 30, 2025
c56c302
working
yrkim98 Jan 30, 2025
80eaf0f
optional instead of none
yrkim98 Jan 30, 2025
dec344e
changing function and varible names to use binary map/prob map
yrkim98 Jan 30, 2025
e798a8e
revert uncesseary line
yrkim98 Feb 13, 2025
70283fd
comment for viewer generalized
yrkim98 Feb 13, 2025
0b1330b
we do need this
yrkim98 Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def get_seg_layers(self) -> list[Layer]:
if layer.name.startswith("[seg]")
]

def insert_threshold(
self, layer_name: str, img: np.ndarray, seg_layers: bool = False
def insert_binary_map_into_layer(
self, layer: ImageLayer, img: np.ndarray, seg_layers: bool = False
) -> None:
self.threshold_inserted[f"[threshold] {layer_name}"] = img
self.threshold_inserted[f"[threshold] {layer.name}"] = img

def get_layers_nonthreshold(self) -> list[Layer]:
return [
Expand All @@ -138,7 +138,7 @@ def get_source_path(self, layer: Layer) -> Optional[Path]:

return None

def get_all_segmentation_labels(self) -> list[Labels]:
def get_all_layers_containing_prob_map(self) -> list[Layer]:
return [
layer
for layer in self.get_all_images()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ def test_on_threshold_changed_non_prediction(test_image):
viewer: FakeViewer = FakeViewer()
main_model: MainModel = MainModel()
main_model.set_predictions_in_viewer(True)
file_input_model: FileInputModel = FileInputModel()
thresholding_service: ThresholdingService = ThresholdingService(
thresholding_model,
FakeExperimentsModel(),
FileInputModel(),
file_input_model,
main_model,
viewer,
task_executor=SynchroTaskExecutor.global_instance(),
Expand All @@ -83,6 +84,7 @@ def test_on_threshold_changed_non_prediction(test_image):
metadata={"prob_map": test_image},
)
viewer.add_image(test_image, name="donotthreshold")
file_input_model.set_selected_idx([0, 1])

# ACT set a threshold to trigger
thresholding_model.set_thresholding_value(50)
Expand Down
7 changes: 7 additions & 0 deletions src/allencell_ml_segmenter/core/file_input_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self) -> None:
self._input_mode: Optional[InputMode] = None
self._selected_paths: Optional[list[Path]] = None
self._max_channels: Optional[int] = None
self._selected_idx: Optional[list[int]] = None

def get_output_seg_directory(self) -> Optional[Path]:
"""
Expand Down Expand Up @@ -92,3 +93,9 @@ def get_input_files_as_list(self) -> List[Path]:
if selected_paths is not None:
return selected_paths
return []

def set_selected_idx(self, selected_idx: list[int]) -> None:
self._selected_idx = selected_idx

def get_selected_idx(self) -> Optional[list[int]]:
return self._selected_idx
44 changes: 44 additions & 0 deletions src/allencell_ml_segmenter/core/prediction_result_input_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Optional

from allencell_ml_segmenter.core.file_input_model import (
InputMode,
FileInputModel,
)
from allencell_ml_segmenter.core.file_input_widget import FileInputWidget
from allencell_ml_segmenter.main.i_viewer import IViewer
from allencell_ml_segmenter.main.segmenter_layer import LabelsLayer
from qtpy.QtCore import Qt

from allencell_ml_segmenter.prediction.service import ModelFileService

from napari.utils.events import Event as NapariEvent # type: ignore


class PredictionResultListWidget(FileInputWidget):
"""
Widget containing a list of prediction results that are selectable for thresholding
"""

def __init__(
self, model: FileInputModel, viewer: IViewer, service: ModelFileService
):
super().__init__(
model, viewer, service, include_channel_selection=False
)
self._prediction_layers: list[LabelsLayer] = []

def _update_layer_list(self, event: Optional[NapariEvent] = None) -> None:
previous_selections: list[int] = self._image_list.get_checked_rows()
self._image_list.clear()
self._prediction_layers = (
self._viewer.get_all_layers_containing_prob_map()
)
for idx, prediction_output_layer in enumerate(self._prediction_layers):
self._image_list.add_item(
prediction_output_layer.name,
set_checked=idx in previous_selections,
)

def _process_checked_signal(self, row: int, state: Qt.CheckState) -> None:
if self._model.get_input_mode() == InputMode.FROM_NAPARI_LAYERS:
self._model.set_selected_idx(self._image_list.get_checked_rows())
4 changes: 2 additions & 2 deletions src/allencell_ml_segmenter/main/i_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_seg_layers(self) -> list[Layer]:
pass

@abstractmethod
def insert_threshold(
def insert_binary_map_into_layer(
self, layer_name: str, img: np.ndarray, seg_layers: bool = False
) -> None:
"""
Expand All @@ -107,5 +107,5 @@ def get_source_path(self, layer: Layer) -> Optional[Path]:
pass

@abstractmethod
def get_all_segmentation_labels(self) -> list[LabelsLayer]:
def get_all_layers_containing_prob_map(self) -> list[Layer]:
pass
43 changes: 13 additions & 30 deletions src/allencell_ml_segmenter/main/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,45 +141,28 @@ def get_seg_layers(self) -> list[Layer]:
if layer.name.startswith("[seg]")
]

def insert_threshold(
def insert_binary_map_into_layer(
self,
layer_name: str,
layer: Layer,
image: np.ndarray,
remove_seg_layers: bool = False,
) -> None:
"""
Insert a thresholded image into the viewer.
If a layer for this thresholded image already exists, the new image will replace the old one and refresh the viewer.
If the layer does not exist, it will be added to the viewer in the correct place (on top of the original segmentation image:
Insert a binary mpa image into the viewer.
If a layer for this binary map image already exists, the new image will replace the old one and refresh the viewer.
If the layer does not exist, it will be added to the viewer in the correct place (on top of the original raw image:
index_of_segmentation + 1 in the LayerList)

:param layer_name: name of layer to insert. Will replace if one exists, will create one in a new position if needed.
:param layer: layer to replace.
:param image: image to insert
:param remove_seg_layers: boolean indicating if the layer that is being thresholded is a segmentation layer, and should be removed from the layer once it is updated with the threshold.
"""
layer_to_insert = self._get_layer_by_name(f"[threshold] {layer_name}")
if layer_to_insert is None:
# No thresholding exists, so we add it to the correct place in the viewer
layerlist = self.viewer.layers

# check if the original segementation layer is currently in the viewer, if so, remove later after
# thresholding is applied
seg_layer_og: Optional[Layer] = None
if remove_seg_layers:
seg_layer_og = self._get_layer_by_name(layer_name)

# figure out where to insert the new thresholded layer (on top of the original segmentation image)
layerlist_pos = layerlist.index(layer_name)
labels_created = Labels(image, name=f"[threshold] {layer_name}")
layerlist.insert(layerlist_pos + 1, labels_created)

# remove the original segmentation layer if it exists
if seg_layer_og:
layerlist.remove(seg_layer_og)
else:
# Thresholding already exists so just update the existing one in the viewer.
layer_to_insert.data = image
layer_to_insert.refresh()
# if threshold has not been previously applied, update name
if "threshold_applied" not in layer.metadata:
layer.name = f"[threshold] {layer.name}"
layer.data = image
layer.metadata["threshold_applied"] = True
layer.refresh()

def get_source_path(self, layer: Layer) -> Optional[Path]:
"""
Expand All @@ -193,7 +176,7 @@ def get_source_path(self, layer: Layer) -> Optional[Path]:
return Path(layer.metadata["source_path"])
return None

def get_all_segmentation_labels(self) -> list[LabelsLayer]:
def get_all_layers_containing_prob_map(self) -> list[Layer]:
"""
Get all segmentation labels layers that currently exist in the viewer.
"""
Expand Down
69 changes: 39 additions & 30 deletions src/allencell_ml_segmenter/thresholding/thresholding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def _handle_thresholding_error(self, error: Exception) -> None:
show_info("Thresholding failed: " + str(error))

def _on_threshold_changed(self, _: Event) -> None:
segmentation_labels: list[LabelsLayer] = (
self._viewer.get_all_segmentation_labels()
layers_containing_prob_map: list[Layer] = (
self._viewer.get_all_layers_containing_prob_map()
)

# determine thresholding function to use
Expand All @@ -80,39 +80,48 @@ def _on_threshold_changed(self, _: Event) -> None:
)
else:
thresh_function = self._threshold_image
for layer in segmentation_labels:
# Creating helper functions for mypy strict typing
def thresholding_task() -> np.ndarray:
# INVARIANT: a segmentation layer must have prob_map in its metadata if it came from our plugin
# so we are only supporting thresholding images that are from the plugin itself.
if (
not isinstance(layer.metadata, dict)
or "prob_map" not in layer.metadata
):
raise ValueError(
"Layer metadata must be a dictionary containing the 'prob_map' key in order to threshold."
)

return thresh_function(layer.metadata["prob_map"])
selected_idx: Optional[list[int]] = (
self._file_input_model.get_selected_idx()
)

layer_instance: LabelsLayer = layer
if selected_idx is not None:
for idx in selected_idx:
layer: Layer = layers_containing_prob_map[idx]

# Creating helper functions for mypy strict typing
def thresholding_task(
layer_instance: Layer = layer,
) -> np.ndarray:
# INVARIANT: a segmentation layer must have prob_map in its metadata if it came from our plugin
# so we are only supporting thresholding images that are from the plugin itself.
if (
not isinstance(layer_instance.metadata, dict)
or "prob_map" not in layer_instance.metadata
):
raise ValueError(
"Layer metadata must be a dictionary containing the 'prob_map' key in order to threshold."
)
# This thresholding task returns a binary map
return thresh_function(layer_instance.metadata["prob_map"])

def on_return(
resulting_binary_map: np.ndarray,
layer_instance: Layer = layer,
) -> None:
self._viewer.insert_binary_map_into_layer(
layer_instance,
resulting_binary_map,
self._main_model.are_predictions_in_viewer(),
)

def on_return(
threshold_output: np.ndarray,
) -> None:
self._viewer.insert_threshold(
layer_instance.name,
threshold_output,
self._main_model.are_predictions_in_viewer(),
self._task_executor.exec(
task=thresholding_task,
# lambda functions capture variables by reference so need to pass layer as a default argument
on_return=on_return,
on_error=self._handle_thresholding_error,
)

self._task_executor.exec(
task=thresholding_task,
# lambda functions capture variables by reference so need to pass layer as a default argument
on_return=on_return,
on_error=self._handle_thresholding_error,
)

def _save_thresholded_images(self, _: Event) -> None:
images_to_threshold: list[Path] = (
self._file_input_model.get_input_files_as_list()
Expand Down
21 changes: 11 additions & 10 deletions src/allencell_ml_segmenter/thresholding/thresholding_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from allencell_ml_segmenter.utils.file_utils import FileUtils

from allencell_ml_segmenter.widgets.label_with_hint_widget import LabelWithHint
from allencell_ml_segmenter.core.file_input_widget import (
FileInputWidget,
from allencell_ml_segmenter.core.prediction_result_input_widget import (
PredictionResultListWidget,
)
from allencell_ml_segmenter.core.file_input_model import (
FileInputModel,
Expand Down Expand Up @@ -85,14 +85,15 @@ def __init__(
layout.addWidget(self._title, alignment=Qt.AlignmentFlag.AlignHCenter)

# selecting input image
self._file_input_widget: FileInputWidget = FileInputWidget(
self._file_input_model,
self._viewer,
self._input_files_service,
include_channel_selection=False,
self._prediction_result_input_widget: PredictionResultListWidget = (
PredictionResultListWidget(
self._file_input_model,
self._viewer,
self._input_files_service,
)
)
self._file_input_widget.setObjectName("fileInput")
layout.addWidget(self._file_input_widget)
self._prediction_result_input_widget.setObjectName("fileInput")
layout.addWidget(self._prediction_result_input_widget)

# thresholding values
self._threshold_label: LabelWithHint = LabelWithHint("Threshold")
Expand Down Expand Up @@ -326,7 +327,7 @@ def doWork(self) -> None:
self._thresholding_model.dispatch_save_thresholded_images()

def focus_changed(self) -> None:
self._file_input_widget._update_layer_list()
self._prediction_result_input_widget._update_layer_list()

def getTypeOfWork(self) -> str:
return ""
Expand Down