Skip to content

Commit

Permalink
Merge pull request #583 from AllenCell/refactor/thresholdingmodel_com…
Browse files Browse the repository at this point in the history
…poses_fileinputmodel

Make ThresholdingModel compose FileInputModel, Add ability to remove binary maps
  • Loading branch information
yrkim98 authored Feb 17, 2025
2 parents 2eef549 + c2b14bd commit 6f5d95f
Show file tree
Hide file tree
Showing 15 changed files with 335 additions and 158 deletions.
3 changes: 3 additions & 0 deletions src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,6 @@ def get_all_layers_containing_prob_map(self) -> list[Layer]:
if getattr(layer, "metadata", None)
and "prob_map" in layer.metadata
]

def clear_binary_map_from_layer(self, layer: Layer) -> None:
self.threshold_inserted.pop(f"[threshold] {layer.name}")
12 changes: 12 additions & 0 deletions src/allencell_ml_segmenter/_tests/main/test_viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from napari.layers import Labels
import numpy as np

from allencell_ml_segmenter.main.viewer import Viewer


def test_clear_binary_map_from_layer() -> None:
layer: Labels = Labels(np.ones((10, 10, 10), dtype=bool))

Viewer.clear_binary_map_from_layer(Viewer, layer)

assert np.all(layer.data == 0) # assert all values are now zeros
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,27 @@ def thresholding_model() -> ThresholdingModel:
def test_set_thresholding_value_dispatches_event(thresholding_model):
fake_subscriber: FakeSubscriber = FakeSubscriber()
thresholding_model.subscribe(
Event.ACTION_THRESHOLDING_VALUE_CHANGED,
Event.ACTION_EXECUTE_THRESHOLDING,
fake_subscriber,
fake_subscriber.handle,
)

thresholding_model.set_thresholding_value(2)

assert fake_subscriber.was_handled(Event.ACTION_THRESHOLDING_VALUE_CHANGED)
assert fake_subscriber.was_handled(Event.ACTION_EXECUTE_THRESHOLDING)


def test_set_autothresholding_enabled_dispatches_event(thresholding_model):
fake_subscriber: FakeSubscriber = FakeSubscriber()
thresholding_model.subscribe(
Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED,
Event.ACTION_EXECUTE_THRESHOLDING,
fake_subscriber,
fake_subscriber.handle,
)

thresholding_model.set_autothresholding_enabled(True)

assert fake_subscriber.was_handled(
Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED
)
assert fake_subscriber.was_handled(Event.ACTION_EXECUTE_THRESHOLDING)


def test_dispatch_save_thresholded_images(thresholding_model):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import numpy as np

