Skip to content

Commit

Permalink
revert formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
franz101 committed Feb 1, 2024
1 parent 8446fcd commit 7ae1876
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 21 deletions.
1 change: 0 additions & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
dataquality.get_insights()
"""


__version__ = "1.6.1"

import sys
Expand Down
1 change: 1 addition & 0 deletions dataquality/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Analytics(Borg):
"""Analytics is used to track errors and logs in the background"""

_telemetrics_disabled: bool = True
_initialized: bool = False

def __init__(self, ApiClient: Type[ApiClient], config: Config) -> None:
"""To initialize the Analytics class you need
Expand Down
2 changes: 1 addition & 1 deletion dataquality/integrations/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def do_model_eval(
dq_evaluate(
encoded_data[split],
split=split,
meta=meta_columns
meta=meta_columns,
# for inference set the split to inference
# and pass an inference_name="inference_run_1"
)
Expand Down
8 changes: 4 additions & 4 deletions dataquality/integrations/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ def watch(
def unwatch(model: Optional[Module] = None, force: bool = True) -> None:
"""Unwatches the model. Run after the run is finished.
:param force: Force unwatch even if the model is not watched"""
torch_helper_data: TorchHelper = (
dq.get_model_logger().logger_config.helper_data.get(
"torch_helper", TorchHelper()
)
torch_helper_data: (
TorchHelper
) = dq.get_model_logger().logger_config.helper_data.get(
"torch_helper", TorchHelper()
)

model = model or torch_helper_data.model
Expand Down
4 changes: 2 additions & 2 deletions dataquality/integrations/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def _after_pred_step(self, *args: Any, **kwargs: Any) -> None:
logging_data = process_batch_data(self.bl.batch)
if not self.nms_fn:
raise Exception("NMS function not found")
postprocess = (
lambda x: x if self.split == Split.validation else self.postprocess
postprocess = lambda x: (
x if self.split == Split.validation else self.postprocess
)
preds = postprocess(preds)
nms = self.nms_fn(
Expand Down
1 change: 1 addition & 0 deletions dataquality/internal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Internal functions to help Galileans"""

from typing import Dict

from dataquality import config
Expand Down
6 changes: 2 additions & 4 deletions dataquality/loggers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def non_inference_logged(cls) -> bool:
)

@abstractmethod
def log(self) -> None:
...
def log(self) -> None: ...

@staticmethod
def _convert_tensor_ndarray(
Expand Down Expand Up @@ -327,8 +326,7 @@ def _cleanup(cls) -> None:
pm = PatchManager()
pm.unpatch()

def upload(self) -> None:
...
def upload(self) -> None: ...

@classmethod
def get_all_subclasses(cls: Type[T]) -> List[Type[T]]:
Expand Down
9 changes: 3 additions & 6 deletions dataquality/loggers/data_logger/base_data_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,7 @@ def validate_and_format(self) -> None:

@classmethod
@abstractmethod
def validate_labels(cls) -> None:
...
def validate_labels(cls) -> None: ...

def validate_metadata(self, batch_size: int) -> None:
if len(self.meta.keys()) > self.MAX_META_COLS:
Expand Down Expand Up @@ -689,17 +688,15 @@ def get_data_logger_attr(cls: object) -> str:
@abstractmethod
def separate_dataframe(
cls, df: DataFrame, prob_only: bool = False, split: Optional[str] = None
) -> BaseLoggerDataFrames:
...
) -> BaseLoggerDataFrames: ...

def validate_kwargs(self, kwargs: Dict) -> None:
"""Raises if a function that shouldn't get kwargs gets any"""
if kwargs.keys():
raise GalileoException(f"Unexpected arguments: {tuple(kwargs.keys())}")

@abstractmethod
def _get_input_df(self) -> DataFrame:
...
def _get_input_df(self) -> DataFrame: ...

@classmethod
def set_tagging_schema(cls, tagging_schema: TaggingSchema) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/setfit/test_setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_setfit_trainer(
dq_evaluate(
dataset,
split="test",
column_mapping=column_mapping
column_mapping=column_mapping,
# for inference set the split to inference
# and pass an inference_name="inference_run_1"
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_telemetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_mock_log_galileo_import():
a = Analytics(MockClient, {"api_url": "https://console.dev.rungalileo.io"})
a.last_log = {}
a.log_import("test")
assert a.last_log["value"] == "test", "No import detected"
assert a.last_log.get("value") == "test", "No import detected"


def test_log_galileo_exception():
Expand All @@ -52,7 +52,7 @@ def test_log_galileo__import():
assert ac._initialized, "Analytics not initialized"
ac._telemetrics_disabled = False
ac.log_import("test")
assert ac.last_log["value"] == "test", "No import detected"
assert ac.last_log.get("value") == "test", "No import detected"


def test_mock_log_galileo_import_disabled():
Expand Down

0 comments on commit 7ae1876

Please sign in to comment.