Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to django-ninja v1 / pydantic v2 #540

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
535 changes: 292 additions & 243 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ psycopg2 = "~2.9.9"
redis = { version = "~5.0.3", extras = [ "hiredis" ] }
uritemplate = "^4.1.1"
iso3166 = "^2.1.1"
rio-tiler = "5.0.3" # TODO: upgrade blocked on pydantic 2
rio-tiler = "6.8.0" # TODO: update to 7.x.
mercantile = "^1.2.1"
django-ninja = "~0.22.2" # TODO: upgrade blocked on pydantic 2
django-ninja = "^1.3.0"
celery = "^5.3.6"
django-extensions = "^3.2.3"
pillow = "^10.2.0"
Expand Down
53 changes: 28 additions & 25 deletions rdwatch/core/schemas/region_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Annotated, Any, Literal

from ninja import Field, Schema
from pydantic import confloat, constr, validator
from pydantic import Field, StringConstraints, field_validator

from django.contrib.gis.gdal import GDALException
from django.contrib.gis.geos import GEOSGeometry, Polygon
Expand All @@ -13,18 +13,19 @@
class RegionFeature(Schema):
type: Literal['region']
region_id: str # a Region isn't limited to their format for RDWATCH
version: str | None
version: str | None = None
mgrs: str
model_content: Literal['empty', 'annotation', 'proposed'] | None
start_date: datetime | None
end_date: datetime | None
model_content: Literal['empty', 'annotation', 'proposed'] | None = None
start_date: datetime | None = Field(default=None, validate_default=True)
end_date: datetime | None = Field(default=None, validate_default=True)
originator: str

# Optional fields
comments: str | None
performer_cache: dict[Any, Any] | None
comments: str | None = None
performer_cache: dict[Any, Any] | None = None

@validator('start_date', 'end_date', pre=True)
@field_validator('start_date', 'end_date', mode='before')
@classmethod
def parse_dates(cls, v: str | None) -> datetime | None:
if v is None:
return v
Expand All @@ -34,8 +35,8 @@ def parse_dates(cls, v: str | None) -> datetime | None:
class SiteSummaryFeature(Schema):
type: Literal['site_summary']
# match the site_id of format KR_R001_0001 or KR_R001_9990
site_id: constr(regex=r'^.{1,255}_\d{4,8}$')
version: str | None
site_id: Annotated[str, StringConstraints(pattern=r'^.{1,255}_\d{4,8}$')]
version: str | None = None
mgrs: str
status: Literal[
'positive_annotated',
Expand All @@ -52,32 +53,34 @@ class SiteSummaryFeature(Schema):
'system_confirmed',
'system_rejected',
]
start_date: datetime | None
end_date: datetime | None
model_content: Literal['annotation', 'proposed'] | None
start_date: datetime | None = Field(default=None, validate_default=True)
end_date: datetime | None = Field(default=None, validate_default=True)
model_content: Literal['annotation', 'proposed'] | None = None
originator: str

# Optional fields
comments: str | None
score: confloat(ge=0.0, le=1.0) | None
validated: Literal['True', 'False'] | None
annotation_cache: dict[Any, Any] | None

@validator('start_date', 'end_date', pre=True)
comments: str | None = None
score: Annotated[float, Field(ge=0.0, le=1.0)] | None = Field(
default=None, validate_default=True
)
validated: Literal['True', 'False'] | None = None
annotation_cache: dict[Any, Any] | None = None

@field_validator('start_date', 'end_date', mode='before')
@classmethod
def parse_dates(cls, v: str | None) -> datetime | None:
if v is None:
return v
return datetime.strptime(v, '%Y-%m-%d')

@validator('score', pre=True, always=True)
@field_validator('score', mode='before')
@classmethod
def parse_score(cls, v: float | None) -> float:
"""
Score is an optional field, and defaults to 1.0 if one isn't provided
https://smartgitlab.com/TE/standards/-/wikis/Region-Model-Specification#score-float-optional
"""
if v is None:
return 1.0
return v
return v if v is not None else 1.0

@property
def site_number(self) -> int:
Expand All @@ -97,7 +100,7 @@ class Feature(Schema):
def parsed_geometry(self) -> GEOSGeometry:
return GEOSGeometry(json.dumps(self.geometry))

@validator('geometry', pre=True)
@field_validator('geometry', mode='before')
def parse_geometry(cls, v: dict[str, Any]) -> dict[str, Any]:
try:
geom = GEOSGeometry(json.dumps(v))
Expand All @@ -112,7 +115,7 @@ class RegionModel(Schema):
type: Literal['FeatureCollection']
features: list[Feature]

