From 7ae1876dc0779a72b26d067d894360215fce5ce9 Mon Sep 17 00:00:00 2001 From: franz Date: Thu, 1 Feb 2024 14:53:47 -0600 Subject: [PATCH] revert formatting --- dataquality/__init__.py | 1 - dataquality/analytics.py | 1 + dataquality/integrations/setfit.py | 2 +- dataquality/integrations/torch.py | 8 ++++---- dataquality/integrations/ultralytics.py | 4 ++-- dataquality/internal.py | 1 + dataquality/loggers/base_logger.py | 6 ++---- dataquality/loggers/data_logger/base_data_logger.py | 9 +++------ tests/integrations/setfit/test_setfit.py | 2 +- tests/test_telemetrics.py | 4 ++-- 10 files changed, 17 insertions(+), 21 deletions(-) diff --git a/dataquality/__init__.py b/dataquality/__init__.py index e21c48fa5..c26a326e8 100644 --- a/dataquality/__init__.py +++ b/dataquality/__init__.py @@ -30,7 +30,6 @@ dataquality.get_insights() """ - __version__ = "1.6.1" import sys diff --git a/dataquality/analytics.py b/dataquality/analytics.py index 864847a5d..cc2fe324f 100644 --- a/dataquality/analytics.py +++ b/dataquality/analytics.py @@ -31,6 +31,7 @@ class Analytics(Borg): """Analytics is used to track errors and logs in the background""" _telemetrics_disabled: bool = True + _initialized: bool = False def __init__(self, ApiClient: Type[ApiClient], config: Config) -> None: """To initialize the Analytics class you need diff --git a/dataquality/integrations/setfit.py b/dataquality/integrations/setfit.py index ecbfef348..b8f5eb04d 100644 --- a/dataquality/integrations/setfit.py +++ b/dataquality/integrations/setfit.py @@ -353,7 +353,7 @@ def do_model_eval( dq_evaluate( encoded_data[split], split=split, - meta=meta_columns + meta=meta_columns, # for inference set the split to inference # and pass an inference_name="inference_run_1" ) diff --git a/dataquality/integrations/torch.py b/dataquality/integrations/torch.py index d0fd43128..cfa874694 100644 --- a/dataquality/integrations/torch.py +++ b/dataquality/integrations/torch.py @@ -296,10 +296,10 @@ def watch( def unwatch(model: Optional[Module] = None, force: bool = True) -> None: """Unwatches the model. Run after the run is finished. :param force: Force unwatch even if the model is not watched""" - torch_helper_data: TorchHelper = ( - dq.get_model_logger().logger_config.helper_data.get( - "torch_helper", TorchHelper() - ) + torch_helper_data: ( + TorchHelper + ) = dq.get_model_logger().logger_config.helper_data.get( + "torch_helper", TorchHelper() ) model = model or torch_helper_data.model diff --git a/dataquality/integrations/ultralytics.py b/dataquality/integrations/ultralytics.py index e6ab5d2b4..43defc392 100644 --- a/dataquality/integrations/ultralytics.py +++ b/dataquality/integrations/ultralytics.py @@ -234,8 +234,8 @@ def _after_pred_step(self, *args: Any, **kwargs: Any) -> None: logging_data = process_batch_data(self.bl.batch) if not self.nms_fn: raise Exception("NMS function not found") - postprocess = ( - lambda x: x if self.split == Split.validation else self.postprocess + postprocess = lambda x: ( + x if self.split == Split.validation else self.postprocess ) preds = postprocess(preds) nms = self.nms_fn( diff --git a/dataquality/internal.py b/dataquality/internal.py index e24a57a6d..3393dc9f1 100644 --- a/dataquality/internal.py +++ b/dataquality/internal.py @@ -1,4 +1,5 @@ """Internal functions to help Galileans""" + from typing import Dict from dataquality import config diff --git a/dataquality/loggers/base_logger.py b/dataquality/loggers/base_logger.py index 2a32c1533..dcea90c8e 100644 --- a/dataquality/loggers/base_logger.py +++ b/dataquality/loggers/base_logger.py @@ -208,8 +208,7 @@ def non_inference_logged(cls) -> bool: ) @abstractmethod - def log(self) -> None: - ... + def log(self) -> None: ... @staticmethod def _convert_tensor_ndarray( @@ -327,8 +326,7 @@ def _cleanup(cls) -> None: pm = PatchManager() pm.unpatch() - def upload(self) -> None: - ... + def upload(self) -> None: ... @classmethod def get_all_subclasses(cls: Type[T]) -> List[Type[T]]: diff --git a/dataquality/loggers/data_logger/base_data_logger.py b/dataquality/loggers/data_logger/base_data_logger.py index c8e0ea769..da51c0508 100644 --- a/dataquality/loggers/data_logger/base_data_logger.py +++ b/dataquality/loggers/data_logger/base_data_logger.py @@ -578,8 +578,7 @@ def validate_and_format(self) -> None: @classmethod @abstractmethod - def validate_labels(cls) -> None: - ... + def validate_labels(cls) -> None: ... def validate_metadata(self, batch_size: int) -> None: if len(self.meta.keys()) > self.MAX_META_COLS: @@ -689,8 +688,7 @@ def get_data_logger_attr(cls: object) -> str: @abstractmethod def separate_dataframe( cls, df: DataFrame, prob_only: bool = False, split: Optional[str] = None - ) -> BaseLoggerDataFrames: - ... + ) -> BaseLoggerDataFrames: ... def validate_kwargs(self, kwargs: Dict) -> None: """Raises if a function that shouldn't get kwargs gets any""" @@ -698,8 +696,7 @@ def validate_kwargs(self, kwargs: Dict) -> None: raise GalileoException(f"Unexpected arguments: {tuple(kwargs.keys())}") @abstractmethod - def _get_input_df(self) -> DataFrame: - ... + def _get_input_df(self) -> DataFrame: ... @classmethod def set_tagging_schema(cls, tagging_schema: TaggingSchema) -> None: diff --git a/tests/integrations/setfit/test_setfit.py b/tests/integrations/setfit/test_setfit.py index ff096861a..7fe91ecf6 100644 --- a/tests/integrations/setfit/test_setfit.py +++ b/tests/integrations/setfit/test_setfit.py @@ -191,7 +191,7 @@ def test_setfit_trainer( dq_evaluate( dataset, split="test", - column_mapping=column_mapping + column_mapping=column_mapping, # for inference set the split to inference # and pass an inference_name="inference_run_1" ) diff --git a/tests/test_telemetrics.py b/tests/test_telemetrics.py index f65821182..7a03f87fd 100644 --- a/tests/test_telemetrics.py +++ b/tests/test_telemetrics.py @@ -26,7 +26,7 @@ def test_mock_log_galileo_import(): a = Analytics(MockClient, {"api_url": "https://console.dev.rungalileo.io"}) a.last_log = {} a.log_import("test") - assert a.last_log["value"] == "test", "No import detected" + assert a.last_log.get("value") == "test", "No import detected" def test_log_galileo_exception(): @@ -52,7 +52,7 @@ def test_log_galileo__import(): assert ac._initialized, "Analytics not initialized" ac._telemetrics_disabled = False ac.log_import("test") - assert ac.last_log["value"] == "test", "No import detected" + assert ac.last_log.get("value") == "test", "No import detected" def test_mock_log_galileo_import_disabled():