diff --git a/dataquality/__init__.py b/dataquality/__init__.py index 4185e86dd..50560dc57 100644 --- a/dataquality/__init__.py +++ b/dataquality/__init__.py @@ -31,7 +31,7 @@ """ -__version__ = "1.5.1" +__version__ = "1.6.0" import sys from typing import Any, List, Optional diff --git a/dataquality/analytics.py b/dataquality/analytics.py index 61c4ce78b..864847a5d 100644 --- a/dataquality/analytics.py +++ b/dataquality/analytics.py @@ -23,8 +23,8 @@ class ProfileModel(BaseModel): """User profile""" - packages: Optional[Dict[str, str]] - uuid: Optional[str] + packages: Optional[Dict[str, str]] = None + uuid: Optional[str] = None class Analytics(Borg): @@ -106,7 +106,7 @@ def _setup_user(self) -> ProfileModel: """This function is used to setup the user information. This includes all installed packages. """ - profile = ProfileModel(**{"uuid": str(hex(uuid.getnode()))}) + profile = ProfileModel(uuid=str(hex(uuid.getnode()))) try: profile.packages = _installed_modules() except Exception: diff --git a/dataquality/core/_config.py b/dataquality/core/_config.py index 72b6be2ab..92e9e72cf 100644 --- a/dataquality/core/_config.py +++ b/dataquality/core/_config.py @@ -3,12 +3,11 @@ import warnings from enum import Enum from pathlib import Path -from typing import Dict, Optional +from typing import Any, Dict, Optional import requests from packaging import version -from pydantic import BaseModel -from pydantic.class_validators import validator +from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import UUID4 from requests.exceptions import ConnectionError as ReqConnectionError @@ -30,7 +29,7 @@ class GalileoConfigVars(str, Enum): CONSOLE_URL = "GALILEO_CONSOLE_URL" @staticmethod - def get_config_mapping() -> Dict[str, Optional[str]]: + def get_config_mapping() -> Dict[str, Any]: return {i.name.lower(): os.environ.get(i.value) for i in GalileoConfigVars} @staticmethod @@ -85,10 +84,7 @@ class Config(BaseModel): minio_fqdn: Optional[str] = None is_exoscale_cluster: bool = False - class Config: - validate_assignment = True - arbitrary_types_allowed = True - underscore_attrs_are_private = True + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) def update_file_config(self) -> None: config_json = self.dict() @@ -96,7 +92,7 @@ def update_file_config(self) -> None: with open(config_data.DEFAULT_GALILEO_CONFIG_FILE, "w+") as f: f.write(json.dumps(config_json, default=str)) - @validator("api_url", pre=True, always=True, allow_reuse=True) + @field_validator("api_url", mode="before") def add_scheme(cls, v: str) -> str: if v and not v.startswith("http"): # api url needs the scheme @@ -209,7 +205,7 @@ def set_config(initial_startup: bool = False) -> Config: if os.path.exists(config_data.DEFAULT_GALILEO_CONFIG_FILE): with open(config_data.DEFAULT_GALILEO_CONFIG_FILE) as f: try: - config_vars: Dict[str, str] = json.load(f) + config_vars: Dict[str, Any] = json.load(f) # If there's an issue reading the config file for any reason, quit and # start fresh except Exception as e: diff --git a/dataquality/loggers/logger_config/base_logger_config.py b/dataquality/loggers/logger_config/base_logger_config.py index 05f3e29ed..848bc8dbb 100644 --- a/dataquality/loggers/logger_config/base_logger_config.py +++ b/dataquality/loggers/logger_config/base_logger_config.py @@ -1,7 +1,7 @@ from collections import defaultdict from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator from dataquality.schemas.condition import Condition from dataquality.schemas.ner import TaggingSchema @@ -13,11 +13,11 @@ class BaseLoggerConfig(BaseModel): tasks: Any = None observed_num_labels: Any = None observed_labels: Any = None - tagging_schema: Optional[TaggingSchema] + tagging_schema: Optional[TaggingSchema] = None last_epoch: int = 0 - cur_epoch: Optional[int] - cur_split: Optional[Split] - cur_inference_name: Optional[str] + cur_epoch: Optional[int] = None + cur_split: Optional[Split] = None + cur_inference_name: Optional[str] = None training_logged: bool = False validation_logged: bool = False test_logged: bool = False @@ -36,24 +36,27 @@ class BaseLoggerConfig(BaseModel): finish: Callable = lambda: None # Overwritten in Semantic Segmentation # True when calling `init` with a run that already exists existing_run: bool = False - dataloader_random_sampling = False + dataloader_random_sampling: bool = False + # model_config = ConfigDict(validate_assignment=True) remove_embs: bool = False - class Config: - validate_assignment = True + model_config = ConfigDict(validate_assignment=True) def reset(self, factory: bool = False) -> None: """Reset all class vars""" self.__init__() # type: ignore - @validator("cur_split") + @field_validator("cur_split", mode="after") def inference_sets_inference_name( - cls, field_value: Split, values: Dict[str, Any] + cls, field_value: Split, validation_info: ValidationInfo ) -> Split: + values = validation_info.data if field_value == Split.inference: - assert values.get( - "cur_inference_name" - ), "Please specify inference_name when setting split to inference" + split_name = values.get("cur_inference_name") + if not split_name: + raise ValueError( + "Please specify inference_name when setting split to inference" + ) return field_value diff --git a/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py b/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py index 0a09ba0e7..9e06435c5 100644 --- a/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py +++ b/dataquality/loggers/logger_config/seq2seq/seq2seq_base.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Set, Union from peft import PeftModel +from pydantic import ConfigDict from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig @@ -23,9 +24,7 @@ class Seq2SeqLoggerConfig(BaseLoggerConfig): # Decoder only below id_to_formatted_prompt_length: Dict[str, Dict[int, int]] = defaultdict(dict) response_template: Optional[List[int]] = None - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) seq2seq_logger_config = Seq2SeqLoggerConfig() diff --git a/dataquality/loggers/logger_config/text_classification.py b/dataquality/loggers/logger_config/text_classification.py index 2a34ddf07..5c6772fdd 100644 --- a/dataquality/loggers/logger_config/text_classification.py +++ b/dataquality/loggers/logger_config/text_classification.py @@ -1,7 +1,7 @@ from typing import List, Optional, Set import numpy as np -from pydantic import validator +from pydantic import ConfigDict, field_validator from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig @@ -10,11 +10,9 @@ class TextClassificationLoggerConfig(BaseLoggerConfig): labels: Optional[List[str]] = None observed_num_labels: int = 0 observed_labels: Set[str] = set() + model_config = ConfigDict(validate_assignment=True) - class Config: - validate_assignment = True - - @validator("labels", always=True, pre=True, allow_reuse=True) + @field_validator("labels", mode="before") def clean_labels(cls, labels: List[str]) -> List[str]: if labels is None: return labels diff --git a/dataquality/loggers/logger_config/text_multi_label.py b/dataquality/loggers/logger_config/text_multi_label.py index 42aef86d4..bb371d400 100644 --- a/dataquality/loggers/logger_config/text_multi_label.py +++ b/dataquality/loggers/logger_config/text_multi_label.py @@ -2,7 +2,7 @@ from typing import DefaultDict, List, Optional, Set import numpy as np -from pydantic import validator +from pydantic import ConfigDict, field_validator from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig @@ -14,11 +14,9 @@ class TextMultiLabelLoggerConfig(BaseLoggerConfig): tasks: Optional[List[str]] = None observed_num_tasks: int = 0 binary: bool = True # For binary multi label + model_config = ConfigDict(validate_assignment=True) - class Config: - validate_assignment = True - - @validator("labels", always=True, pre=True) + @field_validator("labels", mode="before") def clean_labels(cls, labels: List[List[str]]) -> List[List[str]]: cleaned_labels = [] if isinstance(labels, np.ndarray): diff --git a/dataquality/loggers/logger_config/text_ner.py b/dataquality/loggers/logger_config/text_ner.py index 7f6408ffa..a137520e6 100644 --- a/dataquality/loggers/logger_config/text_ner.py +++ b/dataquality/loggers/logger_config/text_ner.py @@ -1,7 +1,7 @@ from typing import Dict, List, Tuple import numpy as np -from pydantic import validator +from pydantic import ConfigDict, field_validator from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig @@ -9,15 +9,12 @@ class TextNERLoggerConfig(BaseLoggerConfig): gold_spans: Dict[str, List[Tuple[int, int, str]]] = {} sample_length: Dict[str, int] = {} - - class Config: - validate_assignment = True - arbitrary_types_allowed = True + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) def get_sample_key(self, split: str, sample_id: int) -> str: return f"{split}_{sample_id}" - @validator("labels", always=True, pre=True, allow_reuse=True) + @field_validator("labels", mode="before") def clean_labels(cls, labels: List[str]) -> List[str]: if isinstance(labels, np.ndarray): labels = labels.tolist() diff --git a/dataquality/schemas/condition.py b/dataquality/schemas/condition.py index 93ecd0c93..1381ccd5c 100644 --- a/dataquality/schemas/condition.py +++ b/dataquality/schemas/condition.py @@ -1,8 +1,7 @@ from enum import Enum from typing import Dict, List, Optional, Tuple, Union -from pydantic import BaseModel -from pydantic.class_validators import validator +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from vaex.dataframe import DataFrame @@ -197,8 +196,9 @@ class Condition(BaseModel): agg: AggregateFunction operator: Operator threshold: float - metric: Optional[str] = None - filters: Optional[List[ConditionFilter]] = [] + metric: Optional[str] = Field(default=None, validate_default=True) + filters: List[ConditionFilter] = Field(default_factory=list, validate_default=True) + model_config = ConfigDict(validate_assignment=True) def evaluate(self, df: DataFrame) -> Tuple[bool, float]: filtered_df = self._apply_filters(df) @@ -226,24 +226,28 @@ def __call__(self, df: DataFrame) -> None: """Asserts the condition""" assert self.evaluate(df)[0] - @validator("filters", pre=True, always=True) + @field_validator("filters", mode="before") def validate_filters( - cls, v: Optional[List[ConditionFilter]], values: Dict + cls, value: Optional[List[ConditionFilter]], validation_info: ValidationInfo ) -> Optional[List[ConditionFilter]]: - if not v: + values: Dict = validation_info.data + if not value: agg = values["agg"] if agg == AggregateFunction.pct: raise ValueError("Percentage aggregate requires a filter") - return v + return value - @validator("metric", pre=True, always=True) - def validate_metric(cls, v: Optional[str], values: Dict) -> Optional[str]: - if not v: - agg = values["agg"] + @field_validator("metric", mode="before") + def validate_metric( + cls, value: Optional[str], validation_info: ValidationInfo + ) -> Optional[str]: + values: Dict = validation_info.data + if value is None: + agg = values.get("agg") if agg != AggregateFunction.pct: raise ValueError( f"You must set a metric for non-percentage aggregate function {agg}" ) - return v + return value diff --git a/dataquality/schemas/dataframe.py b/dataquality/schemas/dataframe.py index 61e9a8801..f80324de2 100644 --- a/dataquality/schemas/dataframe.py +++ b/dataquality/schemas/dataframe.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from enum import Enum, unique -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from vaex.dataframe import DataFrame @@ -9,9 +9,7 @@ class BaseLoggerDataFrames(BaseModel): prob: DataFrame emb: DataFrame data: DataFrame - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) @unique diff --git a/dataquality/schemas/edit.py b/dataquality/schemas/edit.py index 78675a911..7226890fa 100644 --- a/dataquality/schemas/edit.py +++ b/dataquality/schemas/edit.py @@ -1,7 +1,15 @@ from enum import Enum, unique -from typing import Any, Dict, Optional - -from pydantic import UUID4, BaseModel, StrictInt, StrictStr, validator +from typing import Dict, Optional + +from pydantic import ( + UUID4, + BaseModel, + ConfigDict, + StrictInt, + StrictStr, + ValidationInfo, + field_validator, +) from dataquality.schemas.metrics import FilterParams @@ -40,35 +48,42 @@ class Edit(BaseModel): How many words (forward or back) to shift the end of the span by """ - filter: Optional[FilterParams] + model_config = ConfigDict(validate_assignment=True) + + filter: Optional[FilterParams] = None - new_label: Optional[StrictStr] + new_label: Optional[StrictStr] = None - search_string: Optional[StrictStr] - text_replacement: Optional[StrictStr] + search_string: Optional[StrictStr] = None + text_replacement: Optional[StrictStr] = None use_regex: bool = False - shift_span_start_num_words: Optional[StrictInt] - shift_span_end_num_words: Optional[StrictInt] + shift_span_start_num_words: Optional[StrictInt] = None + shift_span_end_num_words: Optional[StrictInt] = None - project_id: Optional[UUID4] - run_id: Optional[UUID4] - split: Optional[str] + project_id: Optional[UUID4] = None + run_id: Optional[UUID4] = None + split: Optional[str] = None task: Optional[str] = None inference_name: Optional[str] = None - note: Optional[StrictStr] + note: Optional[StrictStr] = None edit_action: EditAction - @validator("edit_action", pre=True) - def new_label_if_relabel(cls, edit_action: EditAction, values: Dict) -> EditAction: + @field_validator("edit_action", mode="before") + def new_label_if_relabel( + cls, edit_action: EditAction, validation_info: ValidationInfo + ) -> EditAction: + values: Dict = validation_info.data + if edit_action == EditAction.relabel and values.get("new_label") is None: raise ValueError("If your edit is relabel, you must set new_label") return edit_action - @validator("edit_action", pre=True) + @field_validator("edit_action", mode="before") def text_replacement_if_update_text( - cls, edit_action: EditAction, values: Dict + cls, edit_action: EditAction, validation_info: ValidationInfo ) -> EditAction: + values: Dict = validation_info.data if edit_action == EditAction.update_text and ( values.get("text_replacement") is None or values.get("search_string") is None @@ -79,8 +94,11 @@ def text_replacement_if_update_text( ) return edit_action - @validator("edit_action", pre=True) - def shift_span_validator(cls, edit_action: EditAction, values: Dict) -> EditAction: + @field_validator("edit_action", mode="before") + def shift_span_validator( + cls, edit_action: EditAction, validation_info: ValidationInfo + ) -> EditAction: + values: Dict = validation_info.data err = ( "If your edit is shift_span, you must set search_string and at least " "one of shift_span_start_num_words or shift_span_end_num_words" @@ -95,10 +113,11 @@ def shift_span_validator(cls, edit_action: EditAction, values: Dict) -> EditActi raise ValueError(err) return edit_action - @validator("edit_action", pre=True, always=True) + @field_validator("edit_action", mode="before") def validate_edit_action_for_split( - cls, edit_action: EditAction, values: Dict[str, Any] + cls, edit_action: EditAction, validation_info: ValidationInfo ) -> EditAction: + values: Dict = validation_info.data if not values.get("split"): return edit_action split = values["split"] diff --git a/dataquality/schemas/metrics.py b/dataquality/schemas/metrics.py index 23a451afe..421cb33f7 100644 --- a/dataquality/schemas/metrics.py +++ b/dataquality/schemas/metrics.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional +from typing import List, Optional -from pydantic import BaseModel, Field, StrictStr, root_validator +from pydantic import BaseModel, ConfigDict, Field, StrictStr, model_validator class HashableBaseModel(BaseModel): @@ -47,16 +47,20 @@ class LassoSelection(HashableBaseModel): provided by plotly when creating a lasso selection """ + model_config = ConfigDict(validate_assignment=True) + x: List[float] y: List[float] - @root_validator() - def validate_xy(cls: BaseModel, values: Dict[str, List]) -> Dict[str, List]: - if len(values.get("x", [])) != len(values.get("y", [])): + @model_validator(mode="after") + def validate_xy(self) -> "LassoSelection": + x, y = self.x, self.y + if len(x) != len(y): raise ValueError("x and y must have the same number of points") - if len(values.get("x", [])) < 1: + if len(x) < 1: raise ValueError("x and y must have at least 1 value") - return values + + return self class FilterParams(HashableBaseModel): diff --git a/dataquality/schemas/report.py b/dataquality/schemas/report.py index 9e1ddf19d..84e00da32 100644 --- a/dataquality/schemas/report.py +++ b/dataquality/schemas/report.py @@ -11,9 +11,9 @@ class ConditionStatus(str, Enum): class SplitConditionData(BaseModel): split: str - inference_name: Optional[str] + inference_name: Optional[str] = None status: ConditionStatus - link: Optional[str] + link: Optional[str] = None ground_truth: float diff --git a/dataquality/utils/semantic_segmentation/errors.py b/dataquality/utils/semantic_segmentation/errors.py index 7a185a5bb..6c6ff6906 100644 --- a/dataquality/utils/semantic_segmentation/errors.py +++ b/dataquality/utils/semantic_segmentation/errors.py @@ -133,7 +133,7 @@ def calculate_classification_error( # count the number of pixels in the pred mask relevant region that are # not the correct class areas = np.bincount(incorrect_pixels, minlength=number_classes) - argmax = np.argmax(areas) + argmax = int(np.argmax(areas)) return ClassificationErrorData( accuracy=float_accuracy, mislabeled_class=argmax, diff --git a/pyproject.toml b/pyproject.toml index 3fbea16d0..e61c88f68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,12 +13,13 @@ readme = "README.md" license = {text = 'See LICENSE'} requires-python = ">=3.8" dependencies = [ - "pydantic>=1.8.2,<2.0.0", + "pydantic>=2.0.0", + "pydantic-settings>=2.0.0", "requests>=2.25.1", "types-requests>=2.25.2", "pandas>=0.20.0", "pyarrow>=5.0.0", - "vaex-core==4.16.0", + "vaex-core==4.17.1", "vaex-hdf5>=0.12,<0.13", "diskcache>=5.2.1", "resource>=0.2.1", @@ -90,7 +91,7 @@ test = [ "setfit==0.7.0", "accelerate>=0.19.0", "typing-inspect==0.8.0", - "typing-extensions==4.0.1", + "typing-extensions>=4.9.0", "lightning", ] dev = [ @@ -111,12 +112,9 @@ cuda = [ "cudf-cu11==23.2.0", "cuml-cu11==23.2.0" ] -minio = [ - "minio>=7.1.0,<7.2.0" -] -setfit = [ - "setfit==0.7.0" -] +minio = ["minio>=7.1.0,<7.2.0"] +setfit = ["setfit==0.7.0"] + [tool.setuptools.dynamic] version = {attr = "dataquality.__version__"} @@ -205,4 +203,3 @@ exclude = ''' )/ ''' - diff --git a/tests/core/test_report.py b/tests/core/test_report.py index 816137d71..8a92951db 100644 --- a/tests/core/test_report.py +++ b/tests/core/test_report.py @@ -131,7 +131,7 @@ def test_build_run_report_e2e( ConditionFilter( metric="is_drifted", operator=Operator.eq, - value=True, + value=1.0, ) ], ) @@ -202,8 +202,10 @@ def test_build_run_report_e2e( }, ], } + mock_notify_email.assert_called_once_with( expected_report_data, "run_report", ["foo@bar.com"] ) + # Assert that caching prevented all 6 calls to get_dataframes assert mock_get_dataframe.call_count == 3 diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 36c9f5fbe..e49a0615e 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -30,8 +30,8 @@ def test_set_split_inference_missing_inference_name( dataquality.set_split("inference") assert ( - e.value.errors()[0]["msg"] - == "Please specify inference_name when setting split to inference" + "Please specify inference_name when setting split to inference" + in e.value.errors()[0]["msg"] ) diff --git a/tests/loggers/test_multi_label.py b/tests/loggers/test_multi_label.py index db431c43b..98118e8c6 100644 --- a/tests/loggers/test_multi_label.py +++ b/tests/loggers/test_multi_label.py @@ -224,7 +224,7 @@ def test_set_tasks_not_set_binary( dq.set_labels_for_run(tasks) err = e.value.errors()[0]["msg"] - assert err.startswith("Labels must be a list of lists.") + assert "Labels must be a list of lists." in err assert "If you are running a binary multi-label case," in err diff --git a/tests/schemas/test_metrics.py b/tests/schemas/test_metrics.py index ecf3e03ec..df98a3cbf 100644 --- a/tests/schemas/test_metrics.py +++ b/tests/schemas/test_metrics.py @@ -7,6 +7,6 @@ def test_nested_filter_params_are_hashable() -> None: "ids": [1, 2, 3, 4], "lasso": {"x": [0.1, 0.1, 0.2], "y": [0.4, 0.5, 0.6]}, "inference_filter": {"is_otb": True}, - "meta_filter": [{"name": "foo", "isin": [1, 2, 3]}], + "meta_filter": [{"name": "foo", "isin": ["1", "2", "3"]}], } - assert FilterParams(**f2).__hash__() + assert FilterParams.model_validate(f2).__hash__()