Skip to content

Commit

Permalink
used new version of black to reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
saeliddp committed Feb 16, 2024
1 parent 04b8b88 commit cdee8a5
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 15 deletions.
9 changes: 7 additions & 2 deletions src/allencell_ml_segmenter/core/progress_tracker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod


class ProgressTracker():
class ProgressTracker:
"""
Base class for all ProgressTrackers to inherit from. A ProgressTracker
maintains an integer measure of progress between progress_minimum and
Expand Down Expand Up @@ -30,7 +30,12 @@ def set_progress_maximum(self, progress_maximum: int) -> None:
self._progress_maximum = progress_maximum

def set_progress(self, progress: int) -> None:
self._progress = progress if progress >= self._progress_minimum and progress <= self._progress_maximum else self._progress
self._progress = (
progress
if progress >= self._progress_minimum
and progress <= self._progress_maximum
else self._progress
)

@abstractmethod
def start_tracker(self) -> None:
Expand Down
7 changes: 5 additions & 2 deletions src/allencell_ml_segmenter/core/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def __init__(self, progress_tracker: ProgressTracker, parent=None):
self._progress_tracker: ProgressTracker = progress_tracker

def run(self):
while self._progress_tracker.get_progress() < self._progress_tracker.get_progress_maximum():
while (
self._progress_tracker.get_progress()
< self._progress_tracker.get_progress_maximum()
):
self.task_progress.emit(self._progress_tracker.get_progress())
self.msleep(100)

Expand All @@ -59,7 +62,7 @@ def startLongTaskWithProgressBar(self, progress_tracker: ProgressTracker):
"Cancel",
progress_tracker.get_progress_minimum(),
progress_tracker.get_progress_maximum(),
self
self,
)
self.progressDialog.setWindowTitle(f"{self.getTypeOfWork()} Progress")
self.progressDialog.setWindowModality(Qt.ApplicationModal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ def _get_latest_epoch(self) -> int:
with self._target_path.open("r", newline="") as fr:
dict_reader: DictReader = DictReader(fr)
for row in dict_reader:
latest = int(row["epoch"]) if int(row["epoch"]) > latest else latest
latest = (
int(row["epoch"]) if int(row["epoch"]) > latest else latest
)

return latest

def on_any_event(self, event: FileSystemEvent) -> None:
if self._target_path.exists() and self._target_path.samefile(event.src_path):
if self._target_path.exists() and self._target_path.samefile(
event.src_path
):
self._progress_callback(self._get_latest_epoch())
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from watchdog.observers.api import BaseObserver
from watchdog.observers import Observer
from allencell_ml_segmenter.core.progress_tracker import ProgressTracker
from allencell_ml_segmenter.training.metrics_csv_event_handler import MetricsCSVEventHandler
from allencell_ml_segmenter.training.metrics_csv_event_handler import (
MetricsCSVEventHandler,
)


class MetricsCSVProgressTracker(ProgressTracker):
Expand All @@ -12,23 +14,34 @@ class MetricsCSVProgressTracker(ProgressTracker):
measure of progress. Relies heavily on current cyto-dl file logging procedure.
"""

def __init__(self, csv_path: Path, progress_minimum: int = 0, progress_maximum: int = 0):
def __init__(
self,
csv_path: Path,
progress_minimum: int = 0,
progress_maximum: int = 0,
):
super().__init__(progress_minimum, progress_maximum)

self._csv_path: Path = csv_path
if not csv_path.exists():
csv_path.mkdir(parents=True)

self._target_path: Path = (
csv_path / f"version_{self._get_last_csv_version() + 1}" / "metrics.csv"
csv_path
/ f"version_{self._get_last_csv_version() + 1}"
/ "metrics.csv"
)
self._observer: BaseObserver = None

def start_tracker(self) -> None:
self.stop_tracker()
self._observer = Observer()
event_handler: MetricsCSVEventHandler = MetricsCSVEventHandler(self._target_path, self.set_progress)
self._observer.schedule(event_handler, path=self._csv_path, recursive=True)
event_handler: MetricsCSVEventHandler = MetricsCSVEventHandler(
self._target_path, self.set_progress
)
self._observer.schedule(
event_handler, path=self._csv_path, recursive=True
)
self._observer.start()

def stop_tracker(self) -> None:
Expand All @@ -47,7 +60,11 @@ def _get_last_csv_version(self) -> int:
if child.is_dir():
version_str: str = child.name.split("_")[-1]
try:
last_version = int(version_str) if int(version_str) > last_version else last_version
last_version = (
int(version_str)
if int(version_str) > last_version
else last_version
)
except ValueError:
continue
return last_version
12 changes: 9 additions & 3 deletions src/allencell_ml_segmenter/training/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@
ImageSelectionWidget,
)
from allencell_ml_segmenter.training.training_model import TrainingModel

# from hydra.core.global_hydra import GlobalHydra
from aicsimageio import AICSImage
from aicsimageio.readers import TiffReader

from allencell_ml_segmenter.widgets.label_with_hint_widget import LabelWithHint
from qtpy.QtGui import QIntValidator
from allencell_ml_segmenter.training.training_model import PatchSize
from allencell_ml_segmenter.training.metrics_csv_progress_tracker import MetricsCSVProgressTracker
from allencell_ml_segmenter.training.metrics_csv_progress_tracker import (
MetricsCSVProgressTracker,
)


class TrainingView(View):
Expand Down Expand Up @@ -205,8 +208,11 @@ def train_btn_handler(self) -> None:
"""
Starts training process
"""
progress_tracker: MetricsCSVProgressTracker = MetricsCSVProgressTracker(
self._experiments_model.get_metrics_csv_path(), progress_maximum=self._training_model.get_max_epoch()
progress_tracker: MetricsCSVProgressTracker = (
MetricsCSVProgressTracker(
self._experiments_model.get_metrics_csv_path(),
progress_maximum=self._training_model.get_max_epoch(),
)
)
self.startLongTaskWithProgressBar(progress_tracker)

Expand Down

0 comments on commit cdee8a5

Please sign in to comment.