from allencell_ml_segmenter.core.file_input_model import FileInputModel
from allencell_ml_segmenter._tests.fakes.fake_experiments_model import (
FakeExperimentsModel,
)
Expand Down Expand Up @@ -29,7 +28,6 @@ def test_on_threshold_changed_non_prediction(test_image):
thresholding_service: ThresholdingService = ThresholdingService(
thresholding_model,
FakeExperimentsModel(),
FileInputModel(),
MainModel(),
viewer,
task_executor=SynchroTaskExecutor.global_instance(),
Expand Down Expand Up @@ -60,11 +58,9 @@ def test_on_threshold_changed_non_prediction(test_image):
viewer: FakeViewer = FakeViewer()
main_model: MainModel = MainModel()
main_model.set_predictions_in_viewer(True)
file_input_model: FileInputModel = FileInputModel()
thresholding_service: ThresholdingService = ThresholdingService(
thresholding_model,
FakeExperimentsModel(),
file_input_model,
main_model,
viewer,
task_executor=SynchroTaskExecutor.global_instance(),
Expand All @@ -84,10 +80,14 @@ def test_on_threshold_changed_non_prediction(test_image):
metadata={"prob_map": test_image},
)
viewer.add_image(test_image, name="donotthreshold")
file_input_model.set_selected_idx([0, 1])
thresholding_model.set_thresholding_layers(
viewer.get_all_layers_containing_prob_map()
)
thresholding_model.set_selected_idx([0, 1])

# ACT set a threshold to trigger
thresholding_model.set_thresholding_value(50)
thresholding_model.set_threshold_enabled(True)

# Verify a threshold layer is added for each seg layer
assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted
Expand All @@ -110,3 +110,121 @@ def test_on_threshold_changed_non_prediction(test_image):
assert np.array_equal(seg_data, (test_image > 100).astype(int))
# verify that raw layers do not get thresholded
assert len(viewer.threshold_inserted) == 2


def test_none_selected_removes_binary_maps(test_image):
"""
Test that the thresholding service does not add a threshold layer for a layer that is not a probability map
"""
# ARRANGE
thresholding_model: ThresholdingModel = ThresholdingModel()
viewer: FakeViewer = FakeViewer()
main_model: MainModel = MainModel()
main_model.set_predictions_in_viewer(True)
thresholding_service: ThresholdingService = ThresholdingService(
thresholding_model,
FakeExperimentsModel(),
main_model,
viewer,
task_executor=SynchroTaskExecutor.global_instance(),
)

# Only the [seg] layers below should produce a threshold layer since they have prob map metadata
viewer.add_image(test_image, name="[raw] test_layer 1")
viewer.add_image(
test_image,
name="[seg] test_layer 1",
metadata={"prob_map": test_image},
)
viewer.add_image(test_image, name="[raw] test_layer 2")
viewer.add_image(
test_image,
name="[seg] test_layer 2",
metadata={"prob_map": test_image},
)
viewer.add_image(test_image, name="donotthreshold")
thresholding_model.set_thresholding_layers(
viewer.get_all_layers_containing_prob_map()
)
thresholding_model.set_selected_idx([0, 1])

# create binary maps and add to viewer
thresholding_model.set_thresholding_value(50)
thresholding_model.set_threshold_enabled(True)

# Verify a threshold layer is added for each seg layer
assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 1"]
assert np.array_equal(seg_data, (test_image > 50).astype(int))
assert "[threshold] [seg] test_layer 2" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 2"]
assert np.array_equal(seg_data, (test_image > 50).astype(int))
# verify that raw layers do not get thresholded
assert len(viewer.threshold_inserted) == 2

# ACT remove all binary maps by selecting None
thresholding_model.disable_all()

# Assert
assert len(viewer.threshold_inserted) == 0


def test_unselection_removes_binary_maps(test_image):
"""
Test that the thresholding service does not add a threshold layer for a layer that is not a probability map
"""
# ARRANGE
thresholding_model: ThresholdingModel = ThresholdingModel()
viewer: FakeViewer = FakeViewer()
main_model: MainModel = MainModel()
main_model.set_predictions_in_viewer(True)
thresholding_service: ThresholdingService = ThresholdingService(
thresholding_model,
FakeExperimentsModel(),
main_model,
viewer,
task_executor=SynchroTaskExecutor.global_instance(),
)

# Only the [seg] layers below should produce a threshold layer since they have prob map metadata
viewer.add_image(test_image, name="[raw] test_layer 1")
viewer.add_image(
test_image,
name="[seg] test_layer 1",
metadata={"prob_map": test_image},
)
viewer.add_image(test_image, name="[raw] test_layer 2")
viewer.add_image(
test_image,
name="[seg] test_layer 2",
metadata={"prob_map": test_image},
)
viewer.add_image(test_image, name="donotthreshold")
thresholding_model.set_thresholding_layers(
viewer.get_all_layers_containing_prob_map()
)
thresholding_model.set_selected_idx([0, 1])

# create binary maps and add to viewer
thresholding_model.set_thresholding_value(50)
thresholding_model.set_threshold_enabled(True)

# Verify a threshold layer is added for each seg layer
assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 1"]
assert np.array_equal(seg_data, (test_image > 50).astype(int))
assert "[threshold] [seg] test_layer 2" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 2"]
assert np.array_equal(seg_data, (test_image > 50).astype(int))
# verify that raw layers do not get thresholded
assert len(viewer.threshold_inserted) == 2

# ACT remove binary map by unselecting index 1
thresholding_model.set_selected_idx(
[0]
) # remove 1, which was previously selected

# Assert
assert len(viewer.threshold_inserted) == 1
assert "[threshold] [seg] test_layer 2" not in viewer.threshold_inserted
assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ThresholdingModel,
)
from allencell_ml_segmenter.core.file_input_model import (
FileInputModel,
InputMode,
)
from allencell_ml_segmenter._tests.fakes.fake_viewer import FakeViewer
Expand All @@ -28,16 +27,9 @@ def main_model() -> MainModel:


@pytest.fixture
def thresholding_model() -> ThresholdingModel:
def thresholding_model(tmp_path: Path) -> ThresholdingModel:
model = ThresholdingModel()
model.set_thresholding_value(128)
return model


@pytest.fixture
# tmp_path is a builtin pytest fixture for a faked path
def file_input_model(tmp_path: Path) -> FileInputModel:
model = FileInputModel()
model.set_output_directory(tmp_path / "output")
model.set_input_image_path(tmp_path / "input")
model.set_input_mode(InputMode.FROM_PATH)
Expand All @@ -58,15 +50,13 @@ def viewer() -> FakeViewer:
def thresholding_view(
main_model,
thresholding_model,
file_input_model,
experiments_model,
viewer,
qtbot,
):
view = ThresholdingView(
main_model,
thresholding_model,
file_input_model,
experiments_model,
viewer,
)
Expand Down Expand Up @@ -176,15 +166,18 @@ def test_update_state_from_radios(thresholding_view, thresholding_model):


