Skip to content

Commit

Permalink
feat(deps): Update to pydantic v2 (#820)
Browse files Browse the repository at this point in the history
  • Loading branch information
franz101 authored Jan 25, 2024
1 parent 2351270 commit 9d6d925
Show file tree
Hide file tree
Showing 19 changed files with 125 additions and 110 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.5.1"
__version__ = "1.6.0"

import sys
from typing import Any, List, Optional
Expand Down
6 changes: 3 additions & 3 deletions dataquality/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
class ProfileModel(BaseModel):
"""User profile"""

packages: Optional[Dict[str, str]]
uuid: Optional[str]
packages: Optional[Dict[str, str]] = None
uuid: Optional[str] = None


class Analytics(Borg):
Expand Down Expand Up @@ -106,7 +106,7 @@ def _setup_user(self) -> ProfileModel:
"""This function is used to setup the user information.
This includes all installed packages.
"""
profile = ProfileModel(**{"uuid": str(hex(uuid.getnode()))})
profile = ProfileModel(uuid=str(hex(uuid.getnode())))
try:
profile.packages = _installed_modules()
except Exception:
Expand Down
16 changes: 6 additions & 10 deletions dataquality/core/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import warnings
from enum import Enum
from pathlib import Path
from typing import Dict, Optional
from typing import Any, Dict, Optional

import requests
from packaging import version
from pydantic import BaseModel
from pydantic.class_validators import validator
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.types import UUID4
from requests.exceptions import ConnectionError as ReqConnectionError

Expand All @@ -30,7 +29,7 @@ class GalileoConfigVars(str, Enum):
CONSOLE_URL = "GALILEO_CONSOLE_URL"

@staticmethod
def get_config_mapping() -> Dict[str, Optional[str]]:
def get_config_mapping() -> Dict[str, Any]:
return {i.name.lower(): os.environ.get(i.value) for i in GalileoConfigVars}

@staticmethod
Expand Down Expand Up @@ -85,18 +84,15 @@ class Config(BaseModel):
minio_fqdn: Optional[str] = None
is_exoscale_cluster: bool = False

class Config:
validate_assignment = True
arbitrary_types_allowed = True
underscore_attrs_are_private = True
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)

def update_file_config(self) -> None:
config_json = self.dict()

with open(config_data.DEFAULT_GALILEO_CONFIG_FILE, "w+") as f:
f.write(json.dumps(config_json, default=str))

@validator("api_url", pre=True, always=True, allow_reuse=True)
@field_validator("api_url", mode="before")
def add_scheme(cls, v: str) -> str:
if v and not v.startswith("http"):
# api url needs the scheme
Expand Down Expand Up @@ -209,7 +205,7 @@ def set_config(initial_startup: bool = False) -> Config:
if os.path.exists(config_data.DEFAULT_GALILEO_CONFIG_FILE):
with open(config_data.DEFAULT_GALILEO_CONFIG_FILE) as f:
try:
config_vars: Dict[str, str] = json.load(f)
config_vars: Dict[str, Any] = json.load(f)
# If there's an issue reading the config file for any reason, quit and
# start fresh
except Exception as e:
Expand Down
29 changes: 16 additions & 13 deletions dataquality/loggers/logger_config/base_logger_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set

from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator

from dataquality.schemas.condition import Condition
from dataquality.schemas.ner import TaggingSchema
Expand All @@ -13,11 +13,11 @@ class BaseLoggerConfig(BaseModel):
tasks: Any = None
observed_num_labels: Any = None
observed_labels: Any = None
tagging_schema: Optional[TaggingSchema]
tagging_schema: Optional[TaggingSchema] = None
last_epoch: int = 0
cur_epoch: Optional[int]
cur_split: Optional[Split]
cur_inference_name: Optional[str]
cur_epoch: Optional[int] = None
cur_split: Optional[Split] = None
cur_inference_name: Optional[str] = None
training_logged: bool = False
validation_logged: bool = False
test_logged: bool = False
Expand All @@ -36,24 +36,27 @@ class BaseLoggerConfig(BaseModel):
finish: Callable = lambda: None # Overwritten in Semantic Segmentation
# True when calling `init` with a run that already exists
existing_run: bool = False
dataloader_random_sampling = False
dataloader_random_sampling: bool = False
# model_config = ConfigDict(validate_assignment=True)
remove_embs: bool = False

class Config:
validate_assignment = True
model_config = ConfigDict(validate_assignment=True)

def reset(self, factory: bool = False) -> None:
"""Reset all class vars"""
self.__init__() # type: ignore

@validator("cur_split")
@field_validator("cur_split", mode="after")
def inference_sets_inference_name(
cls, field_value: Split, values: Dict[str, Any]
cls, field_value: Split, validation_info: ValidationInfo
) -> Split:
values = validation_info.data
if field_value == Split.inference:
assert values.get(
"cur_inference_name"
), "Please specify inference_name when setting split to inference"
split_name = values.get("cur_inference_name")
if not split_name:
raise ValueError(
"Please specify inference_name when setting split to inference"
)
return field_value


Expand Down
5 changes: 2 additions & 3 deletions dataquality/loggers/logger_config/seq2seq/seq2seq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Set, Union

