Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pydantic =1 support #75

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
make lint
make test
make docs-build

mamba install --name descent --yes "pydantic <2"
make test

- name: CodeCov
uses: codecov/codecov-action@v3.1.1
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ format:
$(CONDA_ENV_RUN) ruff check --fix --select I $(PACKAGE_DIR)

test:
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_DIR)/tests/
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-append --cov-report=xml --color=yes $(PACKAGE_DIR)/tests/

docs-build:
$(CONDA_ENV_RUN) mkdocs build
Expand Down
143 changes: 92 additions & 51 deletions descent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,53 @@ def _unflatten_tensors(
return tensors


class _PotentialKey(pydantic.BaseModel):
"""

TODO: Needed until interchange upgrades to pydantic >=2
"""
if pydantic.__version__.startswith("1."):
_PotentialKey = openff.interchange.models.PotentialKey
PotentialKeyList = list[_PotentialKey]
else:

id: str
mult: int | None = None
associated_handler: str | None = None
bond_order: float | None = None

def __hash__(self) -> int:
return hash((self.id, self.mult, self.associated_handler, self.bond_order))
class _PotentialKey(pydantic.BaseModel):
"""

def __eq__(self, other: object) -> bool:
import openff.interchange.models
TODO: Needed until interchange upgrades to pydantic >=2
"""

return (
isinstance(other, (_PotentialKey, openff.interchange.models.PotentialKey))
and self.id == other.id
and self.mult == other.mult
and self.associated_handler == other.associated_handler
and self.bond_order == other.bond_order
)
id: str
mult: int | None = None
associated_handler: str | None = None
bond_order: float | None = None

def __hash__(self) -> int:
return hash((self.id, self.mult, self.associated_handler, self.bond_order))

def __eq__(self, other: object) -> bool:
import openff.interchange.models

return (
isinstance(
other, (_PotentialKey, openff.interchange.models.PotentialKey)
)
and self.id == other.id
and self.mult == other.mult
and self.associated_handler == other.associated_handler
and self.bond_order == other.bond_order
)

def _convert_keys(value: typing.Any) -> typing.Any:
if not isinstance(value, list):
return value

def _convert_keys(value: typing.Any) -> typing.Any:
if not isinstance(value, list):
value = [
_PotentialKey(**v.dict())
if isinstance(v, openff.interchange.models.PotentialKey)
else v
for v in value
]
return value

value = [
_PotentialKey(**v.dict())
if isinstance(v, openff.interchange.models.PotentialKey)
else v
for v in value
PotentialKeyList = typing.Annotated[
list[_PotentialKey], pydantic.BeforeValidator(_convert_keys)
]
return value


PotentialKeyList = typing.Annotated[
list[_PotentialKey], pydantic.BeforeValidator(_convert_keys)
]


class AttributeConfig(pydantic.BaseModel):
Expand All @@ -89,17 +94,35 @@ class AttributeConfig(pydantic.BaseModel):
"none indicates no constraint.",
)

@pydantic.model_validator(mode="after")
def _validate_keys(self):
"""Ensure that the keys in `scales` and `limits` match `cols`."""
if pydantic.__version__.startswith("1."):

if any(key not in self.cols for key in self.scales):
raise ValueError("cannot scale non-trainable parameters")
@pydantic.root_validator
def _validate_keys(cls, values):
cols = values.get("cols")

if any(key not in self.cols for key in self.limits):
raise ValueError("cannot clamp non-trainable parameters")
scales = values.get("scales")
limits = values.get("limits")

return self
if any(key not in cols for key in scales):
raise ValueError("cannot scale non-trainable parameters")
if any(key not in cols for key in limits):
raise ValueError("cannot clamp non-trainable parameters")

return values

else:

@pydantic.model_validator(mode="after")
def _validate_keys(self):
"""Ensure that the keys in `scales` and `limits` match `cols`."""

if any(key not in self.cols for key in self.scales):
raise ValueError("cannot scale non-trainable parameters")

if any(key not in self.cols for key in self.limits):
raise ValueError("cannot clamp non-trainable parameters")

return self


class ParameterConfig(AttributeConfig):
Expand All @@ -118,18 +141,36 @@ class ParameterConfig(AttributeConfig):
"If ``None``, no parameters will be excluded.",
)

@pydantic.model_validator(mode="after")
def _validate_include_exclude(self):
"""Ensure that the keys in `include` and `exclude` are disjoint."""
if pydantic.__version__.startswith("1."):

@pydantic.root_validator
def _validate_include_exclude(cls, values):
include = values.get("include")
exclude = values.get("exclude")

if include is not None and exclude is not None:
include = {*include}
exclude = {*exclude}

if include & exclude:
raise ValueError("cannot include and exclude the same parameter")

return values

else:

@pydantic.model_validator(mode="after")
def _validate_include_exclude(self):
"""Ensure that the keys in `include` and `exclude` are disjoint."""

if self.include is not None and self.exclude is not None:
include = {*self.include}
exclude = {*self.exclude}
if self.include is not None and self.exclude is not None:
include = {*self.include}
exclude = {*self.exclude}

if include & exclude:
raise ValueError("cannot include and exclude the same parameter")
if include & exclude:
raise ValueError("cannot include and exclude the same parameter")

return self
return self


class Trainable:
Expand Down
5 changes: 4 additions & 1 deletion devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ dependencies:

# Core packages
- smee >=0.10.0
- pydantic-units # TODO: Remove this line once smee deps are updated

- pytorch
- pydantic
- pyarrow
- datasets

- pydantic

### Levenberg Marquardt
- scipy

Expand Down Expand Up @@ -57,5 +59,6 @@ dependencies:
- mkdocs-literate-nav
- mkdocstrings
- mkdocstrings-python
- griffe <1
- black
- mike
Loading