From e7925f38ba87b4cc3514121941117ab4aad205e3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 21 Nov 2024 00:31:14 -0500 Subject: [PATCH] feat(jax): energy, dos, dipole, polar, property atomic model & model (#4384) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced several new atomic model classes: `DPAtomicModelDipole`, `DPAtomicModelDOS`, `DPAtomicModelEnergy`, `DPAtomicModelPolar`, and `DPAtomicModelProperty`. - Added new model classes: `DipoleModel`, `DOSModel`, `PolarModel`, and `PropertyModel` for enhanced functionalities. - Implemented a new function to create JAX-compatible models from existing DP models, improving integration with JAX. - **Bug Fixes** - Enhanced test suite to support JAX backend, ensuring compatibility and flexibility in testing. - **Documentation** - Updated public API to include new models and functionalities. --------- Signed-off-by: Jinzhe Zeng --- .../atomic_model/polar_atomic_model.py | 17 +- .../jax/atomic_model/dipole_atomic_model.py | 11 + deepmd/jax/atomic_model/dos_atomic_model.py | 11 + deepmd/jax/atomic_model/dp_atomic_model.py | 78 ++++--- .../jax/atomic_model/energy_atomic_model.py | 11 + deepmd/jax/atomic_model/polar_atomic_model.py | 11 + .../jax/atomic_model/property_atomic_model.py | 13 ++ deepmd/jax/model/__init__.py | 16 ++ deepmd/jax/model/dipole_model.py | 17 ++ deepmd/jax/model/dos_model.py | 16 ++ deepmd/jax/model/dp_model.py | 86 ++++++++ deepmd/jax/model/ener_model.py | 64 +----- deepmd/jax/model/model.py | 2 + deepmd/jax/model/polar_model.py | 17 ++ deepmd/jax/model/property_model.py | 19 ++ source/tests/consistent/model/test_dipole.py | 206 ++++++++++++++++++ source/tests/consistent/model/test_dos.py | 14 +- source/tests/consistent/model/test_polar.py | 200 +++++++++++++++++ .../tests/consistent/model/test_property.py | 196 +++++++++++++++++ 19 files changed, 910 insertions(+), 95 deletions(-) create mode 100644 deepmd/jax/atomic_model/dipole_atomic_model.py create mode 100644 deepmd/jax/atomic_model/dos_atomic_model.py create mode 100644 deepmd/jax/atomic_model/energy_atomic_model.py create mode 100644 deepmd/jax/atomic_model/polar_atomic_model.py create mode 100644 deepmd/jax/atomic_model/property_atomic_model.py create mode 100644 deepmd/jax/model/dipole_model.py create mode 100644 deepmd/jax/model/dos_model.py create mode 100644 deepmd/jax/model/dp_model.py create mode 100644 deepmd/jax/model/polar_model.py create mode 100644 deepmd/jax/model/property_model.py create mode 100644 source/tests/consistent/model/test_dipole.py create mode 100644 source/tests/consistent/model/test_polar.py create mode 100644 source/tests/consistent/model/test_property.py diff --git a/deepmd/dpmodel/atomic_model/polar_atomic_model.py b/deepmd/dpmodel/atomic_model/polar_atomic_model.py index 6e1d32ff35..bc7860491c 100644 --- a/deepmd/dpmodel/atomic_model/polar_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/polar_atomic_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np from deepmd.dpmodel.fitting.polarizability_fitting import ( @@ -34,29 +35,29 @@ def apply_out_stat( The atom types. nf x nloc """ + xp = array_api_compat.array_namespace(atype) out_bias, out_std = self._fetch_out_stat(self.bias_keys) - if self.fitting_net.shift_diag: + if self.fitting.shift_diag: nframes, nloc = atype.shape dtype = out_bias[self.bias_keys[0]].dtype for kk in self.bias_keys: ntypes = out_bias[kk].shape[0] - temp = np.zeros(ntypes, dtype=dtype) - temp = np.mean( - np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2), + temp = xp.mean( + xp.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2), axis=1, ) modified_bias = temp[atype] # (nframes, nloc, 1) modified_bias = ( - modified_bias[..., np.newaxis] * (self.fitting_net.scale[atype]) + modified_bias[..., xp.newaxis] * (self.fitting.scale[atype]) ) - eye = np.eye(3, dtype=dtype) - eye = np.tile(eye, (nframes, nloc, 1, 1)) + eye = xp.eye(3, dtype=dtype) + eye = xp.tile(eye, (nframes, nloc, 1, 1)) # (nframes, nloc, 3, 3) - modified_bias = modified_bias[..., np.newaxis] * eye + modified_bias = modified_bias[..., xp.newaxis] * eye # nf x nloc x odims, out_bias: ntypes x odims ret[kk] = ret[kk] + modified_bias diff --git a/deepmd/jax/atomic_model/dipole_atomic_model.py b/deepmd/jax/atomic_model/dipole_atomic_model.py new file mode 100644 index 0000000000..9993efa144 --- /dev/null +++ b/deepmd/jax/atomic_model/dipole_atomic_model.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.dipole_atomic_model import ( + DPDipoleAtomicModel as DPAtomicModelDipoleDP, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + make_jax_dp_atomic_model_from_dpmodel, +) + + +class DPAtomicModelDipole(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDipoleDP)): + pass diff --git a/deepmd/jax/atomic_model/dos_atomic_model.py b/deepmd/jax/atomic_model/dos_atomic_model.py new file mode 100644 index 0000000000..b11542de2a --- /dev/null +++ b/deepmd/jax/atomic_model/dos_atomic_model.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.dos_atomic_model import ( + DPDOSAtomicModel as DPAtomicModelDOSDP, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + make_jax_dp_atomic_model_from_dpmodel, +) + + +class DPAtomicModelDOS(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDOSDP)): + pass diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py index 5898fd3ff8..adfc22c6fa 100644 --- a/deepmd/jax/atomic_model/dp_atomic_model.py +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -23,31 +23,53 @@ ) -@flax_module -class DPAtomicModel(DPAtomicModelDP): - base_descriptor_cls = BaseDescriptor - """The base descriptor class.""" - base_fitting_cls = BaseFitting - """The base fitting class.""" - - def __setattr__(self, name: str, value: Any) -> None: - value = base_atomic_model_set_attr(name, value) - return super().__setattr__(name, value) - - def forward_common_atomic( - self, - extended_coord: jnp.ndarray, - extended_atype: jnp.ndarray, - nlist: jnp.ndarray, - mapping: Optional[jnp.ndarray] = None, - fparam: Optional[jnp.ndarray] = None, - aparam: Optional[jnp.ndarray] = None, - ) -> dict[str, jnp.ndarray]: - return super().forward_common_atomic( - extended_coord, - extended_atype, - jax.lax.stop_gradient(nlist), - mapping=mapping, - fparam=fparam, - aparam=aparam, - ) +def make_jax_dp_atomic_model_from_dpmodel( + dpmodel_atomic_model: type[DPAtomicModelDP], +) -> type[DPAtomicModelDP]: + """Make a JAX backend DP atomic model from a DPModel backend DP atomic model. + + Parameters + ---------- + dpmodel_atomic_model : type[DPAtomicModelDP] + The DPModel backend DP atomic model. + + Returns + ------- + type[DPAtomicModel] + The JAX backend DP atomic model. + """ + + @flax_module + class jax_atomic_model(dpmodel_atomic_model): + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + + return jax_atomic_model + + +class DPAtomicModel(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDP)): + pass diff --git a/deepmd/jax/atomic_model/energy_atomic_model.py b/deepmd/jax/atomic_model/energy_atomic_model.py new file mode 100644 index 0000000000..34c1b26341 --- /dev/null +++ b/deepmd/jax/atomic_model/energy_atomic_model.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.energy_atomic_model import ( + DPEnergyAtomicModel as DPAtomicModelEnergyDP, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + make_jax_dp_atomic_model_from_dpmodel, +) + + +class DPAtomicModelEnergy(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelEnergyDP)): + pass diff --git a/deepmd/jax/atomic_model/polar_atomic_model.py b/deepmd/jax/atomic_model/polar_atomic_model.py new file mode 100644 index 0000000000..c4d8319d3e --- /dev/null +++ b/deepmd/jax/atomic_model/polar_atomic_model.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.polar_atomic_model import ( + DPPolarAtomicModel as DPAtomicModelPolarDP, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + make_jax_dp_atomic_model_from_dpmodel, +) + + +class DPAtomicModelPolar(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPolarDP)): + pass diff --git a/deepmd/jax/atomic_model/property_atomic_model.py b/deepmd/jax/atomic_model/property_atomic_model.py new file mode 100644 index 0000000000..170ab7c3ef --- /dev/null +++ b/deepmd/jax/atomic_model/property_atomic_model.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.property_atomic_model import ( + DPPropertyAtomicModel as DPAtomicModelPropertyDP, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + make_jax_dp_atomic_model_from_dpmodel, +) + + +class DPAtomicModelProperty( + make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelPropertyDP) +): + pass diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py index bba5bc766a..fd31999aab 100644 --- a/deepmd/jax/model/__init__.py +++ b/deepmd/jax/model/__init__.py @@ -1,12 +1,28 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dipole_model import ( + DipoleModel, +) +from .dos_model import ( + DOSModel, +) from .dp_zbl_model import ( DPZBLLinearEnergyAtomicModel, ) from .ener_model import ( EnergyModel, ) +from .polar_model import ( + PolarModel, +) +from .property_model import ( + PropertyModel, +) __all__ = [ "EnergyModel", "DPZBLLinearEnergyAtomicModel", + "DOSModel", + "DipoleModel", + "PolarModel", + "PropertyModel", ] diff --git a/deepmd/jax/model/dipole_model.py b/deepmd/jax/model/dipole_model.py new file mode 100644 index 0000000000..e4734131c9 --- /dev/null +++ b/deepmd/jax/model/dipole_model.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.model.dipole_model import DipoleModel as DipoleModelDP +from deepmd.jax.atomic_model.dipole_atomic_model import ( + DPAtomicModelDipole, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.model.dp_model import ( + make_jax_dp_model_from_dpmodel, +) + + +@BaseModel.register("dipole") +class DipoleModel(make_jax_dp_model_from_dpmodel(DipoleModelDP, DPAtomicModelDipole)): + pass diff --git a/deepmd/jax/model/dos_model.py b/deepmd/jax/model/dos_model.py new file mode 100644 index 0000000000..589fb8f73f --- /dev/null +++ b/deepmd/jax/model/dos_model.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.model.dos_model import DOSModel as DOSModelDP +from deepmd.jax.atomic_model.dos_atomic_model import ( + DPAtomicModelDOS, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.model.dp_model import ( + make_jax_dp_model_from_dpmodel, +) + + +@BaseModel.register("dos") +class DOSModel(make_jax_dp_model_from_dpmodel(DOSModelDP, DPAtomicModelDOS)): + pass diff --git a/deepmd/jax/model/dp_model.py b/deepmd/jax/model/dp_model.py new file mode 100644 index 0000000000..436582f22b --- /dev/null +++ b/deepmd/jax/model/dp_model.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +from deepmd.dpmodel.model import ( + DPModelCommon, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.env import ( + jax, + jnp, +) +from deepmd.jax.model.base_model import ( + forward_common_atomic, +) + + +def make_jax_dp_model_from_dpmodel( + dpmodel_model: type[DPModelCommon], jax_atomicmodel: type[DPAtomicModel] +) -> type[DPModelCommon]: + """Make a JAX backend DP model from a DPModel backend DP model. + + Parameters + ---------- + dpmodel_model : type[DPModelCommon] + The DPModel backend DP model. + jax_atomicmodel : type[DPAtomicModel] + The JAX backend DP atomic model. + + Returns + ------- + type[DPModelCommon] + The JAX backend DP model. + """ + + @flax_module + class jax_model(dpmodel_model): + def __setattr__(self, name: str, value: Any) -> None: + if name == "atomic_model": + value = jax_atomicmodel.deserialize(value.serialize()) + return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + return forward_common_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def format_nlist( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + extra_nlist_sort: bool = False, + ): + return dpmodel_model.format_nlist( + self, + jax.lax.stop_gradient(extended_coord), + extended_atype, + nlist, + extra_nlist_sort=extra_nlist_sort, + ) + + return jax_model diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py index a1865f5635..1d3e8a1d80 100644 --- a/deepmd/jax/model/ener_model.py +++ b/deepmd/jax/model/ener_model.py @@ -1,66 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, - Optional, -) - from deepmd.dpmodel.model import EnergyModel as EnergyModelDP -from deepmd.jax.atomic_model.dp_atomic_model import ( - DPAtomicModel, -) -from deepmd.jax.common import ( - flax_module, -) -from deepmd.jax.env import ( - jax, - jnp, +from deepmd.jax.atomic_model.energy_atomic_model import ( + DPAtomicModelEnergy, ) from deepmd.jax.model.base_model import ( BaseModel, - forward_common_atomic, +) +from deepmd.jax.model.dp_model import ( + make_jax_dp_model_from_dpmodel, ) @BaseModel.register("ener") -@flax_module -class EnergyModel(EnergyModelDP): - def __setattr__(self, name: str, value: Any) -> None: - if name == "atomic_model": - value = DPAtomicModel.deserialize(value.serialize()) - return super().__setattr__(name, value) - - def forward_common_atomic( - self, - extended_coord: jnp.ndarray, - extended_atype: jnp.ndarray, - nlist: jnp.ndarray, - mapping: Optional[jnp.ndarray] = None, - fparam: Optional[jnp.ndarray] = None, - aparam: Optional[jnp.ndarray] = None, - do_atomic_virial: bool = False, - ): - return forward_common_atomic( - self, - extended_coord, - extended_atype, - nlist, - mapping=mapping, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) - - def format_nlist( - self, - extended_coord: jnp.ndarray, - extended_atype: jnp.ndarray, - nlist: jnp.ndarray, - extra_nlist_sort: bool = False, - ): - return EnergyModelDP.format_nlist( - self, - jax.lax.stop_gradient(extended_coord), - extended_atype, - nlist, - extra_nlist_sort=extra_nlist_sort, - ) +class EnergyModel(make_jax_dp_model_from_dpmodel(EnergyModelDP, DPAtomicModelEnergy)): + pass diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py index 8b7d375841..dc350e968c 100644 --- a/deepmd/jax/model/model.py +++ b/deepmd/jax/model/model.py @@ -47,6 +47,8 @@ def get_standard_model(data: dict): descriptor = BaseDescriptor.get_class_by_type(descriptor_type)( **data["descriptor"], ) + if fitting_type in {"dipole", "polar"}: + data["fitting_net"]["embedding_width"] = descriptor.get_dim_emb() fitting = BaseFitting.get_class_by_type(fitting_type)( ntypes=descriptor.get_ntypes(), dim_descrpt=descriptor.get_dim_out(), diff --git a/deepmd/jax/model/polar_model.py b/deepmd/jax/model/polar_model.py new file mode 100644 index 0000000000..cbeccbec59 --- /dev/null +++ b/deepmd/jax/model/polar_model.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.model.polar_model import PolarModel as PolarModelDP +from deepmd.jax.atomic_model.polar_atomic_model import ( + DPAtomicModelPolar, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.model.dp_model import ( + make_jax_dp_model_from_dpmodel, +) + + +@BaseModel.register("polar") +class PolarModel(make_jax_dp_model_from_dpmodel(PolarModelDP, DPAtomicModelPolar)): + pass diff --git a/deepmd/jax/model/property_model.py b/deepmd/jax/model/property_model.py new file mode 100644 index 0000000000..bf53a039ff --- /dev/null +++ b/deepmd/jax/model/property_model.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.model.property_model import PropertyModel as PropertyModelDP +from deepmd.jax.atomic_model.property_atomic_model import ( + DPAtomicModelProperty, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.model.dp_model import ( + make_jax_dp_model_from_dpmodel, +) + + +@BaseModel.register("property") +class PropertyModel( + make_jax_dp_model_from_dpmodel(PropertyModelDP, DPAtomicModelProperty) +): + pass diff --git a/source/tests/consistent/model/test_dipole.py b/source/tests/consistent/model/test_dipole.py new file mode 100644 index 0000000000..6fe0709379 --- /dev/null +++ b/source/tests/consistent/model/test_dipole.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.dipole_model import DipoleModel as DipoleModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_JAX, + INSTALLED_PT, + INSTALLED_TF, + CommonTest, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.dipole_model import DipoleModel as DipoleModelPT +else: + DipoleModelPT = None +if INSTALLED_TF: + from deepmd.tf.model.tensor import DipoleModel as DipoleModelTF +else: + DipoleModelTF = None +if INSTALLED_JAX: + from deepmd.jax.model.dipole_model import DipoleModel as DipoleModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + DipoleModelJAX = None +from deepmd.utils.argcheck import ( + model_args, +) + + +class TestDipole(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [2, 4, 8], + "resnet_dt": False, + "axis_neuron": 8, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "type": "dipole", + "neuron": [4, 4, 4], + "resnet_dt": True, + # TODO: add numb_fparam argument to dipole fitting + "_numb_fparam": 0, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = DipoleModelTF + dp_class = DipoleModelDP + pt_class = DipoleModelPT + jax_class = DipoleModelJAX + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return True # need to fix tf consistency + + @property + def skip_jax(self) -> bool: + return not INSTALLED_JAX + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is DipoleModelDP: + return get_model_dp(data) + elif cls is DipoleModelPT: + model = get_model_pt(data) + model.atomic_model.out_bias.uniform_() + return model + elif cls is DipoleModelJAX: + return get_model_jax(data) + return cls(**data, **self.additional_data) + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ret_key="dipole", + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend in {self.RefBackend.DP, self.RefBackend.JAX}: + return ( + ret["dipole_redu"].ravel(), + ret["dipole"].ravel(), + ) + elif backend is self.RefBackend.PT: + return ( + ret["global_dipole"].ravel(), + ret["dipole"].ravel(), + ) + elif backend is self.RefBackend.TF: + return ( + ret[0].ravel(), + ret[1].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index 8f0b0309cc..83e33e499a 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -30,6 +31,11 @@ from deepmd.tf.model.dos import DOSModel as DOSModelTF else: DOSModelTF = None +if INSTALLED_JAX: + from deepmd.jax.model.dos_model import DOSModel as DOSModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + DOSModelJAX = None from deepmd.utils.argcheck import ( model_args, ) @@ -49,6 +55,7 @@ def data(self) -> dict: "resnet_dt": False, "axis_neuron": 8, "precision": "float64", + "type_one_side": True, "seed": 1, }, "fitting_net": { @@ -65,6 +72,7 @@ def data(self) -> dict: tf_class = DOSModelTF dp_class = DOSModelDP pt_class = DOSModelPT + jax_class = DOSModelJAX args = model_args() def get_reference_backend(self): @@ -86,7 +94,7 @@ def skip_tf(self): @property def skip_jax(self) -> bool: - return True + return not INSTALLED_JAX def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" @@ -97,6 +105,8 @@ def pass_data_to_cls(self, cls, data) -> Any: model = get_model_pt(data) model.atomic_model.out_bias.uniform_() return model + elif cls is DOSModelJAX: + return get_model_jax(data) return cls(**data, **self.additional_data) def setUp(self) -> None: @@ -172,7 +182,7 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend is self.RefBackend.DP: + if backend in {self.RefBackend.DP, self.RefBackend.JAX}: return ( ret["dos_redu"].ravel(), ret["dos"].ravel(), diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py new file mode 100644 index 0000000000..c6ab334b5f --- /dev/null +++ b/source/tests/consistent/model/test_polar.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.dpmodel.model.polar_model import PolarModel as PolarModelDP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_JAX, + INSTALLED_PT, + INSTALLED_TF, + CommonTest, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.polar_model import PolarModel as PolarModelPT +else: + PolarModelPT = None +if INSTALLED_TF: + from deepmd.tf.model.tensor import PolarModel as PolarModelTF +else: + PolarModelTF = None +if INSTALLED_JAX: + from deepmd.jax.model.model import get_model as get_model_jax + from deepmd.jax.model.polar_model import PolarModel as PolarModelJAX +else: + PolarModelJAX = None +from deepmd.utils.argcheck import ( + model_args, +) + + +class TestPolar(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [2, 4, 8], + "resnet_dt": False, + "axis_neuron": 8, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "type": "polar", + "neuron": [4, 4, 4], + "resnet_dt": True, + # TODO: add numb_fparam argument to polar fitting + "_numb_fparam": 0, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = PolarModelTF + dp_class = PolarModelDP + pt_class = PolarModelPT + jax_class = PolarModelJAX + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return True # need to fix tf consistency + + @property + def skip_jax(self) -> bool: + return not INSTALLED_JAX + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is PolarModelDP: + return get_model_dp(data) + elif cls is PolarModelPT: + model = get_model_pt(data) + model.atomic_model.out_bias.uniform_() + return model + elif cls is PolarModelJAX: + return get_model_jax(data) + return cls(**data, **self.additional_data) + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, self.natoms, self.coords, self.atype, self.box, suffix, ret_key="polar" + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend in {self.RefBackend.DP, self.RefBackend.JAX}: + return ( + ret["polarizability_redu"].ravel(), + ret["polarizability"].ravel(), + ) + elif backend is self.RefBackend.PT: + return ( + ret["global_polar"].ravel(), + ret["polar"].ravel(), + ) + elif backend is self.RefBackend.TF: + return ( + ret[0].ravel(), + ret[1].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/model/test_property.py b/source/tests/consistent/model/test_property.py new file mode 100644 index 0000000000..cb5f2f901e --- /dev/null +++ b/source/tests/consistent/model/test_property.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.dpmodel.model.property_model import PropertyModel as PropertyModelDP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_JAX, + INSTALLED_PT, + CommonTest, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.property_model import PropertyModel as PropertyModelPT +else: + PropertyModelPT = None +if INSTALLED_JAX: + from deepmd.jax.model.model import get_model as get_model_jax + from deepmd.jax.model.property_model import PropertyModel as PropertyModelJAX +else: + PropertyModelJAX = None +from deepmd.utils.argcheck import ( + model_args, +) + + +class TestProperty(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [2, 4, 8], + "resnet_dt": False, + "axis_neuron": 8, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "type": "property", + "neuron": [4, 4, 4], + "resnet_dt": True, + # TODO: add numb_fparam argument to property fitting + "_numb_fparam": 0, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = None + dp_class = PropertyModelDP + pt_class = PropertyModelPT + jax_class = PropertyModelJAX + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return True # need to fix tf consistency + + @property + def skip_jax(self) -> bool: + return not INSTALLED_JAX + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is PropertyModelDP: + return get_model_dp(data) + elif cls is PropertyModelPT: + model = get_model_pt(data) + model.atomic_model.out_bias.uniform_() + return model + elif cls is PropertyModelJAX: + return get_model_jax(data) + return cls(**data, **self.additional_data) + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ret_key="property", + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend in {self.RefBackend.DP, self.RefBackend.JAX}: + return ( + ret["property_redu"].ravel(), + ret["property"].ravel(), + ) + elif backend is self.RefBackend.PT: + return ( + ret["property"].ravel(), + ret["atom_property"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}")