from peft import PeftModel
from pydantic import ConfigDict
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast

from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig
Expand All @@ -23,9 +24,7 @@ class Seq2SeqLoggerConfig(BaseLoggerConfig):
# Decoder only below
id_to_formatted_prompt_length: Dict[str, Dict[int, int]] = defaultdict(dict)
response_template: Optional[List[int]] = None

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)


seq2seq_logger_config = Seq2SeqLoggerConfig()
8 changes: 3 additions & 5 deletions dataquality/loggers/logger_config/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Set

import numpy as np
from pydantic import validator
from pydantic import ConfigDict, field_validator

from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig

Expand All @@ -10,11 +10,9 @@ class TextClassificationLoggerConfig(BaseLoggerConfig):
labels: Optional[List[str]] = None
observed_num_labels: int = 0
observed_labels: Set[str] = set()
model_config = ConfigDict(validate_assignment=True)

class Config:
validate_assignment = True

@validator("labels", always=True, pre=True, allow_reuse=True)
@field_validator("labels", mode="before")
def clean_labels(cls, labels: List[str]) -> List[str]:
if labels is None:
return labels
Expand Down
8 changes: 3 additions & 5 deletions dataquality/loggers/logger_config/text_multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import DefaultDict, List, Optional, Set

import numpy as np
from pydantic import validator
from pydantic import ConfigDict, field_validator

from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig

Expand All @@ -14,11 +14,9 @@ class TextMultiLabelLoggerConfig(BaseLoggerConfig):
tasks: Optional[List[str]] = None
observed_num_tasks: int = 0
binary: bool = True # For binary multi label
model_config = ConfigDict(validate_assignment=True)

class Config:
validate_assignment = True

@validator("labels", always=True, pre=True)
@field_validator("labels", mode="before")
def clean_labels(cls, labels: List[List[str]]) -> List[List[str]]:
cleaned_labels = []
if isinstance(labels, np.ndarray):
Expand Down
9 changes: 3 additions & 6 deletions dataquality/loggers/logger_config/text_ner.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
from typing import Dict, List, Tuple

import numpy as np
from pydantic import validator
from pydantic import ConfigDict, field_validator

from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig


class TextNERLoggerConfig(BaseLoggerConfig):
gold_spans: Dict[str, List[Tuple[int, int, str]]] = {}
sample_length: Dict[str, int] = {}

class Config:
validate_assignment = True
arbitrary_types_allowed = True
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)

def get_sample_key(self, split: str, sample_id: int) -> str:
return f"{split}_{sample_id}"

@validator("labels", always=True, pre=True, allow_reuse=True)
@field_validator("labels", mode="before")
def clean_labels(cls, labels: List[str]) -> List[str]:
if isinstance(labels, np.ndarray):
labels = labels.tolist()
Expand Down
30 changes: 17 additions & 13 deletions dataquality/schemas/condition.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union

from pydantic import BaseModel
from pydantic.class_validators import validator
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from vaex.dataframe import DataFrame


Expand Down Expand Up @@ -197,8 +196,9 @@ class Condition(BaseModel):
agg: AggregateFunction
operator: Operator
threshold: float
metric: Optional[str] = None
filters: Optional[List[ConditionFilter]] = []
metric: Optional[str] = Field(default=None, validate_default=True)
filters: List[ConditionFilter] = Field(default_factory=list, validate_default=True)
model_config = ConfigDict(validate_assignment=True)

def evaluate(self, df: DataFrame) -> Tuple[bool, float]:
filtered_df = self._apply_filters(df)
Expand Down Expand Up @@ -226,24 +226,28 @@ def __call__(self, df: DataFrame) -> None:
"""Asserts the condition"""
assert self.evaluate(df)[0]

@validator("filters", pre=True, always=True)
@field_validator("filters", mode="before")
def validate_filters(
cls, v: Optional[List[ConditionFilter]], values: Dict
cls, value: Optional[List[ConditionFilter]], validation_info: ValidationInfo
) -> Optional[List[ConditionFilter]]:
if not v:
values: Dict = validation_info.data
if not value:
agg = values["agg"]
if agg == AggregateFunction.pct:
raise ValueError("Percentage aggregate requires a filter")

return v
return value

@validator("metric", pre=True, always=True)
def validate_metric(cls, v: Optional[str], values: Dict) -> Optional[str]:
if not v:
agg = values["agg"]
@field_validator("metric", mode="before")
def validate_metric(
cls, value: Optional[str], validation_info: ValidationInfo
) -> Optional[str]:
values: Dict = validation_info.data
if value is None:
agg = values.get("agg")
if agg != AggregateFunction.pct:
raise ValueError(
f"You must set a metric for non-percentage aggregate function {agg}"
)

return v
return value
6 changes: 2 additions & 4 deletions dataquality/schemas/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from dataclasses import dataclass
from enum import Enum, unique

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from vaex.dataframe import DataFrame


class BaseLoggerDataFrames(BaseModel):
prob: DataFrame
emb: DataFrame
data: DataFrame

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)


@unique
Expand Down
Loading

0 comments on commit 9d6d925

Please sign in to comment.