@validator('features')
@field_validator('features')
def ensure_one_region_feature(cls, v: list[Feature]):
region_features = [
feature for feature in v if isinstance(feature.properties, RegionFeature)
Expand Down
17 changes: 9 additions & 8 deletions rdwatch/core/schemas/site_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from typing import Literal

from ninja import Schema
from pydantic import validator
from pydantic import field_validator


class SiteEvaluationRequest(Schema):
label: str | None
label: str | None = None
geom: dict | None = None
score: float | None
start_date: datetime | None
end_date: datetime | None
notes: str | None
status: Literal['PROPOSAL', 'APPROVED', 'REJECTED'] | None
score: float | None = None
start_date: datetime | None = None
end_date: datetime | None = None
notes: str | None = None
status: Literal['PROPOSAL', 'APPROVED', 'REJECTED'] | None = None

@validator('start_date', 'end_date', pre=True)
@field_validator('start_date', 'end_date', mode='before')
@classmethod
def parse_dates(cls, v: str | None) -> datetime | None:
if v is None:
return v
Expand Down
118 changes: 61 additions & 57 deletions rdwatch/core/schemas/site_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Annotated, Any, Literal, TypeAlias

from ninja import Field, Schema
from pydantic import confloat, constr, root_validator, validator
from pydantic import Field, StringConstraints, field_validator, model_validator

from django.contrib.gis.gdal import GDALException
from django.contrib.gis.geos import GEOSGeometry
Expand All @@ -20,17 +20,17 @@


class SiteFeatureCache(Schema):
originator_file: str | None
timestamp: datetime | None
commit_hash: str | None
originator_file: str | None = None
timestamp: datetime | None = None
commit_hash: str | None = None


class SiteFeature(Schema):
type: Literal['site']
region_id: constr(min_length=1, max_length=255)
site_id: constr(regex=r'^.{1,255}_\d{4,8}$')
version: constr(regex=r'^\d+\.\d+\.\d+$')
mgrs: constr(regex=r'^\d{2}[A-Z]{3}$')
region_id: Annotated[str, StringConstraints(min_length=1, max_length=255)]
site_id: Annotated[str, StringConstraints(pattern=r'^.{1,255}_\d{4,8}$')]
version: Annotated[str, StringConstraints(pattern=r'^\d+\.\d+\.\d+$')]
mgrs: Annotated[str, StringConstraints(pattern=r'^\d{2}[A-Z]{3}$')]
status: Literal[
'positive_annotated',
'positive_partial',
Expand All @@ -46,40 +46,42 @@ class SiteFeature(Schema):
'system_confirmed',
'system_rejected',
]
start_date: datetime | None
end_date: datetime | None
start_date: datetime | None = Field(default=None, validate_default=True)
end_date: datetime | None = Field(default=None, validate_default=True)
model_content: Literal['annotation', 'proposed', 'update']
originator: str

# Optional fields
score: confloat(ge=0.0, le=1.0) | None
validated: Literal['True', 'False'] | None
cache: SiteFeatureCache | None
score: Annotated[float, Field(ge=0.0, le=1.0)] | None = Field(
default=None, validate_default=True
)
validated: Literal['True', 'False'] | None = None
cache: SiteFeatureCache | None = None
predicted_phase_transition: (
Literal[
'Active Construction',
'Post Construction',
]
| None
)
predicted_phase_transition_date: str | None
misc_info: dict[str, Any] | None
) = None
predicted_phase_transition_date: str | None = None
misc_info: dict[str, Any] | None = None

@validator('start_date', 'end_date', pre=True)
@field_validator('start_date', 'end_date', mode='before')
@classmethod
def parse_dates(cls, v: str | None) -> datetime | None:
if v is None:
return v
return datetime.strptime(v, '%Y-%m-%d')

@validator('score', pre=True, always=True)
@field_validator('score', mode='before')
@classmethod
def parse_score(cls, v: float | None) -> float:
"""
Score is an optional field, and defaults to 1.0 if one isn't provided
https://smartgitlab.com/TE/standards/-/wikis/Site-Model-Specification#score-float-optional
"""
if v is None:
return 1.0
return v
return v if v is not None else 1.0

@property
def site_number(self) -> int:
Expand All @@ -89,17 +91,20 @@ def site_number(self) -> int:

class ObservationFeature(Schema):
type: Literal['observation']
observation_date: datetime | None
source: str | None
sensor_name: Literal['Landsat 8', 'Sentinel-2', 'WorldView', 'Planet'] | None
current_phase: list[CurrentPhase] | None
is_occluded: list[bool] | None
is_site_boundary: list[bool] | None

