diff --git a/src/allencell_ml_segmenter/core/event.py b/src/allencell_ml_segmenter/core/event.py index a2851d5d..f602bc80 100644 --- a/src/allencell_ml_segmenter/core/event.py +++ b/src/allencell_ml_segmenter/core/event.py @@ -32,6 +32,7 @@ class Event(Enum): ACTION_PREDICTION_POSTPROCESSING_AUTO_THRESHOLD = ( "postprocessing_auto_threshold" ) + ACTION_PREDICTION_INITIATED = "prediction_initiated" # Curation ACTION_CURATION_RAW_SELECTED = "curation_raw_selected" ACTION_CURATION_SEG1_SELECTED = "curation_seg1_selected" diff --git a/src/allencell_ml_segmenter/prediction/file_input_widget.py b/src/allencell_ml_segmenter/prediction/file_input_widget.py index e464b33b..2d8b7b62 100644 --- a/src/allencell_ml_segmenter/prediction/file_input_widget.py +++ b/src/allencell_ml_segmenter/prediction/file_input_widget.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import List, Optional -from napari.utils.events import Event +from napari.utils.events import Event as NapariEvent from qtpy.QtCore import Qt from qtpy.QtWidgets import ( QWidget, @@ -15,6 +15,7 @@ QSizePolicy, ) +from allencell_ml_segmenter.core.event import Event from allencell_ml_segmenter.widgets.input_button_widget import ( InputButton, FileInputMode, @@ -28,7 +29,7 @@ CheckBoxListWidget, ) -from napari.viewer import Viewer +from allencell_ml_segmenter.main.viewer import Viewer class PredictionFileInput(QWidget): @@ -168,6 +169,11 @@ def __init__(self, model: PredictionModel, viewer: Viewer): grid_layout.setColumnStretch(1, 0) frame.layout().addLayout(grid_layout) + self._model.subscribe( + Event.ACTION_PREDICTION_INITIATED, + self, + self._set_selected_image_paths_from_napari + ) def _on_screen_slot(self) -> None: """Prohibits usage of non-related input fields if top button is checked.""" @@ -184,11 +190,16 @@ def _from_directory_slot(self) -> None: self._browse_dir_edit.setEnabled(True) self._model.set_prediction_input_mode(PredictionInputMode.FROM_PATH) - def _update_layer_list(self, event: Optional[Event] = None) -> None: + def _update_layer_list(self, event: Optional[NapariEvent] = None) -> None: self._image_list.clear() for layer in self._viewer.get_layers(): self._image_list.add_item(layer.name) + def _set_selected_image_paths_from_napari(self, event: Optional[Event] = None) -> None: + selected_indices: List[int] = self._image_list.get_checked_rows() + selected_paths: List[Path] = [self._viewer.get_layers()[i].source.path for i in selected_indices] + self._model.set_selected_paths(selected_paths) + # TODO: replace with correct implementation and move to a service def map_input_file_directory_to_path_list( self, input_file_directory: str diff --git a/src/allencell_ml_segmenter/prediction/model.py b/src/allencell_ml_segmenter/prediction/model.py index 53a42954..3220d488 100644 --- a/src/allencell_ml_segmenter/prediction/model.py +++ b/src/allencell_ml_segmenter/prediction/model.py @@ -26,6 +26,7 @@ def __init__(self): self._image_input_channel_index: int = None self._input_mode: PredictionInputMode = None self._output_directory: Path = None + self._selected_paths: List[Path] = None # state related to ModelInputWidget self._preprocessing_method: str = None @@ -33,6 +34,9 @@ def __init__(self): self._postprocessing_simple_threshold: float = None self._postprocessing_auto_threshold: str = None + # app state + self._prediction_running: bool = False + def get_input_image_path(self) -> Path: """ Gets list of paths to input images. @@ -151,3 +155,20 @@ def set_prediction_input_mode(self, mode: PredictionInputMode) -> None: def get_prediction_input_mode(self) -> PredictionInputMode: return self._input_mode + + def set_selected_paths(self, paths: List[Path]) -> None: + self._selected_paths = paths + + def get_selected_paths(self) -> List[Path]: + return self._selected_paths + + def set_prediction_running(self, is_prediction_running: bool) -> None: + # To run some setup for predictions + self.dispatch(Event.ACTION_PREDICTION_INITIATED) + # Shoots off a prediction run + self.dispatch(Event.PROCESS_PREDICTION) + self._prediction_running = is_prediction_running + + def get_prediction_running(self) -> bool: + return self._prediction_running + diff --git a/src/allencell_ml_segmenter/prediction/view.py b/src/allencell_ml_segmenter/prediction/view.py index 6b825048..5d6e6190 100644 --- a/src/allencell_ml_segmenter/prediction/view.py +++ b/src/allencell_ml_segmenter/prediction/view.py @@ -91,7 +91,7 @@ def __init__( def run_btn_handler(self): # Just to test service for now. - self._prediction_model.dispatch(Event.PROCESS_PREDICTION) + self._prediction_model.set_prediction_running(True) def doWork(self): print("doWork - prediction") diff --git a/src/allencell_ml_segmenter/services/prediction_service.py b/src/allencell_ml_segmenter/services/prediction_service.py index 17b455cc..da063974 100644 --- a/src/allencell_ml_segmenter/services/prediction_service.py +++ b/src/allencell_ml_segmenter/services/prediction_service.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Union, Dict, List -# from cyto_dl.api.model import CytoDLModel +from cyto_dl.api.model import CytoDLModel from napari.utils.notifications import show_warning @@ -74,14 +74,35 @@ def _predict_model(self, _: Event) -> None: ) continue_prediction = False - # Create a CSV if user selects a folder of input images. - if ( - self._prediction_model.get_prediction_input_mode() + # Check to see if user has selected an input mode + input_mode_selected: PredictionInputMode = self._prediction_model.get_prediction_input_mode() + if not input_mode_selected: + show_warning("Please select input images before running prediction.") + continue_prediction = False + elif ( + input_mode_selected == PredictionInputMode.FROM_PATH ): + # User has selected a directory or a csv as input images input_path: Path = self._prediction_model.get_input_image_path() if input_path.is_dir(): + # if input path selected is a directory, we need to manually write a CSV for cyto-dl self.write_csv_for_inputs(list(input_path.glob("*.*"))) + elif input_path.suffix != ".csv": + # This should not be possible with FileInputWidget- throw an error. + raise ValueError("Somehow the user has selected a non-csv/directory for input images. Should not be possible with FileInputWidget") + elif input_mode_selected == PredictionInputMode.FROM_NAPARI_LAYERS: + # User has selected napari image layers as input images + selected_paths_from_napari: List[Path] = self._prediction_model.get_selected_paths() + if len(selected_paths_from_napari) < 1: + # No image layers selected + show_warning("Please select at least 1 image from the napari layer before running prediction.") + continue_prediction = False + else: + # If user selects input images from napari, we need to manually write a csv for cyto-dl + self.write_csv_for_inputs(self._prediction_model.get_selected_paths()) + + if continue_prediction: cyto_api: CytoDLModel = CytoDLModel()