def test_check_able_to_threshold_valid(
main_model, file_input_model, experiments_model, viewer
main_model, experiments_model, viewer, tmp_path
):
thresholding_model: ThresholdingModel = ThresholdingModel()
thresholding_model.set_threshold_enabled(True)
thresholding_model.set_thresholding_value(100)
thresholding_model.set_output_directory(tmp_path / "output")
thresholding_model.set_input_image_path(tmp_path / "input")
thresholding_model.set_input_mode(InputMode.FROM_PATH)

thresholding_view: ThresholdingView = ThresholdingView(
main_model,
thresholding_model,
file_input_model,
experiments_model,
viewer,
)
Expand All @@ -198,13 +191,11 @@ def test_check_able_to_threshold_no_output_dir(
thresholding_model: ThresholdingModel = ThresholdingModel()
thresholding_model.set_threshold_enabled(True)
thresholding_model.set_thresholding_value(100)
file_input_model: FileInputModel = FileInputModel()
file_input_model.set_input_mode(InputMode.FROM_PATH)
file_input_model.set_input_image_path(Path("fake_path"))
thresholding_model.set_input_mode(InputMode.FROM_PATH)
thresholding_model.set_input_image_path(Path("fake_path"))
thresholding_view: ThresholdingView = ThresholdingView(
main_model,
thresholding_model,
file_input_model,
experiments_model,
viewer,
)
Expand All @@ -218,13 +209,11 @@ def test_check_able_to_threshold_no_input_dir(
thresholding_model: ThresholdingModel = ThresholdingModel()
thresholding_model.set_threshold_enabled(True)
thresholding_model.set_thresholding_value(100)
file_input_model: FileInputModel = FileInputModel()
file_input_model.set_input_mode(InputMode.FROM_PATH)
file_input_model.set_output_directory(Path("fake_path"))
thresholding_model.set_input_mode(InputMode.FROM_PATH)
thresholding_model.set_output_directory(Path("fake_path"))
thresholding_view: ThresholdingView = ThresholdingView(
main_model,
thresholding_model,
file_input_model,
experiments_model,
viewer,
)
Expand All @@ -238,13 +227,11 @@ def test_check_able_to_threshold_no_input_method(
thresholding_model: ThresholdingModel = ThresholdingModel()
thresholding_model.set_threshold_enabled(True)
thresholding_model.set_thresholding_value(100)
file_input_model: FileInputModel = FileInputModel()
file_input_model.set_input_image_path(Path("fake_path"))
file_input_model.set_output_directory(Path("fake_path"))
thresholding_model.set_input_image_path(Path("fake_path"))
thresholding_model.set_output_directory(Path("fake_path"))
thresholding_view: ThresholdingView = ThresholdingView(
main_model,
thresholding_model,
file_input_model,
experiments_model,
viewer,
)
Expand Down
4 changes: 2 additions & 2 deletions src/allencell_ml_segmenter/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class Event(Enum):
ACTION_CURATION_RAW_THREAD_ERROR = "curation_raw_thread_error"

# Thresholding events
ACTION_THRESHOLDING_VALUE_CHANGED = "thresholding_value_changed"
ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED = "autothresholding_selected"
ACTION_EXECUTE_THRESHOLDING = "execute_thresholding"
ACTION_SAVE_THRESHOLDING_IMAGES = "save_thresholding_images"
ACTION_THRESHOLDING_DISABLED = "disable_thresholding"

# View selection events. These can stem from a user action, or from a process (i.e. prediction process ends, and a new view is shown automatically).
VIEW_SELECTION_TRAINING = "training_selected"
Expand Down
6 changes: 0 additions & 6 deletions src/allencell_ml_segmenter/core/file_input_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,3 @@ def get_input_files_as_list(self) -> List[Path]:
if selected_paths is not None:
return selected_paths
return []

def set_selected_idx(self, selected_idx: list[int]) -> None:
self._selected_idx = selected_idx

def get_selected_idx(self) -> Optional[list[int]]:
return self._selected_idx
7 changes: 4 additions & 3 deletions src/allencell_ml_segmenter/core/file_input_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def __init__(
self,
model: FileInputModel,
viewer: IViewer,
service: ModelFileService,
service: Optional[ModelFileService],
include_channel_selection: bool = True,
):
super().__init__()
self._include_channel_selection: bool = include_channel_selection
self._model: FileInputModel = model
self._viewer: IViewer = viewer
self._service: ModelFileService = service
self._service: Optional[ModelFileService] = service
layout: QVBoxLayout = QVBoxLayout()
self.setLayout(layout)
layout.setContentsMargins(0, 0, 0, 0)
Expand Down Expand Up @@ -271,7 +271,8 @@ def _process_checked_signal(self, row: int, state: Qt.CheckState) -> None:
len(selected_indices) == 0
and self._include_channel_selection
):
self._service.stop_channel_extraction() # stop so combobox doesn't reset after thread is finished
if self._service:
self._service.stop_channel_extraction() # stop so combobox doesn't reset after thread is finished
self._reset_channel_combobox()

self._model.set_selected_paths(
Expand Down
Loading

0 comments on commit 6f5d95f

Please sign in to comment.