@validator('is_occluded', 'is_site_boundary', pre=True)
def convert_bools_to_list(cls, val: str, values, field):
observation_date: datetime | None = Field(default=None, validate_default=True)
source: str | None = None
sensor_name: Literal['Landsat 8', 'Sentinel-2', 'WorldView', 'Planet'] | None = None
current_phase: list[CurrentPhase] | None = Field(
default=None, validate_default=True
)
is_occluded: list[bool] | None = Field(default=None, validate_default=True)
is_site_boundary: list[bool] | None = Field(default=None, validate_default=True)

@field_validator('is_occluded', 'is_site_boundary', mode='before')
@classmethod
def convert_bools_to_list(cls, val: str | None):
"""
Converts comma-space-seperated strings into lists of bools.
Converts comma-space-separated strings into lists of bools.
"""
if val is None:
return val
Expand All @@ -108,54 +113,53 @@ def convert_bools_to_list(cls, val: str, values, field):
]
if None in converted_list:
raise ValueError(
f'Invalid value "{val}" for field {field.name} - '
'must be a comma-space-separated formatted string.'
f'Invalid value "{val}" - must be a comma-space-separated formatted string.'
)
return converted_list

@validator('current_phase', pre=True)
@field_validator('current_phase', mode='before')
def convert_phases_to_list(cls, val: str | None):
"""
Converts comma-space-seperated strings into lists of phase strings.
Converts comma-space-separated strings into lists of phase strings.
"""
if val is None:
return val
return val.split(', ')

@validator('observation_date', pre=True)
@field_validator('observation_date', mode='before')
def parse_dates(cls, v: Any) -> datetime | None:
if v is None:
return None
if not isinstance(v, str):
raise ValueError('"observation_date" must be a valid date string.')
return datetime.strptime(v, '%Y-%m-%d')

@root_validator
def ensure_consistent_list_lengths(cls, values: dict[str, Any]):
@model_validator(mode='after')
def ensure_consistent_list_lengths(self):
lists = [
values.get(field)
getattr(self, field)
for field in ('current_phase', 'is_occluded', 'is_site_boundary')
if values.get(field) is not None
if getattr(self, field) is not None
]
if len(lists) and len({len(l) for l in lists}) != 1:
raise ValueError(
'current_phase/is_occluded/is_site_boundary lists must be the same length!'
)
return values
return self

# Optional fields
score: confloat(ge=0.0, le=1.0) | None
misc_info: dict[str, Any] | None
score: Annotated[float, Field(ge=0.0, le=1.0)] | None = Field(
default=None, validate_default=True
)
misc_info: dict[str, Any] | None = None

@validator('score', pre=True, always=True)
@field_validator('score', mode='before')
@classmethod
def parse_score(cls, v: float | None) -> float:
"""
Score is an optional field, and defaults to 1.0 if one isn't provided
https://smartgitlab.com/TE/standards/-/wikis/Site-Model-Specification#score-float-optional-1
"""
if v is None:
return 1.0
return v
return v if v is not None else 1.0


class Feature(Schema):
Expand All @@ -170,7 +174,7 @@ class Feature(Schema):
def parsed_geometry(self) -> GEOSGeometry:
return GEOSGeometry(json.dumps(self.geometry))

@validator('geometry', pre=True)
@field_validator('geometry', mode='before')
def parse_geometry(cls, v: dict[str, Any]):
try:
GEOSGeometry(json.dumps(v))
Expand All @@ -180,15 +184,15 @@ def parse_geometry(cls, v: dict[str, Any]):
raise ValueError(f'Failed to parse geometry: {e}')
return v

@root_validator
def ensure_correct_geometry_type(cls, values: dict[str, Any]):
if 'properties' not in values or 'geometry' not in values:
return values
if isinstance(values['properties'], SiteFeature) and (
values['geometry'].get('type') not in ['Polygon', 'Point']
@model_validator(mode='after')
def ensure_correct_geometry_type(self):
if not self.properties or not self.geometry:
return self
if isinstance(self.properties, SiteFeature) and (
self.geometry.get('type') not in ['Polygon', 'Point']
):
raise ValueError('Site geometry must be of type "Polygon" or "Point"')
return values
return self


class SiteModel(Schema):
Expand All @@ -211,7 +215,7 @@ def observation_features(self) -> list[Feature]:
if isinstance(feature.properties, ObservationFeature)
]

@validator('features')
@field_validator('features')
def ensure_one_site_feature(cls, v: list[Feature]):
site_features = [feature for feature in v if feature.properties.type == 'site']
if len(site_features) != 1:
Expand Down
Loading
Loading