Skip to content

Commit

Permalink
feat: support exclude atypes in atomic model (#3357)
Browse files Browse the repository at this point in the history
With this pr, we do not need to implement the exclude types for each
descriptor and fitting.

---------

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 29, 2024
1 parent d09af56 commit 342b0c3
Show file tree
Hide file tree
Showing 23 changed files with 470 additions and 107 deletions.
83 changes: 82 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,89 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Tuple,
)

import numpy as np

from deepmd.dpmodel.utils import (
AtomExcludeMask,
PairExcludeMask,
)

from .make_base_atomic_model import (
make_base_atomic_model,
)

BaseAtomicModel = make_base_atomic_model(np.ndarray)
BaseAtomicModel_ = make_base_atomic_model(np.ndarray)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

def reinit_pair_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)

ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

return ret_dict

def serialize(self) -> dict:
return {
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
30 changes: 18 additions & 12 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def __init__(
descriptor,
fitting,
type_map: Optional[List[str]] = None,
**kwargs,
):
super().__init__()
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
super().__init__(**kwargs)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -132,24 +133,29 @@ def forward_atomic(
return ret

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
dd = super().serialize()
dd.update(
{
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
)
return dd

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
type_map = data.pop("type_map", None)
obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data)
return obj

def get_dim_fparam(self) -> int:
Expand Down
33 changes: 19 additions & 14 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(
models: List[BaseAtomicModel],
**kwargs,
):
super().__init__()
self.models = models
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)

def mixed_types(self) -> bool:
"""If true, the model
Expand Down Expand Up @@ -273,34 +273,39 @@ def __init__(
self.smin_alpha = smin_alpha

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
)
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"])
dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models"))

return cls(
dp_model=dp_model,
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
smin_alpha=smin_alpha,
**data,
)

def _compute_weight(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def get_rcut(self) -> float:
def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)
else:
raise ValueError(
"cannot infer the number of types from a None type map"
)

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
Expand Down
28 changes: 16 additions & 12 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,29 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
)
return dd

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
rcut = data.pop("rcut")
sel = data.pop("sel")
tab = PairTab.deserialize(data.pop("tab"))
tab_model = cls(None, rcut, sel, **data)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.forward_atomic(
atomic_ret = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ def get_model(data: dict) -> DPModel:
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
atom_exclude_types=data.get("atom_exclude_types", []),
pair_exclude_types=data.get("pair_exclude_types", []),
)
79 changes: 79 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,99 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from typing import (
Dict,
List,
Optional,
Tuple,
)

import torch

from deepmd.dpmodel.atomic_model import (
make_base_atomic_model,
)
from deepmd.pt.utils import (
AtomExcludeMask,
PairExcludeMask,
)

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

def reinit_pair_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

# export public methods that are not abstract
get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel)
get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei)

@torch.jit.export
def get_model_def_script(self) -> str:
return self.model_def_script

def forward_common_atomic(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]

if self.pair_excl is not None:
pair_mask = self.pair_excl(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = torch.where(pair_mask == 1, nlist, -1)

ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

return ret_dict

def serialize(self) -> dict:
return {
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
Loading

0 comments on commit 342b0c3

Please sign in to comment.