Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support exclude atypes in atomic model #3357

Merged
merged 12 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,89 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Tuple,
)

import numpy as np

from deepmd.dpmodel.utils import (
AtomExcludeMask,
PairExcludeMask,
)

from .make_base_atomic_model import (
make_base_atomic_model,
)

BaseAtomicModel = make_base_atomic_model(np.ndarray)
BaseAtomicModel_ = make_base_atomic_model(np.ndarray)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

def reinit_pair_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)

ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

return ret_dict

def serialize(self) -> dict:
return {
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
30 changes: 18 additions & 12 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def __init__(
descriptor,
fitting,
type_map: Optional[List[str]] = None,
**kwargs,
):
super().__init__()
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
super().__init__(**kwargs)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -132,24 +133,29 @@ def forward_atomic(
return ret

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
dd = super().serialize()
dd.update(
{
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
)
return dd

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
type_map = data.pop("type_map", None)
obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data)
return obj

def get_dim_fparam(self) -> int:
Expand Down
33 changes: 19 additions & 14 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(
models: List[BaseAtomicModel],
**kwargs,
):
super().__init__()
self.models = models
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)

def mixed_types(self) -> bool:
"""If true, the model
Expand Down Expand Up @@ -273,34 +273,39 @@ def __init__(
self.smin_alpha = smin_alpha

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
)
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"])
dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models"))

return cls(
dp_model=dp_model,
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
smin_alpha=smin_alpha,
**data,
)

def _compute_weight(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@
def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)
else:
raise ValueError(

Check warning on line 57 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L57

Added line #L57 was not covered by tests
"cannot infer the number of types from a None type map"
)

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
Expand Down
28 changes: 16 additions & 12 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,29 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
)
return dd

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
rcut = data.pop("rcut")
sel = data.pop("sel")
tab = PairTab.deserialize(data.pop("tab"))
tab_model = cls(None, rcut, sel, **data)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.forward_atomic(
atomic_ret = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ def get_model(data: dict) -> DPModel:
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
atom_exclude_types=data.get("atom_exclude_types", []),
pair_exclude_types=data.get("pair_exclude_types", []),
)
79 changes: 79 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,99 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from typing import (
Dict,
List,
Optional,
Tuple,
)

import torch

from deepmd.dpmodel.atomic_model import (
make_base_atomic_model,
)
from deepmd.pt.utils import (
AtomExcludeMask,
PairExcludeMask,
)

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

def reinit_pair_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

# export public methods that are not abstract
get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel)
get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei)

@torch.jit.export
def get_model_def_script(self) -> str:
return self.model_def_script

def forward_common_atomic(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]

if self.pair_excl is not None:
pair_mask = self.pair_excl(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = torch.where(pair_mask == 1, nlist, -1)

ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

return ret_dict

def serialize(self) -> dict:
return {
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
Loading