diff --git a/docs/resources/progress-thread-diagram.jpeg b/docs/resources/progress-thread-diagram.jpeg new file mode 100644 index 00000000..3ef75995 Binary files /dev/null and b/docs/resources/progress-thread-diagram.jpeg differ diff --git a/setup.cfg b/setup.cfg index bb1ecd1c..cb9ca5e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ project = hydra-core==1.3.2 aicsimageio tifffile>=2023.4.12 + watchdog # testing and linting requirements test_lint = diff --git a/src/allencell_ml_segmenter/_tests/core/test_progress_tracker.py b/src/allencell_ml_segmenter/_tests/core/test_progress_tracker.py new file mode 100644 index 00000000..72aa8213 --- /dev/null +++ b/src/allencell_ml_segmenter/_tests/core/test_progress_tracker.py @@ -0,0 +1,36 @@ +from allencell_ml_segmenter.core.progress_tracker import ProgressTracker +import pytest + + +def test_set_progress_within_bounds(): + tracker: ProgressTracker = ProgressTracker( + progress_minimum=0, progress_maximum=10 + ) + tracker.set_progress(2) + assert tracker.get_progress() == 2 + tracker.set_progress(9) + assert tracker.get_progress() == 9 + tracker.set_progress(0) + assert tracker.get_progress() == 0 + tracker.set_progress(10) + assert tracker.get_progress() == 10 + + +def test_set_progress_greater_than_max(): + tracker: ProgressTracker = ProgressTracker( + progress_minimum=0, progress_maximum=10 + ) + with pytest.raises(ValueError): + tracker.set_progress(11) + with pytest.raises(ValueError): + tracker.set_progress(10394) + + +def test_set_progress_less_than_min(): + tracker: ProgressTracker = ProgressTracker( + progress_minimum=0, progress_maximum=10 + ) + with pytest.raises(ValueError): + tracker.set_progress(-1) + with pytest.raises(ValueError): + tracker.set_progress(-1948) diff --git a/src/allencell_ml_segmenter/_tests/main/experiments_home/1_exp/csv/version_0/test_metrics_csv_empty.csv b/src/allencell_ml_segmenter/_tests/main/experiments_home/1_exp/csv/version_0/test_metrics_csv_empty.csv new file mode 100644 index 00000000..e5bd2cb9 --- /dev/null +++ b/src/allencell_ml_segmenter/_tests/main/experiments_home/1_exp/csv/version_0/test_metrics_csv_empty.csv @@ -0,0 +1 @@ +,epoch,fake1,fake2 diff --git a/src/allencell_ml_segmenter/_tests/main/experiments_home/1_exp/csv/version_1/test_metrics_csv_2_epochs.csv b/src/allencell_ml_segmenter/_tests/main/experiments_home/1_exp/csv/version_1/test_metrics_csv_2_epochs.csv new file mode 100644 index 00000000..cd86d64b --- /dev/null +++ b/src/allencell_ml_segmenter/_tests/main/experiments_home/1_exp/csv/version_1/test_metrics_csv_2_epochs.csv @@ -0,0 +1,4 @@ +,epoch,fake1,fake2 +0,0,1,2 +0,1,1,2 +0,2,1,2 diff --git a/src/allencell_ml_segmenter/_tests/main/test_experiments_model.py b/src/allencell_ml_segmenter/_tests/main/test_experiments_model.py index 1bc5c583..fc15079d 100644 --- a/src/allencell_ml_segmenter/_tests/main/test_experiments_model.py +++ b/src/allencell_ml_segmenter/_tests/main/test_experiments_model.py @@ -94,3 +94,75 @@ def test_get_train_config_path() -> None: # Act / Assert assert model.get_train_config_path("test_experiment") == expected + + +def test_get_csv_path() -> None: + # Arrange + user_experiments_path = Path(__file__).parent / "experiments_home" + config = FakeUserSettings( + cyto_dl_home_path=Path(__file__).parent / "cyto_dl_home", + user_experiments_path=user_experiments_path, + ) + model = ExperimentsModel(config) + model.set_experiment_name("0_exp") + expected = user_experiments_path / "0_exp" / "data" + + # Act / Assert + assert model.get_csv_path() == expected + + +def test_get_metrics_csv_path() -> None: + # Arrange + user_experiments_path = Path(__file__).parent / "experiments_home" + config = FakeUserSettings( + cyto_dl_home_path=Path(__file__).parent / "cyto_dl_home", + user_experiments_path=user_experiments_path, + ) + model = ExperimentsModel(config) + model.set_experiment_name("0_exp") + expected = user_experiments_path / "0_exp" / "csv" + + # Act / Assert + assert model.get_metrics_csv_path() == expected + + +def test_get_latest_metrics_csv_version_no_versions() -> None: + # Arrange + user_experiments_path = Path(__file__).parent / "experiments_home" + config = FakeUserSettings( + cyto_dl_home_path=Path(__file__).parent / "cyto_dl_home", + user_experiments_path=user_experiments_path, + ) + model = ExperimentsModel(config) + model.set_experiment_name("0_exp") + + # Act / Assert + assert model.get_latest_metrics_csv_version() == -1 + + +def test_get_latest_metrics_csv_version_no_directory() -> None: + # Arrange + user_experiments_path = Path(__file__).parent / "experiments_home" + config = FakeUserSettings( + cyto_dl_home_path=Path(__file__).parent / "cyto_dl_home", + user_experiments_path=user_experiments_path, + ) + model = ExperimentsModel(config) + model.set_experiment_name("2_exp") + + # Act / Assert + assert model.get_latest_metrics_csv_version() == -1 + + +def test_get_latest_metrics_csv_version_version_1() -> None: + # Arrange + user_experiments_path = Path(__file__).parent / "experiments_home" + config = FakeUserSettings( + cyto_dl_home_path=Path(__file__).parent / "cyto_dl_home", + user_experiments_path=user_experiments_path, + ) + model = ExperimentsModel(config) + model.set_experiment_name("1_exp") + + # Act / Assert + assert model.get_latest_metrics_csv_version() == 1 diff --git a/src/allencell_ml_segmenter/_tests/training/test_metrics_csv_event_handler.py b/src/allencell_ml_segmenter/_tests/training/test_metrics_csv_event_handler.py new file mode 100644 index 00000000..6921f9e8 --- /dev/null +++ b/src/allencell_ml_segmenter/_tests/training/test_metrics_csv_event_handler.py @@ -0,0 +1,66 @@ +from pathlib import Path +import allencell_ml_segmenter +from allencell_ml_segmenter.training.metrics_csv_event_handler import ( + MetricsCSVEventHandler, +) +from unittest.mock import Mock + + +def test_csv_2_epochs(): + callback_mock: Mock = Mock() + test_csv_path: Path = ( + Path(allencell_ml_segmenter.__file__).parent + / "_tests" + / "main" + / "experiments_home" + / "1_exp" + / "csv" + / "version_1" + / "test_metrics_csv_2_epochs.csv" + ) + handler: MetricsCSVEventHandler = MetricsCSVEventHandler( + test_csv_path, callback_mock + ) + fs_event_mock: Mock = Mock(src_path=test_csv_path) + handler.on_any_event(fs_event_mock) + callback_mock.assert_called_with(2) + + +def test_empty_csv(): + callback_mock: Mock = Mock() + test_csv_path: Path = ( + Path(allencell_ml_segmenter.__file__).parent + / "_tests" + / "main" + / "experiments_home" + / "1_exp" + / "csv" + / "version_0" + / "test_metrics_csv_empty.csv" + ) + handler: MetricsCSVEventHandler = MetricsCSVEventHandler( + test_csv_path, callback_mock + ) + fs_event_mock: Mock = Mock(src_path=test_csv_path) + handler.on_any_event(fs_event_mock) + callback_mock.assert_called_with(0) + + +def test_nonexistent_csv(): + callback_mock: Mock = Mock() + test_csv_path: Path = ( + Path(allencell_ml_segmenter.__file__).parent + / "_tests" + / "main" + / "experiments_home" + / "0_exp" + / "csv" + / "version_0" + / "test_metrics_does_not_exist.csv" + ) + handler: MetricsCSVEventHandler = MetricsCSVEventHandler( + test_csv_path, callback_mock + ) + fs_event_mock: Mock = Mock(src_path=test_csv_path) + handler.on_any_event(fs_event_mock) + callback_mock.assert_not_called() diff --git a/src/allencell_ml_segmenter/core/progress_tracker.py b/src/allencell_ml_segmenter/core/progress_tracker.py new file mode 100644 index 00000000..ba0bcca0 --- /dev/null +++ b/src/allencell_ml_segmenter/core/progress_tracker.py @@ -0,0 +1,56 @@ +from abc import abstractmethod + + +class ProgressTracker: + """ + Base class for all ProgressTrackers to inherit from. A ProgressTracker + maintains an integer measure of progress between progress_minimum and + progress_maximum. The progress value can be used by PyQt progress bars + for example. + """ + + def __init__(self, progress_minimum: int = 0, progress_maximum: int = 0): + self._progress_minimum: int = progress_minimum + self._progress_maximum: int = progress_maximum + self._progress: int = progress_minimum + + def get_progress_minimum(self) -> int: + return self._progress_minimum + + def get_progress_maximum(self) -> int: + return self._progress_maximum + + def get_progress(self) -> int: + return self._progress + + def set_progress(self, progress: int) -> None: + """ + If param progress > progress_maximum, throws ValueError. + If param progress < progress minimum, throws ValueError. + Otherwise sets this trackers progress to param progress. + """ + if progress > self._progress_maximum: + raise ValueError( + "cannot set progress to value greater than progress_maximum" + ) + if progress < self._progress_minimum: + raise ValueError( + "cannot set progress to value less than progress_minimum" + ) + + self._progress = progress + + @abstractmethod + def start_tracker(self) -> None: + """ + Enables updates to the progress measure from another thread. + """ + pass + + @abstractmethod + def stop_tracker(self) -> None: + """ + Stops any threads that may be active for progress updates. + Must be called before losing reference to the instance of the ProgressTracker. + """ + pass diff --git a/src/allencell_ml_segmenter/core/view.py b/src/allencell_ml_segmenter/core/view.py index 1699c1c8..51fa9cfc 100644 --- a/src/allencell_ml_segmenter/core/view.py +++ b/src/allencell_ml_segmenter/core/view.py @@ -3,6 +3,7 @@ from PyQt5.QtCore import Qt, QThread, pyqtSignal from allencell_ml_segmenter.core.subscriber import Subscriber +from allencell_ml_segmenter.core.progress_tracker import ProgressTracker class ViewMeta(type(QWidget), type(Subscriber)): @@ -10,20 +11,34 @@ class ViewMeta(type(QWidget), type(Subscriber)): class LongTaskThread(QThread): - taskProgress = pyqtSignal(int) def __init__(self, do_work: callable, parent=None): - super(LongTaskThread, self).__init__(parent) + super().__init__(parent) self._do_work = do_work + # override def run(self): print("running") - # time.sleep(5) self._do_work() - # for i in range(1, 101): - # self.taskProgress.emit(i) - # self.msleep(100) # Simulating some work + +class ProgressThread(QThread): + # pyqtSignal must be class attribute + # https://www.riverbankcomputing.com/static/Docs/PyQt5/signals_slots.html#defining-new-signals-with-pyqtsignal + task_progress: pyqtSignal = pyqtSignal(int) + + def __init__(self, progress_tracker: ProgressTracker, parent=None): + super().__init__(parent) + self._progress_tracker: ProgressTracker = progress_tracker + + # override + def run(self): + while ( + self._progress_tracker.get_progress() + < self._progress_tracker.get_progress_maximum() + ): + self.task_progress.emit(self._progress_tracker.get_progress()) + self.msleep(100) class View(QWidget, Subscriber, metaclass=ViewMeta): @@ -36,7 +51,46 @@ class View(QWidget, Subscriber, metaclass=ViewMeta): def __init__(self): QWidget.__init__(self) - def startLongTask(self): + def startLongTaskWithProgressBar( + self, progress_tracker: ProgressTracker + ) -> None: + self.longTaskThread = LongTaskThread(do_work=self.doWork) + self.progressThread = ProgressThread(progress_tracker) + + self.progressDialog = QProgressDialog( + f"{self.getTypeOfWork()} in Progress", + "Cancel", + progress_tracker.get_progress_minimum(), + progress_tracker.get_progress_maximum(), + self, + ) + self.progressDialog.setWindowTitle(f"{self.getTypeOfWork()} Progress") + self.progressDialog.setWindowModality(Qt.ApplicationModal) + self.progressDialog.canceled.connect(self.longTaskThread.terminate) + self.progressDialog.canceled.connect(self.progressThread.terminate) + # stop the watchdog thread for file watching inside of the progress tracker + self.progressDialog.canceled.connect(progress_tracker.stop_tracker) + + self.progressDialog.show() + + self.longTaskThread.finished.connect(self.progressDialog.reset) + self.longTaskThread.finished.connect(self.longTaskThread.deleteLater) + self.longTaskThread.finished.connect(self.progressDialog.close) + self.longTaskThread.finished.connect(self.showResults) + + # progressThread's task_progress.emit now calls updateProgress + self.progressThread.task_progress.connect(self.updateProgress) + # if the longTaskThread or the progressThread finishes, we no longer + # need to update progress, so we should stop the progress tracker + self.progressThread.finished.connect(progress_tracker.stop_tracker) + self.longTaskThread.finished.connect(progress_tracker.stop_tracker) + + progress_tracker.start_tracker() + self.progressThread.start() + self.longTaskThread.start() + + # will remove once prediction is also ported to progress bar + def startLongTask(self) -> None: self.longTaskThread = LongTaskThread(do_work=self.doWork) self.progressDialog = QProgressDialog( f"{self.getTypeOfWork()} in Progress", "Cancel", 0, 0, self @@ -58,7 +112,7 @@ def startLongTask(self): def showResults(self): pass - def updateProgress(self, value): + def updateProgress(self, value: int) -> None: self.progressDialog.setValue(value) @abstractmethod diff --git a/src/allencell_ml_segmenter/main/experiments_model.py b/src/allencell_ml_segmenter/main/experiments_model.py index 4a365ed9..55539dfe 100644 --- a/src/allencell_ml_segmenter/main/experiments_model.py +++ b/src/allencell_ml_segmenter/main/experiments_model.py @@ -124,6 +124,34 @@ def get_csv_path(self) -> Path: / "data" ) + def get_metrics_csv_path(self) -> Path: + return ( + self.get_user_experiments_path() + / self.get_experiment_name() + / "csv" + ) + + def get_latest_metrics_csv_version(self) -> int: + """ + Returns version number of the most recent version directory within + the cyto-dl CSV folder (self._csv_path) or -1 if no version directories + exist + """ + last_version: int = -1 + if self.get_metrics_csv_path().exists(): + for child in self.get_metrics_csv_path().glob("version_*"): + if child.is_dir(): + version_str: str = child.name.split("_")[-1] + try: + last_version = ( + int(version_str) + if int(version_str) > last_version + else last_version + ) + except ValueError: + continue + return last_version + def get_train_config_path(self, experiment_name: str) -> Path: return ( self.get_user_experiments_path() diff --git a/src/allencell_ml_segmenter/training/metrics_csv_event_handler.py b/src/allencell_ml_segmenter/training/metrics_csv_event_handler.py new file mode 100644 index 00000000..2c805492 --- /dev/null +++ b/src/allencell_ml_segmenter/training/metrics_csv_event_handler.py @@ -0,0 +1,38 @@ +from watchdog.events import FileSystemEvent, FileSystemEventHandler +from pathlib import Path +from csv import DictReader +from typing import Callable + + +class MetricsCSVEventHandler(FileSystemEventHandler): + """ + A MetricsCSVEventHandler calls progress_callback upon any changes to + the provided target_path CSV file, passing the latest epoch in the modified + CSV to the callback. + """ + + def __init__(self, target_path: Path, progress_callback: Callable): + super().__init__() + self._target_path: Path = target_path + self._progress_callback: Callable = progress_callback + + def _get_latest_epoch(self) -> int: + if not self._target_path.exists(): + return 0 + + latest: int = 0 + with self._target_path.open("r", newline="") as fr: + dict_reader: DictReader = DictReader(fr) + for row in dict_reader: + latest = ( + int(row["epoch"]) if int(row["epoch"]) > latest else latest + ) + + return latest + + # override + def on_any_event(self, event: FileSystemEvent) -> None: + if self._target_path.exists() and self._target_path.samefile( + event.src_path + ): + self._progress_callback(self._get_latest_epoch()) diff --git a/src/allencell_ml_segmenter/training/metrics_csv_progress_tracker.py b/src/allencell_ml_segmenter/training/metrics_csv_progress_tracker.py new file mode 100644 index 00000000..e3698b2a --- /dev/null +++ b/src/allencell_ml_segmenter/training/metrics_csv_progress_tracker.py @@ -0,0 +1,50 @@ +from pathlib import Path +from watchdog.observers.api import BaseObserver +from watchdog.observers import Observer +from allencell_ml_segmenter.core.progress_tracker import ProgressTracker +from allencell_ml_segmenter.training.metrics_csv_event_handler import ( + MetricsCSVEventHandler, +) +from typing import Optional + + +class MetricsCSVProgressTracker(ProgressTracker): + """ + A MetricsCSVProgressTracker measures progress by observing a metrics CSV file + produced by cyto-dl and taking the greatest epoch listed inside of it as its + measure of progress. Relies heavily on current cyto-dl file logging procedure. + """ + + def __init__(self, csv_path: Path, num_epochs: int, version_number: int): + """ + :param csv_path: path to cyto-dl csv directory for an experiment + :param num_epochs: maximum number of epochs that will be recorded in the csv + :param version_number: experiment version to track + """ + super().__init__(progress_minimum=0, progress_maximum=num_epochs) + + self._csv_path: Path = csv_path + if not csv_path.exists(): + csv_path.mkdir(parents=True) + + self._target_path: Path = ( + csv_path / f"version_{version_number}" / "metrics.csv" + ) + self._observer: Optional[BaseObserver] = None + + # override + def start_tracker(self) -> None: + self.stop_tracker() + self._observer = Observer() + event_handler: MetricsCSVEventHandler = MetricsCSVEventHandler( + self._target_path, self.set_progress + ) + self._observer.schedule( + event_handler, path=self._csv_path, recursive=True + ) + self._observer.start() + + # override + def stop_tracker(self) -> None: + if self._observer: + self._observer.stop() diff --git a/src/allencell_ml_segmenter/training/view.py b/src/allencell_ml_segmenter/training/view.py index d64f734c..a7b82500 100644 --- a/src/allencell_ml_segmenter/training/view.py +++ b/src/allencell_ml_segmenter/training/view.py @@ -1,5 +1,4 @@ from pathlib import Path -import sys from allencell_ml_segmenter.main.i_viewer import IViewer from qtpy.QtCore import Qt from qtpy.QtWidgets import ( @@ -25,13 +24,16 @@ ImageSelectionWidget, ) from allencell_ml_segmenter.training.training_model import TrainingModel -from hydra.core.global_hydra import GlobalHydra + from aicsimageio import AICSImage from aicsimageio.readers import TiffReader from allencell_ml_segmenter.widgets.label_with_hint_widget import LabelWithHint from qtpy.QtGui import QIntValidator from allencell_ml_segmenter.training.training_model import PatchSize +from allencell_ml_segmenter.training.metrics_csv_progress_tracker import ( + MetricsCSVProgressTracker, +) class TrainingView(View): @@ -205,7 +207,14 @@ def train_btn_handler(self) -> None: """ Starts training process """ - self.startLongTask() + progress_tracker: MetricsCSVProgressTracker = ( + MetricsCSVProgressTracker( + self._experiments_model.get_metrics_csv_path(), + self._training_model.get_max_epoch(), + self._experiments_model.get_latest_metrics_csv_version() + 1, + ) + ) + self.startLongTaskWithProgressBar(progress_tracker) def read_result_images(self, dir_to_grab: Path): output_dir: Path = dir_to_grab