From 0efc0a3cee2cf7b730b96daa4af00226d9271dee Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 18:28:45 +0800 Subject: [PATCH 1/8] feat: support exclude atypes in atomic model --- .../atomic_model/make_base_atomic_model.py | 10 +++ .../model/atomic_model/base_atomic_model.py | 79 +++++++++++++++++++ .../pt/model/atomic_model/dp_atomic_model.py | 39 ++++++--- .../model/atomic_model/linear_atomic_model.py | 36 ++++++--- .../atomic_model/pairtab_atomic_model.py | 31 +++++--- deepmd/pt/model/descriptor/se_a.py | 18 ++++- deepmd/pt/model/model/make_model.py | 2 +- deepmd/pt/model/task/fitting.py | 12 ++- deepmd/pt/utils/exclude_mask.py | 14 ++++ source/tests/pt/model/test_dp_atomic_model.py | 77 +++++++++++++++--- 10 files changed, 263 insertions(+), 55 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index d4186c990d..df6e39dd2e 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -48,6 +48,16 @@ def get_rcut(self) -> float: def get_type_map(self) -> Optional[List[str]]: """Get the type map.""" + def get_ntypes(self) -> int: + """Get the number of atom types.""" + tmap = self.get_type_map() + if tmap is not None: + return len(tmap) + else: + raise ValueError( + "cannot infer the number of types from a None type map" + ) + @abstractmethod def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 1e5f976baf..4d22533baf 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -1,16 +1,56 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Tuple, +) + import torch from deepmd.dpmodel.atomic_model import ( make_base_atomic_model, ) +from deepmd.pt.utils import ( + AtomExcludeMask, + PairExcludeMask, +) BaseAtomicModel_ = make_base_atomic_model(torch.Tensor) class BaseAtomicModel(BaseAtomicModel_): + def __init__( + self, + atom_exclude_types: List[int] = [], + pair_exclude_types: List[Tuple[int, int]] = [], + ): + super().__init__() + self.reinit_atom_exclude(atom_exclude_types) + self.reinit_pair_exclude(pair_exclude_types) + + def reinit_atom_exclude( + self, + exclude_types: List[int] = [], + ): + self.atom_exclude_types = exclude_types + if exclude_types == []: + self.atom_excl = None + else: + self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types) + + def reinit_pair_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.pair_exclude_types = exclude_types + if exclude_types == []: + self.pair_excl = None + else: + self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types) + # export public methods that are not abstract get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel) get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei) @@ -18,3 +58,42 @@ class BaseAtomicModel(BaseAtomicModel_): @torch.jit.export def get_model_def_script(self) -> str: return self.model_def_script + + def forward_common_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + nf, nloc, nnei = nlist.shape + nf, nall = extended_atype.shape + atype = extended_atype[:, :nloc] + if self.pair_excl is not None: + pair_mask = self.pair_excl(nlist, extended_atype) + # exclude neighbors in the nlist + nlist = torch.where(pair_mask == 1, nlist, -1) + + ret_dict = self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + + if self.atom_excl is not None: + atom_mask = self.atom_excl(atype) + for kk in ret_dict.keys(): + ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None] + + return ret_dict + + def serialize(self) -> dict: + return { + "atom_exclude_types": self.atom_exclude_types, + "pair_exclude_types": self.pair_exclude_types, + } diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 881ea4c97d..4a1fece61e 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -46,8 +46,14 @@ class DPAtomicModel(torch.nn.Module, BaseAtomicModel): For example `type_map[1]` gives the name of the type 1. """ - def __init__(self, descriptor, fitting, type_map: Optional[List[str]]): - super().__init__() + def __init__( + self, + descriptor, + fitting, + type_map: Optional[List[str]], + **kwargs, + ): + torch.nn.Module.__init__(self) self.model_def_script = "" ntypes = len(type_map) self.type_map = type_map @@ -56,6 +62,8 @@ def __init__(self, descriptor, fitting, type_map: Optional[List[str]]): self.rcut = self.descriptor.get_rcut() self.sel = self.descriptor.get_sel() self.fitting_net = fitting + # order matters ntypes and type_map should be initialized first. + BaseAtomicModel.__init__(self, **kwargs) def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -92,20 +100,27 @@ def mixed_types(self) -> bool: return self.descriptor.mixed_types() def serialize(self) -> dict: - return { - "@class": "Model", - "type": "standard", - "type_map": self.type_map, - "descriptor": self.descriptor.serialize(), - "fitting": self.fitting_net.serialize(), - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "type": "standard", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "fitting": self.fitting_net.serialize(), + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) - descriptor_obj = BaseDescriptor.deserialize(data["descriptor"]) - fitting_obj = BaseFitting.deserialize(data["fitting"]) - obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) + data.pop("@class", None) + data.pop("type", None) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + type_map = data.pop("type_map", None) + obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) return obj def forward_atomic( diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 68ff303d64..a247a3bf50 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import sys from abc import ( abstractmethod, @@ -50,10 +51,11 @@ def __init__( models: List[BaseAtomicModel], **kwargs, ): - super().__init__() + torch.nn.Module.__init__(self) self.models = torch.nn.ModuleList(models) self.atomic_bias = None self.mixed_types_list = [model.mixed_types() for model in self.models] + BaseAtomicModel.__init__(self, **kwargs) def mixed_types(self) -> bool: """If true, the model @@ -300,29 +302,37 @@ def __init__( self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE) def serialize(self) -> dict: - return { - "@class": "Model", - "type": "zbl", - "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), - "sw_rmin": self.sw_rmin, - "sw_rmax": self.sw_rmax, - "smin_alpha": self.smin_alpha, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "type": "zbl", + "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), + "sw_rmin": self.sw_rmin, + "sw_rmax": self.sw_rmax, + "smin_alpha": self.smin_alpha, + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPZBLLinearAtomicModel": - sw_rmin = data["sw_rmin"] - sw_rmax = data["sw_rmax"] - smin_alpha = data["smin_alpha"] + data = copy.deepcopy(data) + sw_rmin = data.pop("sw_rmin") + sw_rmax = data.pop("sw_rmax") + smin_alpha = data.pop("smin_alpha") - dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"]) + dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models")) + data.pop("@class", None) + data.pop("type", None) return cls( dp_model=dp_model, zbl_model=zbl_model, sw_rmin=sw_rmin, sw_rmax=sw_rmax, smin_alpha=smin_alpha, + **data, ) def _compute_weight( diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 86bfe98c36..6549534c8a 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -48,11 +48,12 @@ class PairTabAtomicModel(torch.nn.Module, BaseAtomicModel): def __init__( self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs ): - super().__init__() + torch.nn.Module.__init__(self) self.model_def_script = "" self.tab_file = tab_file self.rcut = rcut self.tab = self._set_pairtab(tab_file, rcut) + BaseAtomicModel.__init__(self, **kwargs) # handle deserialization with no input file if self.tab_file is not None: @@ -121,20 +122,26 @@ def mixed_types(self) -> bool: return True def serialize(self) -> dict: - return { - "@class": "Model", - "type": "pairtab", - "tab": self.tab.serialize(), - "rcut": self.rcut, - "sel": self.sel, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "type": "pairtab", + "tab": self.tab.serialize(), + "rcut": self.rcut, + "sel": self.sel, + } + ) + return dd @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": - rcut = data["rcut"] - sel = data["sel"] - tab = PairTab.deserialize(data["tab"]) - tab_model = cls(None, rcut, sel) + rcut = data.pop("rcut") + sel = data.pop("sel") + tab = PairTab.deserialize(data.pop("tab")) + data.pop("@class", None) + data.pop("type", None) + tab_model = cls(None, rcut, sel, **data) tab_model.tab = tab tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info)) tab_model.register_buffer("tab_data", torch.from_numpy(tab_model.tab.tab_data)) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 033d640ad8..032e747897 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -130,6 +130,13 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for descriptor elements.""" return self.sea.compute_input_stats(merged, path) + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + """Update the type exclusions.""" + self.sea.reinit_exclude(exclude_types) + def forward( self, coord_ext: torch.Tensor, @@ -266,10 +273,10 @@ def __init__( self.prec = PRECISION_DICT[self.precision] self.resnet_dt = resnet_dt self.old_impl = old_impl - self.exclude_types = exclude_types self.ntypes = len(sel) - self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types) self.type_one_side = type_one_side + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) self.sel = sel self.sec = torch.tensor( @@ -402,6 +409,13 @@ def get_stats(self) -> Dict[str, StatItem]: ) return self.stats + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def forward( self, nlist: torch.Tensor, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 79634186e4..1ce0a0143f 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -186,7 +186,7 @@ def forward_common_lower( nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.view(nframes, -1, 3) nlist = self.format_nlist(extended_coord, extended_atype, nlist) - atomic_ret = self.forward_atomic( + atomic_ret = self.forward_common_atomic( extended_coord, extended_atype, nlist, diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 6c395d3800..a0045123f6 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -285,9 +285,8 @@ def __init__( self.precision = precision self.prec = PRECISION_DICT[self.precision] self.rcond = rcond - self.exclude_types = exclude_types - - self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + # order matters, should be place after the assignment of ntypes + self.reinit_exclude(exclude_types) net_dim_out = self._net_out_dim() # init constants @@ -363,6 +362,13 @@ def __init__( log.info("Set seed to %d in fitting net.", seed) torch.manual_seed(seed) + def reinit_exclude( + self, + exclude_types: List[int] = [], + ): + self.exclude_types = exclude_types + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + def serialize(self) -> dict: """Serialize the fitting to dict.""" return { diff --git a/deepmd/pt/utils/exclude_mask.py b/deepmd/pt/utils/exclude_mask.py index 74b1d8dc41..6df8df8dd0 100644 --- a/deepmd/pt/utils/exclude_mask.py +++ b/deepmd/pt/utils/exclude_mask.py @@ -22,6 +22,13 @@ def __init__( exclude_types: List[int] = [], ): super().__init__() + self.reinit(ntypes, exclude_types) + + def reinit( + self, + ntypes: int, + exclude_types: List[int] = [], + ): self.ntypes = ntypes self.exclude_types = exclude_types self.type_mask = np.array( @@ -62,6 +69,13 @@ def __init__( exclude_types: List[Tuple[int, int]] = [], ): super().__init__() + self.reinit(ntypes, exclude_types) + + def reinit( + self, + ntypes: int, + exclude_types: List[Tuple[int, int]] = [], + ): self.ntypes = ntypes self._exclude_types: Set[Tuple[int, int]] = set() for tt in exclude_types: diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index bb0d20ab02..451ac732f9 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools import unittest import numpy as np @@ -50,17 +51,27 @@ def test_self_consistency(self): mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) - md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) - args = [ - to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] - ] - ret0 = md0.forward_atomic(*args) - ret1 = md1.forward_atomic(*args) - np.testing.assert_allclose( - to_numpy_array(ret0["energy"]), - to_numpy_array(ret1["energy"]), - ) + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) + args = [ + to_torch_tensor(ii) + for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) def test_dp_consistency(self): rng = np.random.default_rng() @@ -85,7 +96,7 @@ def test_dp_consistency(self): to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] ret0 = md0.forward_atomic(*args0) - ret1 = md1.forward_atomic(*args1) + ret1 = md1.forward_common_atomic(*args1) np.testing.assert_allclose( ret0["energy"], to_numpy_array(ret1["energy"]), @@ -110,3 +121,45 @@ def test_jit(self): md0 = torch.jit.script(md0) self.assertEqual(md0.get_rcut(), self.rcut) self.assertEqual(md0.get_type_map(), type_map) + + def test_excl_consistency(self): + nf, nloc, nnei = self.nlist.shape + type_map = ["foo", "bar"] + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) + + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + # hacking! + md1.descriptor.reinit_exclude(pair_excl) + md1.fitting_net.reinit_exclude(atom_excl) + + args = [ + to_torch_tensor(ii) + for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) From 4545059013a795e660fa912de782db990491aed6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 20:29:37 +0800 Subject: [PATCH 2/8] support exclude in dp --- .../dpmodel/atomic_model/base_atomic_model.py | 84 ++++++++++++++++++- .../dpmodel/atomic_model/dp_atomic_model.py | 30 ++++--- .../atomic_model/linear_atomic_model.py | 33 ++++---- .../atomic_model/pairtab_atomic_model.py | 28 ++++--- deepmd/dpmodel/model/make_model.py | 2 +- deepmd/tf/model/model.py | 5 ++ .../common/dpmodel/test_dp_atomic_model.py | 15 ++-- source/tests/pt/model/test_dp_atomic_model.py | 2 +- 8 files changed, 153 insertions(+), 46 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index b9521cde8e..a905ed97ff 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,8 +1,90 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Tuple, +) + import numpy as np +from deepmd.dpmodel.utils import ( + AtomExcludeMask, + PairExcludeMask, +) + from .make_base_atomic_model import ( make_base_atomic_model, ) -BaseAtomicModel = make_base_atomic_model(np.ndarray) +BaseAtomicModel_ = make_base_atomic_model(np.ndarray) + + +class BaseAtomicModel(BaseAtomicModel_): + def __init__( + self, + atom_exclude_types: List[int] = [], + pair_exclude_types: List[Tuple[int, int]] = [], + ): + super().__init__() + self.reinit_atom_exclude(atom_exclude_types) + self.reinit_pair_exclude(pair_exclude_types) + + def reinit_atom_exclude( + self, + exclude_types: List[int] = [], + ): + self.atom_exclude_types = exclude_types + if exclude_types == []: + self.atom_excl = None + else: + self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types) + + def reinit_pair_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.pair_exclude_types = exclude_types + if exclude_types == []: + self.pair_excl = None + else: + self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types) + + def forward_common_atomic( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> Dict[str, np.ndarray]: + nf, nloc, nnei = nlist.shape + nf, nall = extended_atype.shape + atype = extended_atype[:, :nloc] + if self.pair_excl is not None: + pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) + # exclude neighbors in the nlist + nlist = np.where(pair_mask == 1, nlist, -1) + + ret_dict = self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + + if self.atom_excl is not None: + atom_mask = self.atom_excl.build_type_exclude_mask(atype) + for kk in ret_dict.keys(): + ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None] + + return ret_dict + + def serialize(self) -> dict: + return { + "atom_exclude_types": self.atom_exclude_types, + "pair_exclude_types": self.pair_exclude_types, + } diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index cd349749fa..96ef6d30ae 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -46,11 +46,12 @@ def __init__( descriptor, fitting, type_map: Optional[List[str]] = None, + **kwargs, ): - super().__init__() self.type_map = type_map self.descriptor = descriptor self.fitting = fitting + super().__init__(**kwargs) def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -132,14 +133,18 @@ def forward_atomic( return ret def serialize(self) -> dict: - return { - "@class": "Model", - "type": "standard", - "@version": 1, - "type_map": self.type_map, - "descriptor": self.descriptor.serialize(), - "fitting": self.fitting.serialize(), - } + dd = super().serialize() + dd.update( + { + "@class": "Model", + "type": "standard", + "@version": 1, + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "fitting": self.fitting.serialize(), + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPAtomicModel": @@ -147,9 +152,10 @@ def deserialize(cls, data) -> "DPAtomicModel": check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") - descriptor_obj = BaseDescriptor.deserialize(data["descriptor"]) - fitting_obj = BaseFitting.deserialize(data["fitting"]) - obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + type_map = data.pop("type_map", None) + obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) return obj def get_dim_fparam(self) -> int: diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 6d8aea499e..03c1249d4b 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -52,9 +52,9 @@ def __init__( models: List[BaseAtomicModel], **kwargs, ): - super().__init__() self.models = models self.mixed_types_list = [model.mixed_types() for model in self.models] + super().__init__(**kwargs) def mixed_types(self) -> bool: """If true, the model @@ -273,15 +273,19 @@ def __init__( self.smin_alpha = smin_alpha def serialize(self) -> dict: - return { - "@class": "Model", - "type": "zbl", - "@version": 1, - "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), - "sw_rmin": self.sw_rmin, - "sw_rmax": self.sw_rmax, - "smin_alpha": self.smin_alpha, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "type": "zbl", + "@version": 1, + "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), + "sw_rmin": self.sw_rmin, + "sw_rmax": self.sw_rmax, + "smin_alpha": self.smin_alpha, + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPZBLLinearAtomicModel": @@ -289,11 +293,11 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel": check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") - sw_rmin = data["sw_rmin"] - sw_rmax = data["sw_rmax"] - smin_alpha = data["smin_alpha"] + sw_rmin = data.pop("sw_rmin") + sw_rmax = data.pop("sw_rmax") + smin_alpha = data.pop("smin_alpha") - dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"]) + dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models")) return cls( dp_model=dp_model, @@ -301,6 +305,7 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel": sw_rmin=sw_rmin, sw_rmax=sw_rmax, smin_alpha=smin_alpha, + **data, ) def _compute_weight( diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index ddece80f2d..5469ee80d2 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -109,14 +109,18 @@ def mixed_types(self) -> bool: return True def serialize(self) -> dict: - return { - "@class": "Model", - "type": "pairtab", - "@version": 1, - "tab": self.tab.serialize(), - "rcut": self.rcut, - "sel": self.sel, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "type": "pairtab", + "@version": 1, + "tab": self.tab.serialize(), + "rcut": self.rcut, + "sel": self.sel, + } + ) + return dd @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": @@ -124,10 +128,10 @@ def deserialize(cls, data) -> "PairTabAtomicModel": check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") - rcut = data["rcut"] - sel = data["sel"] - tab = PairTab.deserialize(data["tab"]) - tab_model = cls(None, rcut, sel) + rcut = data.pop("rcut") + sel = data.pop("sel") + tab = PairTab.deserialize(data.pop("tab")) + tab_model = cls(None, rcut, sel, **data) tab_model.tab = tab tab_model.tab_info = tab_model.tab.tab_info tab_model.tab_data = tab_model.tab.tab_data diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 7928644061..fe30ad6011 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -192,7 +192,7 @@ def call_lower( nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.reshape(nframes, -1, 3) nlist = self.format_nlist(extended_coord, extended_atype, nlist) - atomic_ret = self.forward_atomic( + atomic_ret = self.forward_common_atomic( extended_coord, extended_atype, nlist, diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 889f7ccc4d..ca660f8e95 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -784,6 +784,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": check_version_compatibility(data.pop("@version", 1), 1, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) + data.pop("atom_exclude_types") + data.pop("pair_exclude_types") return cls( descriptor=descriptor, fitting_net=fitting, @@ -814,4 +816,7 @@ def serialize(self, suffix: str = "") -> dict: "type_map": self.type_map, "descriptor": self.descrpt.serialize(suffix=suffix), "fitting": self.fitting.serialize(suffix=suffix), + # not supported yet + "atom_exclude_types": [], + "pair_exclude_types": [], } diff --git a/source/tests/common/dpmodel/test_dp_atomic_model.py b/source/tests/common/dpmodel/test_dp_atomic_model.py index b32c8ae11a..f97299cf72 100644 --- a/source/tests/common/dpmodel/test_dp_atomic_model.py +++ b/source/tests/common/dpmodel/test_dp_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools import unittest import numpy as np @@ -38,10 +39,14 @@ def test_self_consistency( mixed_types=ds.mixed_types(), ) type_map = ["foo", "bar"] - md0 = DPAtomicModel(ds, ft, type_map=type_map) - md1 = DPAtomicModel.deserialize(md0.serialize()) - ret0 = md0.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) - ret1 = md1.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + md0 = DPAtomicModel(ds, ft, type_map=type_map) + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + md1 = DPAtomicModel.deserialize(md0.serialize()) - np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + ret0 = md0.forward_common_atomic(self.coord_ext, self.atype_ext, self.nlist) + ret1 = md1.forward_common_atomic(self.coord_ext, self.atype_ext, self.nlist) + + np.testing.assert_allclose(ret0["energy"], ret1["energy"]) diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index 451ac732f9..b62286f2c2 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -95,7 +95,7 @@ def test_dp_consistency(self): args1 = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] - ret0 = md0.forward_atomic(*args0) + ret0 = md0.forward_common_atomic(*args0) ret1 = md1.forward_common_atomic(*args1) np.testing.assert_allclose( ret0["energy"], From 5f497a5915ca27222773f776fd5d5e20229f5460 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 21:24:10 +0800 Subject: [PATCH 3/8] support model argument, add consistency check --- deepmd/dpmodel/model/model.py | 2 ++ .../model/atomic_model/base_atomic_model.py | 1 + deepmd/pt/model/model/__init__.py | 14 ++++++++++++- deepmd/utils/argcheck.py | 17 ++++++++++++++++ source/tests/consistent/model/test_ener.py | 20 +++++++++++++++++++ source/tests/pt/model/test_deeppot.py | 6 ++++++ source/tests/pt/test_dp_test.py | 3 +-- 7 files changed, 60 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 4a6e269f25..b21a8a2c78 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -38,4 +38,6 @@ def get_model(data: dict) -> DPModel: descriptor=descriptor, fitting=fitting, type_map=data["type_map"], + atom_exclude_types=data["atom_exclude_types"], + pair_exclude_types=data["pair_exclude_types"], ) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 4d22533baf..94a8f0b839 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -71,6 +71,7 @@ def forward_common_atomic( nf, nloc, nnei = nlist.shape nf, nall = extended_atype.shape atype = extended_atype[:, :nloc] + if self.pair_excl is not None: pair_mask = self.pair_excl(nlist, extended_atype) # exclude neighbors in the nlist diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 0dc9ae20af..8ae4d3966e 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -71,11 +71,15 @@ def get_zbl_model(model_params): rmin = model_params["sw_rmin"] rmax = model_params["sw_rmax"] + atom_exclude_types = model_params.get("atom_exclude_types", []) + pair_exclude_types = model_params.get("pair_exclude_types", []) return DPZBLModel( dp_model, pt_model, rmin, rmax, + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, ) @@ -98,8 +102,16 @@ def get_model(model_params): if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True fitting = Fitting(**fitting_net) + atom_exclude_types = model_params.get("atom_exclude_types", []) + pair_exclude_types = model_params.get("pair_exclude_types", []) - model = EnergyModel(descriptor, fitting, type_map=model_params["type_map"]) + model = EnergyModel( + descriptor, + fitting, + type_map=model_params["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) model.model_def_script = json.dumps(model_params) return model diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index dbe4881952..4be56b4acf 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -866,6 +866,9 @@ def model_args(exclude_hybrid=False): doc_srtab_add_bias = "Whether add energy bias from the statistics of the data to short-range tabulated atomic energy. It only takes effect when `use_srtab` is provided." doc_compress_config = "Model compression configurations" doc_spin = "The settings for systems with spin." + doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types" + doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other." + hybrid_models = [] if not exclude_hybrid: hybrid_models.extend( @@ -904,6 +907,20 @@ def model_args(exclude_hybrid=False): Argument("smin_alpha", float, optional=True, doc=doc_smin_alpha), Argument("sw_rmin", float, optional=True, doc=doc_sw_rmin), Argument("sw_rmax", float, optional=True, doc=doc_sw_rmax), + Argument( + "pair_exclude_types", + list, + optional=True, + default=[], + doc=doc_pair_exclude_types, + ), + Argument( + "atom_exclude_types", + list, + optional=True, + default=[], + doc=doc_atom_exclude_types, + ), Argument( "srtab_add_bias", bool, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index b3aa778ca0..da5033a3b6 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -17,6 +17,7 @@ INSTALLED_PT, INSTALLED_TF, CommonTest, + parameterized, ) from .common import ( ModelTest, @@ -37,11 +38,24 @@ ) +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) class TestEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param return { "type_map": ["O", "H"], + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, "descriptor": { "type": "se_e2_a", "sel": [20, 20], @@ -73,6 +87,12 @@ def data(self) -> dict: pt_class = EnergyModelPT args = model_args() + def skip_tf(self): + return ( + self.data["pair_exclude_types"] != [] + or self.data["atom_exclude_types"] != [] + ) + def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" data = data.copy() diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index 334206a2b0..6584ce2974 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json +import os import unittest from argparse import ( Namespace, @@ -49,6 +50,11 @@ def setUp(self): trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0) self.model = "model.pt" + def tearDown(self): + for f in os.listdir("."): + if f in ["lcurve.out", self.input_json]: + os.remove(f) + def test_dp_test(self): dp = DeepPot(str(self.model)) cell = np.array( diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index 8d7dc9cd58..08bd2ce623 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -62,11 +62,10 @@ def tearDown(self): for f in os.listdir("."): if f.startswith("model") and f.endswith(".pt"): os.remove(f) - if f in ["lcurve.out"]: + if f in ["lcurve.out", self.input_json]: os.remove(f) if f in ["stat_files"]: shutil.rmtree(f) - os.remove(self.input_json) if __name__ == "__main__": From dd0ebf07f4fb64ef0ff3a87b91cd8b4ce4058215 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 21:27:01 +0800 Subject: [PATCH 4/8] fix warnings --- deepmd/dpmodel/atomic_model/base_atomic_model.py | 3 +-- deepmd/pt/model/atomic_model/base_atomic_model.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index a905ed97ff..09d33203a1 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -59,8 +59,7 @@ def forward_common_atomic( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, ) -> Dict[str, np.ndarray]: - nf, nloc, nnei = nlist.shape - nf, nall = extended_atype.shape + _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 94a8f0b839..d6de3dfc88 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -68,8 +68,7 @@ def forward_common_atomic( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: - nf, nloc, nnei = nlist.shape - nf, nall = extended_atype.shape + _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: From c469c333fefc16900d7aedeb4b924dded0afe7f2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 21:28:25 +0800 Subject: [PATCH 5/8] fix warnings --- source/tests/pt/model/test_dp_atomic_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index b62286f2c2..88bb3ab763 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -123,7 +123,6 @@ def test_jit(self): self.assertEqual(md0.get_type_map(), type_map) def test_excl_consistency(self): - nf, nloc, nnei = self.nlist.shape type_map = ["foo", "bar"] # test the case of exclusion From e586bf2f270d42429eb1ccd3106deba512f28dd1 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 22:09:52 +0800 Subject: [PATCH 6/8] fix test --- deepmd/dpmodel/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index b21a8a2c78..6f06785c56 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -38,6 +38,6 @@ def get_model(data: dict) -> DPModel: descriptor=descriptor, fitting=fitting, type_map=data["type_map"], - atom_exclude_types=data["atom_exclude_types"], - pair_exclude_types=data["pair_exclude_types"], + atom_exclude_types=data.get("atom_exclude_types", []), + pair_exclude_types=data.get("pair_exclude_types", []), ) From de220f20c20f70acfa21799016bdf9083681d804 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 28 Feb 2024 22:16:11 +0800 Subject: [PATCH 7/8] rm data file --- source/tests/consistent/io/test_io.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index be599b0805..71e4002128 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -57,6 +57,12 @@ def save_data_to_model(self, model_file: str, data: dict) -> None: out_hook = out_backend.deserialize_hook out_hook(model_file, data) + def tearDown(self): + prefix = "test_consistent_io_" + self.__class__.__name__.lower() + for ii in Path(".").glob(prefix + ".*"): + if Path(ii).exists(): + Path(ii).unlink() + def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name in ("tensorflow", "pytorch", "dpmodel"): @@ -173,3 +179,6 @@ def setUp(self): "backend": "test", "model_def_script": model_def_script, } + + def tearDown(self): + IOTest.tearDown(self) From 8a9d82996fb4691472018d6616ef86dcc254ca86 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 29 Feb 2024 13:29:12 +0800 Subject: [PATCH 8/8] add pt only support to doc --- deepmd/utils/argcheck.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d676f23245..8e3196cba1 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1242,14 +1242,14 @@ def model_args(exclude_hybrid=False): list, optional=True, default=[], - doc=doc_pair_exclude_types, + doc=doc_only_pt_supported + doc_pair_exclude_types, ), Argument( "atom_exclude_types", list, optional=True, default=[], - doc=doc_atom_exclude_types, + doc=doc_only_pt_supported + doc_atom_exclude_types, ), Argument( "srtab_add_bias",