Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support exclude atypes in atomic model #3357

Merged
merged 12 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,89 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Tuple,
)

import numpy as np

from deepmd.dpmodel.utils import (
AtomExcludeMask,
PairExcludeMask,
)

from .make_base_atomic_model import (
make_base_atomic_model,
)

BaseAtomicModel = make_base_atomic_model(np.ndarray)
BaseAtomicModel_ = make_base_atomic_model(np.ndarray)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

Check warning on line 31 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L29-L31

Added lines #L29 - L31 were not covered by tests

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None

Check warning on line 39 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L37-L39

Added lines #L37 - L39 were not covered by tests
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

Check warning on line 41 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L41

Added line #L41 was not covered by tests

def reinit_pair_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None

Check warning on line 49 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L47-L49

Added lines #L47 - L49 were not covered by tests
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

Check warning on line 51 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L51

Added line #L51 was not covered by tests

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)

Check warning on line 65 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L62-L65

Added lines #L62 - L65 were not covered by tests
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)

Check warning on line 67 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L67

Added line #L67 was not covered by tests

ret_dict = self.forward_atomic(

Check warning on line 69 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L69

Added line #L69 was not covered by tests
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

Check warning on line 81 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L78-L81

Added lines #L78 - L81 were not covered by tests

return ret_dict

Check warning on line 83 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L83

Added line #L83 was not covered by tests

def serialize(self) -> dict:
return {

Check warning on line 86 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L86

Added line #L86 was not covered by tests
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
30 changes: 18 additions & 12 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
descriptor,
fitting,
type_map: Optional[List[str]] = None,
**kwargs,
):
super().__init__()
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
super().__init__(**kwargs)

Check warning on line 54 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L54

Added line #L54 was not covered by tests

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -132,24 +133,29 @@
return ret

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
dd = super().serialize()
dd.update(

Check warning on line 137 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L136-L137

Added lines #L136 - L137 were not covered by tests
{
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
)
return dd

Check warning on line 147 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L147

Added line #L147 was not covered by tests

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
type_map = data.pop("type_map", None)
obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data)

Check warning on line 158 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L155-L158

Added lines #L155 - L158 were not covered by tests
return obj

def get_dim_fparam(self) -> int:
Expand Down
33 changes: 19 additions & 14 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
models: List[BaseAtomicModel],
**kwargs,
):
super().__init__()
self.models = models
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)

Check warning on line 57 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L57

Added line #L57 was not covered by tests

def mixed_types(self) -> bool:
"""If true, the model
Expand Down Expand Up @@ -273,34 +273,39 @@
self.smin_alpha = smin_alpha

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
dd = BaseAtomicModel.serialize(self)
dd.update(

Check warning on line 277 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L276-L277

Added lines #L276 - L277 were not covered by tests
{
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
)
return dd

Check warning on line 288 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L288

Added line #L288 was not covered by tests

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

Check warning on line 298 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L296-L298

Added lines #L296 - L298 were not covered by tests

dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"])
dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models"))

Check warning on line 300 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L300

Added line #L300 was not covered by tests

return cls(
dp_model=dp_model,
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
smin_alpha=smin_alpha,
**data,
)

def _compute_weight(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@
def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)

Check warning on line 55 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L53-L55

Added lines #L53 - L55 were not covered by tests
else:
raise ValueError(

Check warning on line 57 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L57

Added line #L57 was not covered by tests
"cannot infer the number of types from a None type map"
)

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
Expand Down
28 changes: 16 additions & 12 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,29 @@
return True

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
dd = BaseAtomicModel.serialize(self)
dd.update(

Check warning on line 113 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L112-L113

Added lines #L112 - L113 were not covered by tests
{
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
)
return dd

Check warning on line 123 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L123

Added line #L123 was not covered by tests

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
rcut = data.pop("rcut")
sel = data.pop("sel")
tab = PairTab.deserialize(data.pop("tab"))
tab_model = cls(None, rcut, sel, **data)

Check warning on line 134 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L131-L134

Added lines #L131 - L134 were not covered by tests
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.reshape(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
atomic_ret = self.forward_atomic(
atomic_ret = self.forward_common_atomic(

Check warning on line 195 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L195

Added line #L195 was not covered by tests
extended_coord,
extended_atype,
nlist,
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ def get_model(data: dict) -> DPModel:
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
atom_exclude_types=data.get("atom_exclude_types", []),
pair_exclude_types=data.get("pair_exclude_types", []),
)
79 changes: 79 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,99 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from typing import (

Check warning on line 4 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L4

Added line #L4 was not covered by tests
Dict,
List,
Optional,
Tuple,
)

import torch

from deepmd.dpmodel.atomic_model import (
make_base_atomic_model,
)
from deepmd.pt.utils import (

Check warning on line 16 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L16

Added line #L16 was not covered by tests
AtomExcludeMask,
PairExcludeMask,
)

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(

Check warning on line 25 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L25

Added line #L25 was not covered by tests
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

Check warning on line 32 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L30-L32

Added lines #L30 - L32 were not covered by tests

def reinit_atom_exclude(

Check warning on line 34 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L34

Added line #L34 was not covered by tests
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None

Check warning on line 40 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L38-L40

Added lines #L38 - L40 were not covered by tests
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

Check warning on line 42 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L42

Added line #L42 was not covered by tests

def reinit_pair_exclude(

Check warning on line 44 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L44

Added line #L44 was not covered by tests
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None

Check warning on line 50 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L48-L50

Added lines #L48 - L50 were not covered by tests
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

Check warning on line 52 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L52

Added line #L52 was not covered by tests

# export public methods that are not abstract
get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel)
get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei)

@torch.jit.export
def get_model_def_script(self) -> str:
return self.model_def_script

def forward_common_atomic(

Check warning on line 62 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L62

Added line #L62 was not covered by tests
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]

Check warning on line 72 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L71-L72

Added lines #L71 - L72 were not covered by tests

if self.pair_excl is not None:
pair_mask = self.pair_excl(nlist, extended_atype)

Check warning on line 75 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L74-L75

Added lines #L74 - L75 were not covered by tests
# exclude neighbors in the nlist
nlist = torch.where(pair_mask == 1, nlist, -1)

Check warning on line 77 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L77

Added line #L77 was not covered by tests

ret_dict = self.forward_atomic(

Check warning on line 79 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L79

Added line #L79 was not covered by tests
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

Check warning on line 91 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L88-L91

Added lines #L88 - L91 were not covered by tests

return ret_dict

Check warning on line 93 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L93

Added line #L93 was not covered by tests

def serialize(self) -> dict:
return {

Check warning on line 96 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L95-L96

Added lines #L95 - L96 were not covered by tests
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
Loading