diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index b9521cde8e..09d33203a1 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,8 +1,89 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Tuple, +) + import numpy as np +from deepmd.dpmodel.utils import ( + AtomExcludeMask, + PairExcludeMask, +) + from .make_base_atomic_model import ( make_base_atomic_model, ) -BaseAtomicModel = make_base_atomic_model(np.ndarray) +BaseAtomicModel_ = make_base_atomic_model(np.ndarray) + + +class BaseAtomicModel(BaseAtomicModel_): + def __init__( + self, + atom_exclude_types: List[int] = [], + pair_exclude_types: List[Tuple[int, int]] = [], + ): + super().__init__() + self.reinit_atom_exclude(atom_exclude_types) + self.reinit_pair_exclude(pair_exclude_types) + + def reinit_atom_exclude( + self, + exclude_types: List[int] = [], + ): + self.atom_exclude_types = exclude_types + if exclude_types == []: + self.atom_excl = None + else: + self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types) + + def reinit_pair_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.pair_exclude_types = exclude_types + if exclude_types == []: + self.pair_excl = None + else: + self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types) + + def forward_common_atomic( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> Dict[str, np.ndarray]: + _, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + if self.pair_excl is not None: + pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) + # exclude neighbors in the nlist + nlist = np.where(pair_mask == 1, nlist, -1) + + ret_dict = self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + + if self.atom_excl is not None: + atom_mask = self.atom_excl.build_type_exclude_mask(atype) + for kk in ret_dict.keys(): + ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None] + + return ret_dict + + def serialize(self) -> dict: + return { + "atom_exclude_types": self.atom_exclude_types, + "pair_exclude_types": self.pair_exclude_types, + } diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index cd349749fa..96ef6d30ae 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -46,11 +46,12 @@ def __init__( descriptor, fitting, type_map: Optional[List[str]] = None, + **kwargs, ): - super().__init__() self.type_map = type_map self.descriptor = descriptor self.fitting = fitting + super().__init__(**kwargs) def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -132,14 +133,18 @@ def forward_atomic( return ret def serialize(self) -> dict: - return { - "@class": "Model", - "type": "standard", - "@version": 1, - "type_map": self.type_map, - "descriptor": self.descriptor.serialize(), - "fitting": self.fitting.serialize(), - } + dd = super().serialize() + dd.update( + { + "@class": "Model", + "type": "standard", + "@version": 1, + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "fitting": self.fitting.serialize(), + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPAtomicModel": @@ -147,9 +152,10 @@ def deserialize(cls, data) -> "DPAtomicModel": check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") - descriptor_obj = BaseDescriptor.deserialize(data["descriptor"]) - fitting_obj = BaseFitting.deserialize(data["fitting"]) - obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + type_map = data.pop("type_map", None) + obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) return obj def get_dim_fparam(self) -> int: diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 6d8aea499e..03c1249d4b 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -52,9 +52,9 @@ def __init__( models: List[BaseAtomicModel], **kwargs, ): - super().__init__() self.models = models self.mixed_types_list = [model.mixed_types() for model in self.models] + super().__init__(**kwargs) def mixed_types(self) -> bool: """If true, the model @@ -273,15 +273,19 @@ def __init__( self.smin_alpha = smin_alpha def serialize(self) -> dict: - return { - "@class": "Model", - "type": "zbl", - "@version": 1, - "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), - "sw_rmin": self.sw_rmin, - "sw_rmax": self.sw_rmax, - "smin_alpha": self.smin_alpha, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "type": "zbl", + "@version": 1, + "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), + "sw_rmin": self.sw_rmin, + "sw_rmax": self.sw_rmax, + "smin_alpha": self.smin_alpha, + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPZBLLinearAtomicModel": @@ -289,11 +293,11 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel": check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") - sw_rmin = data["sw_rmin"] - sw_rmax = data["sw_rmax"] - smin_alpha = data["smin_alpha"] + sw_rmin = data.pop("sw_rmin") + sw_rmax = data.pop("sw_rmax") + smin_alpha = data.pop("smin_alpha") - dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"]) + dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models")) return cls( dp_model=dp_model, @@ -301,6 +305,7 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel": sw_rmin=sw_rmin, sw_rmax=sw_rmax, smin_alpha=smin_alpha, + **data, ) def _compute_weight( diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index d4186c990d..df6e39dd2e 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -48,6 +48,16 @@ def get_rcut(self) -> float: def get_type_map(self) -> Optional[List[str]]: """Get the type map.""" + def get_ntypes(self) -> int: + """Get the number of atom types.""" + tmap = self.get_type_map() + if tmap is not None: + return len(tmap) + else: + raise ValueError( + "cannot infer the number of types from a None type map" + ) + @abstractmethod def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index ddece80f2d..5469ee80d2 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -109,14 +109,18 @@ def mixed_types(self) -> bool: return True def serialize(self) -> dict: - return { - "@class": "Model", - "type": "pairtab", - "@version": 1, - "tab": self.tab.serialize(), - "rcut": self.rcut, - "sel": self.sel, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "type": "pairtab", + "@version": 1, + "tab": self.tab.serialize(), + "rcut": self.rcut, + "sel": self.sel, + } + ) + return dd @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": @@ -124,10 +128,10 @@ def deserialize(cls, data) -> "PairTabAtomicModel": check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("@class") data.pop("type") - rcut = data["rcut"] - sel = data["sel"] - tab = PairTab.deserialize(data["tab"]) - tab_model = cls(None, rcut, sel) + rcut = data.pop("rcut") + sel = data.pop("sel") + tab = PairTab.deserialize(data.pop("tab")) + tab_model = cls(None, rcut, sel, **data) tab_model.tab = tab tab_model.tab_info = tab_model.tab.tab_info tab_model.tab_data = tab_model.tab.tab_data diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 1261906148..e8b1ecc390 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -213,7 +213,7 @@ def call_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam - atomic_ret = self.forward_atomic( + atomic_ret = self.forward_common_atomic( cc_ext, extended_atype, nlist, diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 4a6e269f25..6f06785c56 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -38,4 +38,6 @@ def get_model(data: dict) -> DPModel: descriptor=descriptor, fitting=fitting, type_map=data["type_map"], + atom_exclude_types=data.get("atom_exclude_types", []), + pair_exclude_types=data.get("pair_exclude_types", []), ) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 1e5f976baf..d6de3dfc88 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -1,16 +1,56 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Tuple, +) + import torch from deepmd.dpmodel.atomic_model import ( make_base_atomic_model, ) +from deepmd.pt.utils import ( + AtomExcludeMask, + PairExcludeMask, +) BaseAtomicModel_ = make_base_atomic_model(torch.Tensor) class BaseAtomicModel(BaseAtomicModel_): + def __init__( + self, + atom_exclude_types: List[int] = [], + pair_exclude_types: List[Tuple[int, int]] = [], + ): + super().__init__() + self.reinit_atom_exclude(atom_exclude_types) + self.reinit_pair_exclude(pair_exclude_types) + + def reinit_atom_exclude( + self, + exclude_types: List[int] = [], + ): + self.atom_exclude_types = exclude_types + if exclude_types == []: + self.atom_excl = None + else: + self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types) + + def reinit_pair_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.pair_exclude_types = exclude_types + if exclude_types == []: + self.pair_excl = None + else: + self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types) + # export public methods that are not abstract get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel) get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei) @@ -18,3 +58,42 @@ class BaseAtomicModel(BaseAtomicModel_): @torch.jit.export def get_model_def_script(self) -> str: return self.model_def_script + + def forward_common_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + _, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + + if self.pair_excl is not None: + pair_mask = self.pair_excl(nlist, extended_atype) + # exclude neighbors in the nlist + nlist = torch.where(pair_mask == 1, nlist, -1) + + ret_dict = self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + + if self.atom_excl is not None: + atom_mask = self.atom_excl(atype) + for kk in ret_dict.keys(): + ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None] + + return ret_dict + + def serialize(self) -> dict: + return { + "atom_exclude_types": self.atom_exclude_types, + "pair_exclude_types": self.pair_exclude_types, + } diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index d2c1743d30..63e91ff428 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -49,8 +49,14 @@ class DPAtomicModel(torch.nn.Module, BaseAtomicModel): For example `type_map[1]` gives the name of the type 1. """ - def __init__(self, descriptor, fitting, type_map: Optional[List[str]]): - super().__init__() + def __init__( + self, + descriptor, + fitting, + type_map: Optional[List[str]], + **kwargs, + ): + torch.nn.Module.__init__(self) self.model_def_script = "" ntypes = len(type_map) self.type_map = type_map @@ -59,6 +65,8 @@ def __init__(self, descriptor, fitting, type_map: Optional[List[str]]): self.rcut = self.descriptor.get_rcut() self.sel = self.descriptor.get_sel() self.fitting_net = fitting + # order matters ntypes and type_map should be initialized first. + BaseAtomicModel.__init__(self, **kwargs) def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -95,22 +103,29 @@ def mixed_types(self) -> bool: return self.descriptor.mixed_types() def serialize(self) -> dict: - return { - "@class": "Model", - "type": "standard", - "@version": 1, - "type_map": self.type_map, - "descriptor": self.descriptor.serialize(), - "fitting": self.fitting_net.serialize(), - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "standard", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "fitting": self.fitting_net.serialize(), + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) - descriptor_obj = BaseDescriptor.deserialize(data["descriptor"]) - fitting_obj = BaseFitting.deserialize(data["fitting"]) - obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) + data.pop("@class", None) + data.pop("type", None) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + type_map = data.pop("type_map", None) + obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data) return obj def forward_atomic( diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 52f5f1d13c..5efbe533da 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -54,10 +54,11 @@ def __init__( models: List[BaseAtomicModel], **kwargs, ): - super().__init__() + torch.nn.Module.__init__(self) self.models = torch.nn.ModuleList(models) self.atomic_bias = None self.mixed_types_list = [model.mixed_types() for model in self.models] + BaseAtomicModel.__init__(self, **kwargs) def mixed_types(self) -> bool: """If true, the model @@ -307,32 +308,39 @@ def __init__( self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE) def serialize(self) -> dict: - return { - "@class": "Model", - "type": "zbl", - "@version": 1, - "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), - "sw_rmin": self.sw_rmin, - "sw_rmax": self.sw_rmax, - "smin_alpha": self.smin_alpha, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "zbl", + "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), + "sw_rmin": self.sw_rmin, + "sw_rmax": self.sw_rmax, + "smin_alpha": self.smin_alpha, + } + ) + return dd @classmethod def deserialize(cls, data) -> "DPZBLLinearAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) - sw_rmin = data["sw_rmin"] - sw_rmax = data["sw_rmax"] - smin_alpha = data["smin_alpha"] + sw_rmin = data.pop("sw_rmin") + sw_rmax = data.pop("sw_rmax") + smin_alpha = data.pop("smin_alpha") - dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"]) + dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models")) + data.pop("@class", None) + data.pop("type", None) return cls( dp_model=dp_model, zbl_model=zbl_model, sw_rmin=sw_rmin, sw_rmax=sw_rmax, smin_alpha=smin_alpha, + **data, ) def _compute_weight( diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index c0b7c65d7a..47a20d3be9 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -52,11 +52,12 @@ class PairTabAtomicModel(torch.nn.Module, BaseAtomicModel): def __init__( self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs ): - super().__init__() + torch.nn.Module.__init__(self) self.model_def_script = "" self.tab_file = tab_file self.rcut = rcut self.tab = self._set_pairtab(tab_file, rcut) + BaseAtomicModel.__init__(self, **kwargs) # handle deserialization with no input file if self.tab_file is not None: @@ -125,23 +126,29 @@ def mixed_types(self) -> bool: return True def serialize(self) -> dict: - return { - "@class": "Model", - "type": "pairtab", - "@version": 1, - "tab": self.tab.serialize(), - "rcut": self.rcut, - "sel": self.sel, - } + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "pairtab", + "tab": self.tab.serialize(), + "rcut": self.rcut, + "sel": self.sel, + } + ) + return dd @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) - rcut = data["rcut"] - sel = data["sel"] - tab = PairTab.deserialize(data["tab"]) - tab_model = cls(None, rcut, sel) + rcut = data.pop("rcut") + sel = data.pop("sel") + tab = PairTab.deserialize(data.pop("tab")) + data.pop("@class", None) + data.pop("type", None) + tab_model = cls(None, rcut, sel, **data) tab_model.tab = tab tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info)) tab_model.register_buffer("tab_data", torch.from_numpy(tab_model.tab.tab_data)) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index bb3cd30ff9..fc2cf60531 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -136,6 +136,13 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for descriptor elements.""" return self.sea.compute_input_stats(merged, path) + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + """Update the type exclusions.""" + self.sea.reinit_exclude(exclude_types) + def forward( self, coord_ext: torch.Tensor, @@ -288,10 +295,10 @@ def __init__( self.prec = PRECISION_DICT[self.precision] self.resnet_dt = resnet_dt self.old_impl = old_impl - self.exclude_types = exclude_types self.ntypes = len(sel) - self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types) self.type_one_side = type_one_side + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) self.sel = sel self.sec = torch.tensor( @@ -424,6 +431,13 @@ def get_stats(self) -> Dict[str, StatItem]: ) return self.stats + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def forward( self, nlist: torch.Tensor, diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index b823a051f5..87eb391a7e 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -71,11 +71,15 @@ def get_zbl_model(model_params): rmin = model_params["sw_rmin"] rmax = model_params["sw_rmax"] + atom_exclude_types = model_params.get("atom_exclude_types", []) + pair_exclude_types = model_params.get("pair_exclude_types", []) return DPZBLModel( dp_model, pt_model, rmin, rmax, + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, ) @@ -98,7 +102,16 @@ def get_model(model_params): if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True fitting = BaseFitting(**fitting_net) - model = DPModel(descriptor, fitting, type_map=model_params["type_map"]) + atom_exclude_types = model_params.get("atom_exclude_types", []) + pair_exclude_types = model_params.get("pair_exclude_types", []) + + model = DPModel( + descriptor, + fitting, + type_map=model_params["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) model.model_def_script = json.dumps(model_params) return model diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 3efd3fb046..98f0a18241 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -208,7 +208,7 @@ def forward_common_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam - atomic_ret = self.forward_atomic( + atomic_ret = self.forward_common_atomic( cc_ext, extended_atype, nlist, diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 20876d9be7..8e8338210f 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -280,9 +280,8 @@ def __init__( self.precision = precision self.prec = PRECISION_DICT[self.precision] self.rcond = rcond - self.exclude_types = exclude_types - - self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + # order matters, should be place after the assignment of ntypes + self.reinit_exclude(exclude_types) net_dim_out = self._net_out_dim() # init constants @@ -358,6 +357,13 @@ def __init__( log.info("Set seed to %d in fitting net.", seed) torch.manual_seed(seed) + def reinit_exclude( + self, + exclude_types: List[int] = [], + ): + self.exclude_types = exclude_types + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + def serialize(self) -> dict: """Serialize the fitting to dict.""" return { diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 889f7ccc4d..ca660f8e95 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -784,6 +784,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": check_version_compatibility(data.pop("@version", 1), 1, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) + data.pop("atom_exclude_types") + data.pop("pair_exclude_types") return cls( descriptor=descriptor, fitting_net=fitting, @@ -814,4 +816,7 @@ def serialize(self, suffix: str = "") -> dict: "type_map": self.type_map, "descriptor": self.descrpt.serialize(suffix=suffix), "fitting": self.fitting.serialize(suffix=suffix), + # not supported yet + "atom_exclude_types": [], + "pair_exclude_types": [], } diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8366f7bb38..8e3196cba1 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1182,6 +1182,9 @@ def model_args(exclude_hybrid=False): doc_srtab_add_bias = "Whether add energy bias from the statistics of the data to short-range tabulated atomic energy. It only takes effect when `use_srtab` is provided." doc_compress_config = "Model compression configurations" doc_spin = "The settings for systems with spin." + doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types" + doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other." + hybrid_models = [] if not exclude_hybrid: hybrid_models.extend( @@ -1234,6 +1237,20 @@ def model_args(exclude_hybrid=False): Argument( "sw_rmax", float, optional=True, doc=doc_only_tf_supported + doc_sw_rmax ), + Argument( + "pair_exclude_types", + list, + optional=True, + default=[], + doc=doc_only_pt_supported + doc_pair_exclude_types, + ), + Argument( + "atom_exclude_types", + list, + optional=True, + default=[], + doc=doc_only_pt_supported + doc_atom_exclude_types, + ), Argument( "srtab_add_bias", bool, diff --git a/source/tests/common/dpmodel/test_dp_atomic_model.py b/source/tests/common/dpmodel/test_dp_atomic_model.py index b32c8ae11a..f97299cf72 100644 --- a/source/tests/common/dpmodel/test_dp_atomic_model.py +++ b/source/tests/common/dpmodel/test_dp_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools import unittest import numpy as np @@ -38,10 +39,14 @@ def test_self_consistency( mixed_types=ds.mixed_types(), ) type_map = ["foo", "bar"] - md0 = DPAtomicModel(ds, ft, type_map=type_map) - md1 = DPAtomicModel.deserialize(md0.serialize()) - ret0 = md0.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) - ret1 = md1.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + md0 = DPAtomicModel(ds, ft, type_map=type_map) + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + md1 = DPAtomicModel.deserialize(md0.serialize()) - np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + ret0 = md0.forward_common_atomic(self.coord_ext, self.atype_ext, self.nlist) + ret1 = md1.forward_common_atomic(self.coord_ext, self.atype_ext, self.nlist) + + np.testing.assert_allclose(ret0["energy"], ret1["energy"]) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index be599b0805..71e4002128 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -57,6 +57,12 @@ def save_data_to_model(self, model_file: str, data: dict) -> None: out_hook = out_backend.deserialize_hook out_hook(model_file, data) + def tearDown(self): + prefix = "test_consistent_io_" + self.__class__.__name__.lower() + for ii in Path(".").glob(prefix + ".*"): + if Path(ii).exists(): + Path(ii).unlink() + def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name in ("tensorflow", "pytorch", "dpmodel"): @@ -173,3 +179,6 @@ def setUp(self): "backend": "test", "model_def_script": model_def_script, } + + def tearDown(self): + IOTest.tearDown(self) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index b3aa778ca0..da5033a3b6 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -17,6 +17,7 @@ INSTALLED_PT, INSTALLED_TF, CommonTest, + parameterized, ) from .common import ( ModelTest, @@ -37,11 +38,24 @@ ) +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) class TestEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param return { "type_map": ["O", "H"], + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, "descriptor": { "type": "se_e2_a", "sel": [20, 20], @@ -73,6 +87,12 @@ def data(self) -> dict: pt_class = EnergyModelPT args = model_args() + def skip_tf(self): + return ( + self.data["pair_exclude_types"] != [] + or self.data["atom_exclude_types"] != [] + ) + def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" data = data.copy() diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index 102e1f6b0c..697ebb6411 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json +import os import unittest from argparse import ( Namespace, @@ -53,6 +54,11 @@ def setUp(self): trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0) self.model = "model.pt" + def tearDown(self): + for f in os.listdir("."): + if f in ["lcurve.out", self.input_json]: + os.remove(f) + def test_dp_test(self): dp = DeepPot(str(self.model)) cell = np.array( diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index bb0d20ab02..88bb3ab763 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools import unittest import numpy as np @@ -50,17 +51,27 @@ def test_self_consistency(self): mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) - md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) - args = [ - to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] - ] - ret0 = md0.forward_atomic(*args) - ret1 = md1.forward_atomic(*args) - np.testing.assert_allclose( - to_numpy_array(ret0["energy"]), - to_numpy_array(ret1["energy"]), - ) + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) + args = [ + to_torch_tensor(ii) + for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) def test_dp_consistency(self): rng = np.random.default_rng() @@ -84,8 +95,8 @@ def test_dp_consistency(self): args1 = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] - ret0 = md0.forward_atomic(*args0) - ret1 = md1.forward_atomic(*args1) + ret0 = md0.forward_common_atomic(*args0) + ret1 = md1.forward_common_atomic(*args1) np.testing.assert_allclose( ret0["energy"], to_numpy_array(ret1["energy"]), @@ -110,3 +121,44 @@ def test_jit(self): md0 = torch.jit.script(md0) self.assertEqual(md0.get_rcut(), self.rcut) self.assertEqual(md0.get_type_map(), type_map) + + def test_excl_consistency(self): + type_map = ["foo", "bar"] + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) + + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + # hacking! + md1.descriptor.reinit_exclude(pair_excl) + md1.fitting_net.reinit_exclude(atom_excl) + + args = [ + to_torch_tensor(ii) + for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index 8d7dc9cd58..08bd2ce623 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -62,11 +62,10 @@ def tearDown(self): for f in os.listdir("."): if f.startswith("model") and f.endswith(".pt"): os.remove(f) - if f in ["lcurve.out"]: + if f in ["lcurve.out", self.input_json]: os.remove(f) if f in ["stat_files"]: shutil.rmtree(f) - os.remove(self.input_json) if __name__ == "__main__":