Skip to content

Commit

Permalink
fixed some linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
saeliddp committed Feb 16, 2024
1 parent 70a1307 commit 0a7bdde
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 27 deletions.
19 changes: 10 additions & 9 deletions src/allencell_ml_segmenter/core/progress_tracker.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,41 @@
from abc import abstractmethod


class ProgressTracker():
"""
Base class for all ProgressTrackers to inherit from. A ProgressTracker
maintains an integer measure of progress between progress_minimum and
progress_maximum. The progress value can be used by PyQt progress bars
for example.
"""
def __init__(self, progress_minimum: int=0, progress_maximum: int=0):

def __init__(self, progress_minimum: int = 0, progress_maximum: int = 0):
self._progress_minimum: int = progress_minimum
self._progress_maximum: int = progress_maximum
self._progress: int = progress_minimum

def get_progress_minimum(self) -> int:
return self._progress_minimum

def get_progress_maximum(self) -> int:
return self._progress_maximum

def get_progress(self) -> int:
return self._progress

def set_progress_minimum(self, progress_minimum: int) -> None:
self._progress_minimum = progress_minimum

def set_progress_maximum(self, progress_maximum: int) -> None:
self._progress_maximum = progress_maximum

def set_progress(self, progress: int) -> None:
self._progress = progress if progress >= self._progress_minimum and progress <= self._progress_maximum else self._progress

@abstractmethod
def start_tracker(self) -> None:
pass

@abstractmethod
def stop_tracker(self) -> None:
pass
pass
13 changes: 7 additions & 6 deletions src/allencell_ml_segmenter/core/view.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from abc import abstractmethod
from qtpy.QtWidgets import QWidget, QProgressDialog
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from pathlib import Path

from allencell_ml_segmenter.core.subscriber import Subscriber
from allencell_ml_segmenter.core.progress_tracker import ProgressTracker


class ViewMeta(type(QWidget), type(Subscriber)):
pass

Expand All @@ -25,7 +25,8 @@ def run(self):
# for i in range(1, 101):
# self.taskProgress.emit(i)
# self.msleep(100) # Simulating some work



class ProgressThread(QThread):
task_progress: pyqtSignal = pyqtSignal(int)

Expand Down Expand Up @@ -54,10 +55,10 @@ def startLongTaskWithProgressBar(self, progress_tracker: ProgressTracker):
self.progressThread = ProgressThread(progress_tracker)

self.progressDialog = QProgressDialog(
f"{self.getTypeOfWork()} in Progress",
"Cancel",
progress_tracker.get_progress_minimum(),
progress_tracker.get_progress_maximum(),
f"{self.getTypeOfWork()} in Progress",
"Cancel",
progress_tracker.get_progress_minimum(),
progress_tracker.get_progress_maximum(),
self
)
self.progressDialog.setWindowTitle(f"{self.getTypeOfWork()} Progress")
Expand Down
11 changes: 6 additions & 5 deletions src/allencell_ml_segmenter/training/metrics_csv_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from pathlib import Path
from csv import DictReader

class MetricsCSVEventHandler(FileSystemEventHandler):

class MetricsCSVEventHandler(FileSystemEventHandler):
"""
A MetricsCSVEventHandler calls progress_callback upon any changes to
the provided target_path CSV file, passing the latest epoch in the modified
Expand All @@ -13,19 +14,19 @@ def __init__(self, target_path: Path, progress_callback: callable):
super().__init__()
self._target_path: Path = target_path
self._progress_callback: callable = progress_callback

def _get_latest_epoch(self) -> int:
if not self._target_path.exists():
return 0

latest: int = 0
with self._target_path.open("r", newline="") as fr:
dict_reader: DictReader = DictReader(fr)
for row in dict_reader:
latest = int(row["epoch"]) if int(row["epoch"]) > latest else latest
latest = int(row["epoch"]) if int(row["epoch"]) > latest else latest

return latest

def on_any_event(self, event: FileSystemEvent) -> None:
if self._target_path.exists() and self._target_path.samefile(event.src_path):
self._progress_callback(self._get_latest_epoch())
self._progress_callback(self._get_latest_epoch())
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from allencell_ml_segmenter.core.progress_tracker import ProgressTracker
from allencell_ml_segmenter.training.metrics_csv_event_handler import MetricsCSVEventHandler


class MetricsCSVProgressTracker(ProgressTracker):
"""
A MetricsCSVProgressTracker measures progress by observing a metrics CSV file
produced by cyto-dl and taking the greatest epoch listed inside of it as its
measure of progress. Relies heavily on current cyto-dl file logging procedure.
"""

def __init__(self, csv_path: Path, progress_minimum: int=0, progress_maximum: int=0):
def __init__(self, csv_path: Path, progress_minimum: int = 0, progress_maximum: int = 0):
super().__init__(progress_minimum, progress_maximum)

self._csv_path: Path = csv_path
Expand All @@ -27,9 +28,9 @@ def start_tracker(self) -> None:
self.stop_tracker()
self._observer = Observer()
event_handler: MetricsCSVEventHandler = MetricsCSVEventHandler(self._target_path, self.set_progress)
self._observer.schedule(event_handler, path=self._csv_path, recursive=True)
self._observer.schedule(event_handler, path=self._csv_path, recursive=True)
self._observer.start()

def stop_tracker(self) -> None:
if self._observer:
self._observer.stop()
Expand All @@ -49,4 +50,4 @@ def _get_last_csv_version(self) -> int:
last_version = int(version_str) if int(version_str) > last_version else last_version
except ValueError:
continue
return last_version
return last_version
6 changes: 3 additions & 3 deletions src/allencell_ml_segmenter/training/view.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
import sys
from allencell_ml_segmenter.main.i_viewer import IViewer
from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
Expand All @@ -25,7 +24,7 @@
ImageSelectionWidget,
)
from allencell_ml_segmenter.training.training_model import TrainingModel
from hydra.core.global_hydra import GlobalHydra
# from hydra.core.global_hydra import GlobalHydra
from aicsimageio import AICSImage
from aicsimageio.readers import TiffReader

Expand All @@ -34,6 +33,7 @@
from allencell_ml_segmenter.training.training_model import PatchSize
from allencell_ml_segmenter.training.metrics_csv_progress_tracker import MetricsCSVProgressTracker


class TrainingView(View):
"""
Holds widgets pertinent to training processes - ImageSelectionWidget & ModelSelectionWidget.
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(

# apply styling
self.setStyleSheet(Style.get_stylesheet("training_view.qss"))

def train_btn_handler(self) -> None:
"""
Starts training process
Expand Down

0 comments on commit 0a7bdde

Please sign in to comment.