From d9240135a3c592442c015fe68b127c6a0e10f69c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Nov 2024 15:40:29 -0500 Subject: [PATCH 01/31] checkpoint Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/__init__.py | 3 + deepmd/jax/jax2tf/make_model.py | 111 +++++++++++++ deepmd/jax/jax2tf/nlist.py | 218 ++++++++++++++++++++++++++ deepmd/jax/jax2tf/region.py | 104 ++++++++++++ deepmd/jax/jax2tf/serialization.py | 140 ++++++++++++++++- deepmd/jax/jax2tf/tfmodel.py | 39 +++-- deepmd/jax/jax2tf/transform_output.py | 113 +++++++++++++ 7 files changed, 705 insertions(+), 23 deletions(-) create mode 100644 deepmd/jax/jax2tf/make_model.py create mode 100644 deepmd/jax/jax2tf/nlist.py create mode 100644 deepmd/jax/jax2tf/region.py create mode 100644 deepmd/jax/jax2tf/transform_output.py diff --git a/deepmd/jax/jax2tf/__init__.py b/deepmd/jax/jax2tf/__init__.py index 88a928f04d..c2cda24bd7 100644 --- a/deepmd/jax/jax2tf/__init__.py +++ b/deepmd/jax/jax2tf/__init__.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf +import tensorflow.experimental.numpy as tnp if not tf.executing_eagerly(): # TF disallow temporary eager execution @@ -9,3 +10,5 @@ "If you are converting a model between different backends, " "considering converting to the `.dp` format first." ) + +tnp.experimental_enable_numpy_behavior() diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py new file mode 100644 index 0000000000..7ff07a6a7a --- /dev/null +++ b/deepmd/jax/jax2tf/make_model.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, +) + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.dpmodel.output_def import ( + ModelOutputDef, +) +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.jax.jax2tf.region import ( + normalize_coord, +) +from deepmd.jax.jax2tf.transform_output import ( + communicate_extended_output, +) + + +def model_call_from_call_lower( + *, # enforce keyword-only arguments + call_lower: Callable[ + [ + tnp.ndarray, + tnp.ndarray, + tnp.ndarray, + Optional[tnp.ndarray], + Optional[tnp.ndarray], + bool, + ], + dict[str, tnp.ndarray], + ], + rcut: float, + sel: list[int], + mixed_types: bool, + model_output_def: ModelOutputDef, + coord: tnp.ndarray, + atype: tnp.ndarray, + box: Optional[tnp.ndarray] = None, + fparam: Optional[tnp.ndarray] = None, + aparam: Optional[tnp.ndarray] = None, + do_atomic_virial: bool = False, +): + """Return model prediction from lower interface. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,tnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + atype_shape = tf.shape(atype) + nframes, nloc = atype_shape[0], atype_shape[1] + cc, bb, fp, ap = coord, box, fparam, aparam + del coord, box, fparam, aparam + if tf.shape(bb)[-1] == 0: + coord_normalized = normalize_coord( + cc.reshape(nframes, nloc, 3), + bb.reshape(nframes, 3, 3), + ) + else: + coord_normalized = cc + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + model_predict_lower = call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + ) + model_predict = communicate_extended_output( + model_predict_lower, + model_output_def, + mapping, + do_atomic_virial=do_atomic_virial, + ) + return model_predict diff --git a/deepmd/jax/jax2tf/nlist.py b/deepmd/jax/jax2tf/nlist.py new file mode 100644 index 0000000000..0a687746b0 --- /dev/null +++ b/deepmd/jax/jax2tf/nlist.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from .region import ( + to_face_distance, +) + + +## translated from torch implementation by chatgpt +def build_neighbor_list( + coord: tnp.ndarray, + atype: tnp.ndarray, + nloc: int, + rcut: float, + sel: Union[int, list[int]], + distinguish_types: bool = True, +) -> tnp.ndarray: + """Build neighbor list for a single frame. keeps nsel neighbors. + + Parameters + ---------- + coord : tnp.ndarray + exptended coordinates of shape [batch_size, nall x 3] + atype : tnp.ndarray + extended atomic types of shape [batch_size, nall] + type < 0 the atom is treat as virtual atoms. + nloc : int + number of local atoms. + rcut : float + cut-off radius + sel : int or list[int] + maximal number of neighbors (of each type). + if distinguish_types==True, nsel should be list and + the length of nsel should be equal to number of + types. + distinguish_types : bool + distinguish different types. + + Returns + ------- + neighbor_list : tnp.ndarray + Neighbor list of shape [batch_size, nloc, nsel], the neighbors + are stored in an ascending order. If the number of + neighbors is less than nsel, the positions are masked + with -1. The neighbor list of an atom looks like + |------ nsel ------| + xx xx xx xx -1 -1 -1 + if distinguish_types==True and we have two types + |---- nsel[0] -----| |---- nsel[1] -----| + xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 + For virtual atoms all neighboring positions are filled with -1. + + """ + batch_size = tf.shape(coord)[0] + coord = tnp.reshape(coord, (batch_size, -1)) + nall = tf.shape(coord)[1] // 3 + # fill virtual atoms with large coords so they are not neighbors of any + # real atom. + if tf.size(coord) > 0: + xmax = tnp.max(coord) + 2.0 * rcut + else: + xmax = tf.cast(2.0 * rcut, coord.dtype) + # nf x nall + is_vir = atype < 0 + coord1 = tnp.where( + is_vir[:, :, None], xmax, tnp.reshape(coord, (batch_size, nall, 3)) + ) + coord1 = tnp.reshape(coord1, (batch_size, nall * 3)) + if isinstance(sel, int): + sel = [sel] + nsel = sum(sel) + coord0 = coord1[:, : nloc * 3] + diff = ( + tnp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :] + - tnp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :] + ) + rr = tf.linalg.norm(diff, axis=-1) + # if central atom has two zero distances, sorting sometimes can not exclude itself + rr -= tf.eye(nloc, nall, dtype=diff.dtype)[tnp.newaxis, :, :] + nlist = tnp.argsort(rr, axis=-1) + rr = tnp.sort(rr, axis=-1) + rr = rr[:, :, 1:] + nlist = nlist[:, :, 1:] + nnei = tf.shape(rr)[2] + if nsel <= nnei: + rr = rr[:, :, :nsel] + nlist = nlist[:, :, :nsel] + else: + rr = tnp.concatenate( + [rr, tnp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut], + axis=-1, + ) + nlist = tnp.concatenate( + [nlist, tnp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + axis=-1, + ) + nlist = tnp.where( + tnp.logical_or((rr > rcut), is_vir[:, :nloc, None]), + tnp.full_like(nlist, -1), + nlist, + ) + + if distinguish_types: + return nlist_distinguish_types(nlist, atype, sel) + else: + return nlist + + +def nlist_distinguish_types( + nlist: tnp.ndarray, + atype: tnp.ndarray, + sel: list[int], +): + """Given a nlist that does not distinguish atom types, return a nlist that + distinguish atom types. + + """ + nloc = tf.shape(nlist)[1] + ret_nlist = [] + tmp_atype = tnp.tile(atype[:, None, :], (1, nloc, 1)) + mask = nlist == -1 + tnlist_0 = tnp.where(mask, tnp.zeros_like(nlist), nlist) + tnlist = tnp.take_along_axis(tmp_atype, tnlist_0, axis=2) + tnlist = tnp.where(mask, tnp.full_like(tnlist, -1), tnlist) + for ii, ss in enumerate(sel): + pick_mask = tf.cast(tnlist == ii, tnp.int32) + sorted_indices = tnp.argsort(-pick_mask, kind="stable", axis=-1) + pick_mask_sorted = -tnp.sort(-pick_mask, axis=-1) + inlist = tnp.take_along_axis(nlist, sorted_indices, axis=2) + inlist = tnp.where( + ~tf.cast(pick_mask_sorted, tf.bool), tnp.full_like(inlist, -1), inlist + ) + ret_nlist.append(inlist[..., :ss]) + ret = tf.concat(ret_nlist, axis=-1) + return ret + + +def tf_outer(a, b): + return tf.einsum("i,j->ij", a, b) + + +## translated from torch implementation by chatgpt +def extend_coord_with_ghosts( + coord: tnp.ndarray, + atype: tnp.ndarray, + cell: Optional[tnp.ndarray], + rcut: float, +): + """Extend the coordinates of the atoms by appending peridoc images. + The number of images is large enough to ensure all the neighbors + within rcut are appended. + + Parameters + ---------- + coord : tnp.ndarray + original coordinates of shape [-1, nloc*3]. + atype : tnp.ndarray + atom type of shape [-1, nloc]. + cell : tnp.ndarray + simulation cell tensor of shape [-1, 9]. + rcut : float + the cutoff radius + + Returns + ------- + extended_coord: tnp.ndarray + extended coordinates of shape [-1, nall*3]. + extended_atype: tnp.ndarray + extended atom type of shape [-1, nall]. + index_mapping: tnp.ndarray + mapping extended index to the local index + + """ + atype_shape = tf.shape(atype) + nf, nloc = atype_shape[0], atype_shape[1] + # int64 for index + aidx = tf.range(nloc, dtype=tnp.int64) + aidx = tnp.tile(aidx[tnp.newaxis, :], (nf, 1)) + if tf.shape(cell)[-1] == 0: + nall = nloc + extend_coord = coord + extend_atype = atype + extend_aidx = aidx + else: + coord = tnp.reshape(coord, (nf, nloc, 3)) + cell = tnp.reshape(cell, (nf, 3, 3)) + to_face = to_face_distance(cell) + nbuff = tf.cast(tnp.ceil(rcut / to_face), tnp.int64) + nbuff = tnp.max(nbuff, axis=0) + xi = tf.range(nbuff[0], nbuff[0] + 1, 1, dtype=tnp.int64) + yi = tf.range(nbuff[1], nbuff[1] + 1, 1, dtype=tnp.int64) + zi = tf.range(nbuff[2], nbuff[2] + 1, 1, dtype=tnp.int64) + xyz = tf_outer(xi, tnp.asarray([1, 0, 0]))[:, tnp.newaxis, tnp.newaxis, :] + xyz = xyz + tf_outer(yi, tnp.asarray([0, 1, 0]))[tnp.newaxis, :, tnp.newaxis, :] + xyz = xyz + tf_outer(zi, tnp.asarray([0, 0, 1]))[tnp.newaxis, tnp.newaxis, :, :] + xyz = tnp.reshape(xyz, (-1, 3)) + xyz = tf.cast(xyz, coord.dtype) + shift_idx = tnp.take(xyz, tnp.argsort(tf.linalg.norm(xyz, axis=1)), axis=0) + ns = tf.shape(shift_idx)[0] + nall = ns * nloc + shift_vec = tnp.einsum("sd,fdk->fsk", shift_idx, cell) + # shift_vec = tnp.tensordot(shift_idx, cell, axes=([1], [1])) + # shift_vec = tnp.transpose(shift_vec, (1, 0, 2)) + extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] + extend_atype = tnp.tile(atype[:, :, tnp.newaxis], (1, ns, 1)) + extend_aidx = tnp.tile(aidx[:, :, tnp.newaxis], (1, ns, 1)) + + return ( + tnp.reshape(extend_coord, (nf, nall * 3)), + tnp.reshape(extend_atype, (nf, nall)), + tnp.reshape(extend_aidx, (nf, nall)), + ) diff --git a/deepmd/jax/jax2tf/region.py b/deepmd/jax/jax2tf/region.py new file mode 100644 index 0000000000..96024bd79a --- /dev/null +++ b/deepmd/jax/jax2tf/region.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + + +def phys2inter( + coord: tnp.ndarray, + cell: tnp.ndarray, +) -> tnp.ndarray: + """Convert physical coordinates to internal(direct) coordinates. + + Parameters + ---------- + coord : tnp.ndarray + physical coordinates of shape [*, na, 3]. + cell : tnp.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + inter_coord: tnp.ndarray + the internal coordinates + + """ + rec_cell = tf.linalg.inv(cell) + return tnp.matmul(coord, rec_cell) + + +def inter2phys( + coord: tnp.ndarray, + cell: tnp.ndarray, +) -> tnp.ndarray: + """Convert internal(direct) coordinates to physical coordinates. + + Parameters + ---------- + coord : tnp.ndarray + internal coordinates of shape [*, na, 3]. + cell : tnp.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + phys_coord: tnp.ndarray + the physical coordinates + + """ + return tnp.matmul(coord, cell) + + +def normalize_coord( + coord: tnp.ndarray, + cell: tnp.ndarray, +) -> tnp.ndarray: + """Apply PBC according to the atomic coordinates. + + Parameters + ---------- + coord : tnp.ndarray + original coordinates of shape [*, na, 3]. + cell : tnp.ndarray + simulation cell shape [*, 3, 3]. + + Returns + ------- + wrapped_coord: tnp.ndarray + wrapped coordinates of shape [*, na, 3]. + + """ + icoord = phys2inter(coord, cell) + icoord = tnp.remainder(icoord, 1.0) + return inter2phys(icoord, cell) + + +def to_face_distance( + cell: tnp.ndarray, +) -> tnp.ndarray: + """Compute the to-face-distance of the simulation cell. + + Parameters + ---------- + cell : tnp.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + dist: tnp.ndarray + the to face distances of shape [*, 3] + + """ + cshape = tf.shape(cell) + dist = b_to_face_distance(tnp.reshape(cell, [-1, 3, 3])) + return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0)) + + +def b_to_face_distance(cell): + volume = tf.linalg.det(cell) + c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...]) + _h2yz = volume / tf.linalg.norm(c_yz, axis=-1) + c_zx = tf.linalg.cross(cell[:, 2, ...], cell[:, 0, ...]) + _h2zx = volume / tf.linalg.norm(c_zx, axis=-1) + c_xy = tf.linalg.cross(cell[:, 0, ...], cell[:, 1, ...]) + _h2xy = volume / tf.linalg.norm(c_xy, axis=-1) + return tnp.stack([_h2yz, _h2zx, _h2xy], axis=1) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index dff43a11fc..7e560f6008 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -1,11 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json +from typing import ( + Optional, +) import tensorflow as tf +import tensorflow.experimental.numpy as tnp from jax.experimental import ( jax2tf, ) +from deepmd.jax.jax2tf.make_model import ( + model_call_from_call_lower, +) from deepmd.jax.model.base_model import ( BaseModel, ) @@ -28,7 +35,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: tf_model = tf.Module() - def exported_whether_do_atomic_virial(do_atomic_virial): + def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms): def call_lower_with_fixed_do_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): @@ -42,13 +49,20 @@ def call_lower_with_fixed_do_atomic_virial( do_atomic_virial=do_atomic_virial, ) + # nghost >= 1 is assumed if there is + # other workaround does not work, such as + # nall; nloc + nghost - 1 + if has_ghost_atoms: + nghost = "nghost" + else: + nghost = "0" return jax2tf.convert( call_lower_with_fixed_do_atomic_virial, polymorphic_shapes=[ - "(nf, nloc + nghost, 3)", - "(nf, nloc + nghost)", + f"(nf, nloc + {nghost}, 3)", + f"(nf, nloc + {nghost})", f"(nf, nloc, {model.get_nnei()})", - "(nf, nloc + nghost)", + f"(nf, nloc + {nghost})", f"(nf, {model.get_dim_fparam()})", f"(nf, nloc, {model.get_dim_aparam()})", ], @@ -71,8 +85,14 @@ def call_lower_with_fixed_do_atomic_virial( def call_lower_without_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): - return exported_whether_do_atomic_virial(do_atomic_virial=False)( - coord, atype, nlist, mapping, fparam, aparam + return tf.cond( + tf.shape(coord)[1] == tf.shape(nlist)[1], + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=False + )(coord, atype, nlist, mapping, fparam, aparam), + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=True + )(coord, atype, nlist, mapping, fparam, aparam), ) tf_model.call_lower = call_lower_without_atomic_virial @@ -89,12 +109,116 @@ def call_lower_without_atomic_virial( ], ) def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): - return exported_whether_do_atomic_virial(do_atomic_virial=True)( - coord, atype, nlist, mapping, fparam, aparam + return tf.cond( + tf.shape(coord)[1] == tf.shape(nlist)[1], + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=False + )(coord, atype, nlist, mapping, fparam, aparam), + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=True + )(coord, atype, nlist, mapping, fparam, aparam), ) tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial + def make_call_whether_do_atomic_virial(do_atomic_virial: bool): + if do_atomic_virial: + call_lower = call_lower_with_atomic_virial + else: + call_lower = call_lower_without_atomic_virial + + def call( + coord: tnp.ndarray, + atype: tnp.ndarray, + box: Optional[tnp.ndarray] = None, + fparam: Optional[tnp.ndarray] = None, + aparam: Optional[tnp.ndarray] = None, + ): + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return model_call_from_call_lower( + call_lower=call_lower, + rcut=model.get_rcut(), + sel=model.get_sel(), + mixed_types=model.mixed_types(), + model_output_def=model.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return call + + @tf.function( + autograph=True, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.float64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_with_atomic_virial( + coord: tnp.ndarray, + atype: tnp.ndarray, + box: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ): + return make_call_whether_do_atomic_virial(do_atomic_virial=True)( + coord, atype, box, fparam, aparam + ) + + tf_model.call_atomic_virial = call_with_atomic_virial + + @tf.function( + autograph=True, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.float64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_without_atomic_virial( + coord: tnp.ndarray, + atype: tnp.ndarray, + box: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ): + return make_call_whether_do_atomic_virial(do_atomic_virial=False)( + coord, atype, box, fparam, aparam + ) + + tf_model.call = call_without_atomic_virial + # set functions to export other attributes @tf.function def get_type_map(): diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 8f04014a97..0d7b13ba1f 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -7,9 +7,6 @@ import jax.experimental.jax2tf as jax2tf import tensorflow as tf -from deepmd.dpmodel.model.make_model import ( - model_call_from_call_lower, -) from deepmd.dpmodel.output_def import ( FittingOutputDef, ModelOutputDef, @@ -55,6 +52,8 @@ def __init__( self._call_lower_atomic_virial = jax2tf.call_tf( self.model.call_lower_atomic_virial ) + self._call = jax2tf.call_tf(self.model.call) + self._call_atomic_virial = jax2tf.call_tf(self.model.call_atomic_virial) self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist()) self.rcut = self.model.get_rcut().numpy().item() self.dim_fparam = self.model.get_dim_fparam().numpy().item() @@ -142,18 +141,28 @@ def call( The keys are defined by the `ModelOutputDef`. """ - return model_call_from_call_lower( - call_lower=self.call_lower, - rcut=self.get_rcut(), - sel=self.get_sel(), - mixed_types=self.mixed_types(), - model_output_def=self.model_output_def(), - coord=coord, - atype=atype, - box=box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, + if do_atomic_virial: + call = self._call_atomic_virial + else: + call = self._call + # Attempt to convert a value (None) with an unsupported type () to a Tensor. + if box is None: + box = jnp.empty((coord.shape[0], 0, 0), dtype=jnp.float64) + if fparam is None: + fparam = jnp.empty( + (coord.shape[0], self.get_dim_fparam()), dtype=jnp.float64 + ) + if aparam is None: + aparam = jnp.empty( + (coord.shape[0], coord.shape[1], self.get_dim_aparam()), + dtype=jnp.float64, + ) + return call( + coord, + atype, + box, + fparam, + aparam, ) def model_output_def(self): diff --git a/deepmd/jax/jax2tf/transform_output.py b/deepmd/jax/jax2tf/transform_output.py new file mode 100644 index 0000000000..f853744c02 --- /dev/null +++ b/deepmd/jax/jax2tf/transform_output.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.dpmodel.output_def import ( + ModelOutputDef, + OutputVariableDef, + get_deriv_name, + get_reduce_name, +) + + +def get_leading_dims( + vv: tnp.ndarray, + vdef: OutputVariableDef, +) -> tnp.ndarray: + """Get the dimensions of nf x nloc. + + Parameters + ---------- + vv : np.ndarray + The input array from which to compute the leading dimensions. + vdef : OutputVariableDef + The output variable definition containing the shape to exclude from `vv`. + + Returns + ------- + list + A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions. + """ + vshape = tf.shape(vv) + return vshape[: (len(vshape) - len(vdef.shape))] + + +def communicate_extended_output( + model_ret: dict[str, tnp.ndarray], + model_output_def: ModelOutputDef, + mapping: tnp.ndarray, # nf x nloc + do_atomic_virial: bool = False, +) -> dict[str, tnp.ndarray]: + """Transform the output of the model network defined on + local and ghost (extended) atoms to local atoms. + + """ + new_ret = {} + for kk in model_output_def.keys_outp(): + vv = model_ret[kk] + vdef = model_output_def[kk] + new_ret[kk] = vv + if vdef.reducible: + kk_redu = get_reduce_name(kk) + new_ret[kk_redu] = model_ret[kk_redu] + kk_derv_r, kk_derv_c = get_deriv_name(kk) + mldims = tf.shape(mapping) + vldims = get_leading_dims(vv, vdef) + if vdef.r_differentiable: + if model_ret[kk_derv_r] is not None: + derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 + indices = mapping.reshape(tf.shape(mapping)[0], -1, 1) + # concat frame idx + indices = tf.concat( + [ + tf.repeat( + tf.range(tf.shape(indices)[0], dtype=indices.dtype), + tf.shape(mapping)[1], + ).reshape(tf.shape(indices)), + indices, + ], + axis=-1, + ) + force = tf.scatter_nd( + indices, + model_ret[kk_derv_r], + tf.cast(tf.concat([vldims, derv_r_ext_dims], axis=0), tf.int64), + ) + new_ret[kk_derv_r] = force.reshape( + tf.concat([tf.shape(force)[:2], list(vdef.shape), [3]], axis=0) + ) + else: + # name holders + new_ret[kk_derv_r] = None + if vdef.c_differentiable: + assert vdef.r_differentiable + if model_ret[kk_derv_c] is not None: + derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005 + indices = mapping.reshape(tf.shape(mapping)[0], -1, 1) + # concat frame idx + indices = tf.concat( + [ + tf.repeat( + tf.range(tf.shape(indices)[0], dtype=indices.dtype), + tf.shape(mapping)[1], + ).reshape(tf.shape(indices)), + indices, + ], + axis=-1, + ) + virial = tf.scatter_nd( + indices, + model_ret[kk_derv_c], + tf.cast(tf.concat([vldims, derv_c_ext_dims], axis=0), tf.int64), + ) + new_ret[kk_derv_c] = virial.reshape( + tf.concat([tf.shape(virial)[:2], list(vdef.shape), [9]], axis=0) + ) + new_ret[kk_derv_c + "_redu"] = tnp.sum(new_ret[kk_derv_c], axis=1) + else: + new_ret[kk_derv_c] = None + new_ret[kk_derv_c + "_redu"] = None + if not do_atomic_virial: + # pop atomic virial, because it is not correctly calculated. + new_ret.pop(kk_derv_c) + return new_ret From 373ea65d12223908549f4901c5e787c3f32dd228 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Nov 2024 16:42:15 -0500 Subject: [PATCH 02/31] bugfix Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/make_model.py | 11 +++++------ deepmd/jax/jax2tf/nlist.py | 9 ++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index 7ff07a6a7a..9a3811c644 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Callable, - Optional, ) import tensorflow as tf @@ -29,8 +28,8 @@ def model_call_from_call_lower( tnp.ndarray, tnp.ndarray, tnp.ndarray, - Optional[tnp.ndarray], - Optional[tnp.ndarray], + tnp.ndarray, + tnp.ndarray, bool, ], dict[str, tnp.ndarray], @@ -41,9 +40,9 @@ def model_call_from_call_lower( model_output_def: ModelOutputDef, coord: tnp.ndarray, atype: tnp.ndarray, - box: Optional[tnp.ndarray] = None, - fparam: Optional[tnp.ndarray] = None, - aparam: Optional[tnp.ndarray] = None, + box: tnp.ndarray = None, + fparam: tnp.ndarray = None, + aparam: tnp.ndarray = None, do_atomic_virial: bool = False, ): """Return model prediction from lower interface. diff --git a/deepmd/jax/jax2tf/nlist.py b/deepmd/jax/jax2tf/nlist.py index 0a687746b0..5a0ed58b63 100644 --- a/deepmd/jax/jax2tf/nlist.py +++ b/deepmd/jax/jax2tf/nlist.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Optional, Union, ) @@ -149,7 +148,7 @@ def tf_outer(a, b): def extend_coord_with_ghosts( coord: tnp.ndarray, atype: tnp.ndarray, - cell: Optional[tnp.ndarray], + cell: tnp.ndarray, rcut: float, ): """Extend the coordinates of the atoms by appending peridoc images. @@ -193,9 +192,9 @@ def extend_coord_with_ghosts( to_face = to_face_distance(cell) nbuff = tf.cast(tnp.ceil(rcut / to_face), tnp.int64) nbuff = tnp.max(nbuff, axis=0) - xi = tf.range(nbuff[0], nbuff[0] + 1, 1, dtype=tnp.int64) - yi = tf.range(nbuff[1], nbuff[1] + 1, 1, dtype=tnp.int64) - zi = tf.range(nbuff[2], nbuff[2] + 1, 1, dtype=tnp.int64) + xi = tf.range(-nbuff[0], nbuff[0] + 1, 1, dtype=tnp.int64) + yi = tf.range(-nbuff[1], nbuff[1] + 1, 1, dtype=tnp.int64) + zi = tf.range(-nbuff[2], nbuff[2] + 1, 1, dtype=tnp.int64) xyz = tf_outer(xi, tnp.asarray([1, 0, 0]))[:, tnp.newaxis, tnp.newaxis, :] xyz = xyz + tf_outer(yi, tnp.asarray([0, 1, 0]))[tnp.newaxis, :, tnp.newaxis, :] xyz = xyz + tf_outer(zi, tnp.asarray([0, 0, 1]))[tnp.newaxis, tnp.newaxis, :, :] From 933e4dfcd982f800c98ed1a4bbdf0c982b65fa6a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Nov 2024 16:48:48 -0500 Subject: [PATCH 03/31] bugfix Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/make_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index 9a3811c644..feb58c74cb 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -74,7 +74,7 @@ def model_call_from_call_lower( nframes, nloc = atype_shape[0], atype_shape[1] cc, bb, fp, ap = coord, box, fparam, aparam del coord, box, fparam, aparam - if tf.shape(bb)[-1] == 0: + if tf.shape(bb)[-1] != 0: coord_normalized = normalize_coord( cc.reshape(nframes, nloc, 3), bb.reshape(nframes, 3, 3), From 94d20543b5c3a1d83eb4636524a9125de343f2f8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Nov 2024 17:04:46 -0500 Subject: [PATCH 04/31] nopbc Signed-off-by: Jinzhe Zeng --- deepmd/jax/infer/deep_eval.py | 6 +++++ deepmd/jax/model/hlo.py | 18 +++++++++++--- deepmd/jax/utils/serialization.py | 35 ++++++++++++++++++++++----- source/tests/consistent/io/test_io.py | 27 +++++++++++++++++++++ 4 files changed, 77 insertions(+), 9 deletions(-) diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index fc526a502e..b9d1974c27 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -97,6 +97,12 @@ def __init__( stablehlo_atomic_virial=model_data["@variables"][ "stablehlo_atomic_virial" ].tobytes(), + stablehlo_no_ghost=model_data["@variables"][ + "stablehlo_no_ghost" + ].tobytes(), + stablehlo_atomic_virial_no_ghost=model_data["@variables"][ + "stablehlo_atomic_virial_no_ghost" + ].tobytes(), model_def_script=model_data["model_def_script"], **model_data["constants"], ) diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 2946f8bec7..4d59957456 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -46,6 +46,8 @@ def __init__( self, stablehlo, stablehlo_atomic_virial, + stablehlo_no_ghost, + stablehlo_atomic_virial_no_ghost, model_def_script, type_map, rcut, @@ -62,6 +64,10 @@ def __init__( self._call_lower_atomic_virial = jax_export.deserialize( stablehlo_atomic_virial ).call + self._call_lower_no_ghost = jax_export.deserialize(stablehlo_no_ghost).call + self._call_lower_atomic_virial_no_ghost = jax_export.deserialize( + stablehlo_atomic_virial_no_ghost + ).call self.stablehlo = stablehlo self.type_map = type_map self.rcut = rcut @@ -174,10 +180,16 @@ def call_lower( aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ): - if do_atomic_virial: - call_lower = self._call_lower_atomic_virial + if extended_coord.shape[1] > nlist.shape[1]: + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower else: - call_lower = self._call_lower + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial_no_ghost + else: + call_lower = self._call_lower_no_ghost return call_lower( extended_coord, extended_atype, diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 6ab99a81f0..1ed26f2d40 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -53,7 +53,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") - def exported_whether_do_atomic_virial(do_atomic_virial): + def exported_whether_do_atomic_virial( + do_atomic_virial: bool, has_ghost_atoms: bool + ): def call_lower_with_fixed_do_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): @@ -67,13 +69,18 @@ def call_lower_with_fixed_do_atomic_virial( do_atomic_virial=do_atomic_virial, ) + if has_ghost_atoms: + nghost_ = nghost + else: + nghost_ = 0 + return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( jax.ShapeDtypeStruct( - (nf, nloc + nghost, 3), jnp.float64 + (nf, nloc + nghost_, 3), jnp.float64 ), # extended_coord - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype + jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int32), # extended_atype jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping + jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int64), # mapping jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() else None, # fparam @@ -82,18 +89,34 @@ def call_lower_with_fixed_do_atomic_virial( else None, # aparam ) - exported = exported_whether_do_atomic_virial(do_atomic_virial=False) + exported = exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=True + ) exported_atomic_virial = exported_whether_do_atomic_virial( - do_atomic_virial=True + do_atomic_virial=True, has_ghost_atoms=True ) serialized: bytearray = exported.serialize() serialized_atomic_virial = exported_atomic_virial.serialize() + + exported_no_ghost = exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=False + ) + exported_atomic_virial_no_ghost = exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=False + ) + serialized_no_ghost: bytearray = exported_no_ghost.serialize() + serialized_atomic_virial_no_ghost = exported_atomic_virial_no_ghost.serialize() + data = data.copy() data.setdefault("@variables", {}) data["@variables"]["stablehlo"] = np.void(serialized) data["@variables"]["stablehlo_atomic_virial"] = np.void( serialized_atomic_virial ) + data["@variables"]["stablehlo_no_ghost"] = np.void(serialized_no_ghost) + data["@variables"]["stablehlo_atomic_virial_no_ghost"] = np.void( + serialized_atomic_virial_no_ghost + ) data["constants"] = { "type_map": model.get_type_map(), "rcut": model.get_rcut(), diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index ca213da13c..8eb26e7ac3 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -142,6 +142,7 @@ def test_deep_eval(self): nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] + rets_nopbc = [] for backend_name, suffix_idx in ( # unfortunately, jax2tf cannot work with tf v1 behaviors ("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0), @@ -182,6 +183,23 @@ def test_deep_eval(self): atomic=True, ) rets.append(ret) + ret = deep_eval.eval( + self.coords, + None, + self.atype, + fparam=fparam, + aparam=aparam, + ) + rets_nopbc.append(ret) + ret = deep_eval.eval( + self.coords, + None, + self.atype, + fparam=fparam, + aparam=aparam, + atomic=True, + ) + rets_nopbc.append(ret) for ret in rets[1:]: for vv1, vv2 in zip(rets[0], ret): if np.isnan(vv2).all(): @@ -189,6 +207,15 @@ def test_deep_eval(self): continue np.testing.assert_allclose(vv1, vv2, rtol=1e-12, atol=1e-12) + for idx, ret in enumerate(rets_nopbc[1:]): + for vv1, vv2 in zip(rets_nopbc[0], ret): + if np.isnan(vv2).all(): + # expect all nan if not supported + continue + np.testing.assert_allclose( + vv1, vv2, rtol=1e-12, atol=1e-12, err_msg=f"backend {idx+1}" + ) + class TestDeepPot(unittest.TestCase, IOTest): def setUp(self): From 84cb819aa31fe646511ee484083384a3f1af57e1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Nov 2024 17:27:26 -0500 Subject: [PATCH 05/31] clean up default values Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/make_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index feb58c74cb..d21fc998b5 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -40,9 +40,9 @@ def model_call_from_call_lower( model_output_def: ModelOutputDef, coord: tnp.ndarray, atype: tnp.ndarray, - box: tnp.ndarray = None, - fparam: tnp.ndarray = None, - aparam: tnp.ndarray = None, + box: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, do_atomic_virial: bool = False, ): """Return model prediction from lower interface. From bd27d4f7f730bcda0401759b68b7857a280e35a1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 02:56:16 -0500 Subject: [PATCH 06/31] add tests Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 2 +- pyproject.toml | 3 + source/tests/jax/__init__.py | 1 + source/tests/jax/jax2tf/__init__.py | 8 ++ source/tests/jax/jax2tf/test_nlist.py | 153 +++++++++++++++++++++++++ source/tests/jax/jax2tf/test_region.py | 53 +++++++++ 6 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 source/tests/jax/__init__.py create mode 100644 source/tests/jax/jax2tf/__init__.py create mode 100644 source/tests/jax/jax2tf/test_nlist.py create mode 100644 source/tests/jax/jax2tf/test_region.py diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 422dcb5f17..f03aa410f9 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -58,7 +58,7 @@ jobs: env: NUM_WORKERS: 0 - name: Test TF2 eager mode - run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0 + run: pytest --cov=deepmd source/tests/consistent/io/test_io.py source/tests/jax/jax2tf --durations=0 env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 diff --git a/pyproject.toml b/pyproject.toml index 802e920014..f911f6905c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -405,8 +405,10 @@ convention = "numpy" banned-module-level-imports = [ "deepmd.tf", "deepmd.pt", + "deepmd.jax", "tensorflow", "torch", + "jax", ] [tool.ruff.lint.flake8-tidy-imports.banned-api] @@ -419,6 +421,7 @@ banned-module-level-imports = [ "deepmd/jax/**" = ["TID253"] "source/tests/tf/**" = ["TID253"] "source/tests/pt/**" = ["TID253"] +"source/tests/jax/**" = ["TID253"] "source/tests/universal/pt/**" = ["TID253"] "source/ipi/tests/**" = ["TID253"] "source/lmp/tests/**" = ["TID253"] diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/jax/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py new file mode 100644 index 0000000000..3c27417989 --- /dev/null +++ b/source/tests/jax/jax2tf/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +from ...utils import ( + DP_TEST_TF2_ONLY, +) + +pytest.mark.skipif(not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1") diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py new file mode 100644 index 0000000000..feb8deb92b --- /dev/null +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.jax.jax2tf.region import ( + inter2phys, +) + +dtype = tnp.float64 + + +class TestNeighList(tf.test.TestCase): + def setUp(self): + self.nf = 3 + self.nloc = 3 + self.ns = 5 * 5 * 3 + self.nall = self.ns * self.nloc + self.cell = tnp.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype) + self.icoord = tnp.array([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) + self.atype = tnp.array([-1, 0, 1], dtype=tnp.int32) + [self.cell, self.icoord, self.atype] = [ + tnp.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] + ] + self.coord = inter2phys(self.icoord, self.cell).reshape([-1, self.nloc * 3]) + self.cell = self.cell.reshape([-1, 9]) + [self.cell, self.coord, self.atype] = [ + tnp.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] + ] + self.rcut = 1.01 + self.prec = 1e-10 + self.nsel = [10, 10] + self.ref_nlist = tnp.array( + [ + [-1] * sum(self.nsel), + [1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1], + ] + ) + + def test_build_notype(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + sum(self.nsel), + distinguish_types=False, + ) + self.assertAllClose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + self.assertAllClose( + tnp.sort(nlist_loc, axis=-1), + tnp.sort(self.ref_nlist, axis=-1), + ) + + def test_build_type(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + self.nsel, + distinguish_types=True, + ) + self.assertAllClose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + for ii in range(2): + self.assertAllClose( + tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), + tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), + ) + + def test_extend_coord(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + # expected ncopy x nloc + self.assertEqual(list(ecoord.shape), [self.nf, self.nall * 3]) + self.assertEqual(list(eatype.shape), [self.nf, self.nall]) + self.assertEqual(list(mapping.shape), [self.nf, self.nall]) + # check the nloc part is identical with original coord + self.assertAllClose( + ecoord[:, : self.nloc * 3], self.coord, rtol=self.prec, atol=self.prec + ) + # check the shift vectors are aligned with grid + shift_vec = ( + ecoord.reshape([-1, self.ns, self.nloc, 3]) + - self.coord.reshape([-1, self.nloc, 3])[:, None, :, :] + ) + shift_vec = shift_vec.reshape([-1, self.nall, 3]) + # hack!!! assumes identical cell across frames + shift_vec = tnp.matmul( + shift_vec, tf.linalg.inv(self.cell.reshape([self.nf, 3, 3])[0]) + ) + # nf x nall x 3 + shift_vec = tnp.round(shift_vec) + # check: identical shift vecs + self.assertAllClose(shift_vec[0], shift_vec[1], rtol=self.prec, atol=self.prec) + # check: shift idx aligned with grid + mm, _, cc = tf.unique_with_counts(shift_vec[0][:, 0]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 1]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 2]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-1, 0, 1], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 3] * 3, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py new file mode 100644 index 0000000000..54286d4dac --- /dev/null +++ b/source/tests/jax/jax2tf/test_region.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.region import ( + inter2phys, + to_face_distance, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +class TestRegion(tf.test.TestCase): + def setUp(self): + self.cell = tnp.array( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], + ) + self.cell = tnp.reshape(self.cell, [1, 1, -1, 3]) + self.cell = tnp.tile(self.cell, [4, 5, 1, 1]) + self.prec = 1e-8 + + def test_inter_to_phys(self): + rng = tf.random.Generator.from_seed(GLOBAL_SEED) + inter = rng.normal(shape=[4, 5, 3, 3]) + phys = inter2phys(inter, self.cell) + for ii in range(4): + for jj in range(5): + expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj]) + self.assertAllClose( + phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec + ) + + def test_to_face_dist(self): + cell0 = self.cell[0][0] + vol = tf.linalg.det(cell0) + # area of surfaces xy, xz, yz + sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1])) + sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2])) + syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2])) + # vol / area gives distance + dz = vol / sxy + dy = vol / sxz + dx = vol / syz + expected = tnp.array([dx, dy, dz]) + dists = to_face_distance(self.cell) + for ii in range(4): + for jj in range(5): + self.assertAllClose( + dists[ii][jj], expected, rtol=self.prec, atol=self.prec + ) From a00aae8e5b28a52e8054e0586282aff8912158bc Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 03:34:41 -0500 Subject: [PATCH 07/31] skip the whole module Signed-off-by: Jinzhe Zeng --- source/tests/jax/jax2tf/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index 3c27417989..c4cb776cc4 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -5,4 +5,6 @@ DP_TEST_TF2_ONLY, ) -pytest.mark.skipif(not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1") +pytest.mark.skipif( + not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True +) From 95c476bf51f7df38b925b3f357c81a78a4d5169c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 03:44:34 -0500 Subject: [PATCH 08/31] fix tests on ci Signed-off-by: Jinzhe Zeng --- source/tests/jax/jax2tf/__init__.py | 9 --------- source/tests/jax/jax2tf/test_nlist.py | 28 ++++++++++++++++++-------- source/tests/jax/jax2tf/test_region.py | 27 ++++++++++++++++++------- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index c4cb776cc4..6ceb116d85 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -1,10 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import pytest - -from ...utils import ( - DP_TEST_TF2_ONLY, -) - -pytest.mark.skipif( - not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True -) diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index feb8deb92b..c3fe3d034d 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -1,17 +1,29 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import tensorflow as tf -import tensorflow.experimental.numpy as tnp +import pytest -from deepmd.jax.jax2tf.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, +from ...utils import ( + DP_TEST_TF2_ONLY, ) -from deepmd.jax.jax2tf.region import ( - inter2phys, + +pytest.mark.skipif( + not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True ) -dtype = tnp.float64 + +if DP_TEST_TF2_ONLY: + import tensorflow as tf + import tensorflow.experimental.numpy as tnp + + from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, + ) + from deepmd.jax.jax2tf.region import ( + inter2phys, + ) + + dtype = tnp.float64 class TestNeighList(tf.test.TestCase): diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index 54286d4dac..b7e3a7d89f 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -1,18 +1,31 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import tensorflow as tf -import tensorflow.experimental.numpy as tnp -from deepmd.jax.jax2tf.region import ( - inter2phys, - to_face_distance, +import pytest + +from ...utils import ( + DP_TEST_TF2_ONLY, ) -from ...seed import ( - GLOBAL_SEED, +pytest.mark.skipif( + not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True ) +if DP_TEST_TF2_ONLY: + import tensorflow as tf + import tensorflow.experimental.numpy as tnp + + from deepmd.jax.jax2tf.region import ( + inter2phys, + to_face_distance, + ) + + from ...seed import ( + GLOBAL_SEED, + ) + + class TestRegion(tf.test.TestCase): def setUp(self): self.cell = tnp.array( From 8bef185dce5e02bbe2d5be151390dcf5a12ed9ba Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 04:01:26 -0500 Subject: [PATCH 09/31] fix Signed-off-by: Jinzhe Zeng --- source/tests/jax/jax2tf/test_nlist.py | 15 +++++++-------- source/tests/jax/jax2tf/test_region.py | 19 ++++++++----------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index c3fe3d034d..3b380bc2cc 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -1,20 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import pytest +import tensorflow as tf +import tensorflow.experimental.numpy as tnp from ...utils import ( DP_TEST_TF2_ONLY, ) -pytest.mark.skipif( - not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True -) - - if DP_TEST_TF2_ONLY: - import tensorflow as tf - import tensorflow.experimental.numpy as tnp - from deepmd.jax.jax2tf.nlist import ( build_neighbor_list, extend_coord_with_ghosts, @@ -26,6 +20,11 @@ dtype = tnp.float64 +pytest.mark.skipif( + not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True +) + + class TestNeighList(tf.test.TestCase): def setUp(self): self.nf = 3 diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index b7e3a7d89f..ce2536ddb0 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -2,28 +2,25 @@ import pytest +import tensorflow as tf +import tensorflow.experimental.numpy as tnp +from ...seed import ( + GLOBAL_SEED, +) from ...utils import ( DP_TEST_TF2_ONLY, ) -pytest.mark.skipif( - not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True -) - - if DP_TEST_TF2_ONLY: - import tensorflow as tf - import tensorflow.experimental.numpy as tnp - from deepmd.jax.jax2tf.region import ( inter2phys, to_face_distance, ) - from ...seed import ( - GLOBAL_SEED, - ) +pytest.mark.skipif( + not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True +) class TestRegion(tf.test.TestCase): From 5e1062157422a30dd46fe44ada3ad06461382492 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 04:49:52 -0500 Subject: [PATCH 10/31] fix skip testing. I am still confused why it doesn't work Signed-off-by: Jinzhe Zeng --- source/tests/jax/jax2tf/test_nlist.py | 7 +++---- source/tests/jax/jax2tf/test_region.py | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index 3b380bc2cc..049d948f84 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -20,11 +20,10 @@ dtype = tnp.float64 -pytest.mark.skipif( - not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True +@pytest.mark.skipif( + not DP_TEST_TF2_ONLY, + reason="TF2 conflicts with TF1", ) - - class TestNeighList(tf.test.TestCase): def setUp(self): self.nf = 3 diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index ce2536ddb0..f38b661401 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -18,11 +18,11 @@ to_face_distance, ) -pytest.mark.skipif( - not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True -) - +@pytest.mark.skipif( + not DP_TEST_TF2_ONLY, + reason="TF2 conflicts with TF1", +) class TestRegion(tf.test.TestCase): def setUp(self): self.cell = tnp.array( From fc7b6b7a1198a3fe8fad1ba0fb8ba14d3ea39309 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 15:31:55 -0500 Subject: [PATCH 11/31] try to resolve OOM issue Signed-off-by: Jinzhe Zeng --- source/tests/consistent/io/test_io.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 8eb26e7ac3..39bbf8056e 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy +import gc import shutil import unittest from pathlib import ( @@ -108,6 +109,9 @@ def test_data_equal(self): data.pop(kk, None) reference_data.pop(kk, None) np.testing.assert_equal(data, reference_data) + # try to resolve OOM issue in the CI + del data, reference_data + gc.collect() def test_deep_eval(self): self.coords = np.array( @@ -200,6 +204,8 @@ def test_deep_eval(self): atomic=True, ) rets_nopbc.append(ret) + del deep_eval + gc.collect() for ret in rets[1:]: for vv1, vv2 in zip(rets[0], ret): if np.isnan(vv2).all(): From 31f9e8772b4ee719591c2fdfa0bc5c1f229f748e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 13:02:21 -0500 Subject: [PATCH 12/31] limit threads during tests Signed-off-by: Jinzhe Zeng --- source/tests/jax/__init__.py | 9 +++++++++ source/tests/jax/jax2tf/__init__.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py index 6ceb116d85..52e6a17be2 100644 --- a/source/tests/jax/__init__.py +++ b/source/tests/jax/__init__.py @@ -1 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os + +os.environ["XLA_FLAGS"] = " ".join( + ( + "--xla_cpu_multi_thread_eigen=false", + "intra_op_parallelism_threads=1", + "inter_op_parallelism_threads=1", + ) +) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index 6ceb116d85..d5efa3d5dc 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -1 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf + +# limit the number of threads +tf.config.threading.set_inter_op_parallelism_threads(1) +tf.config.threading.set_intra_op_parallelism_threads(1) From 050c20c6044839795c4c56f1562a31b2bbb78bd0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 15:23:34 -0500 Subject: [PATCH 13/31] set xla flags before any imports Signed-off-by: Jinzhe Zeng --- source/tests/__init__.py | 10 ++++++++++ source/tests/jax/__init__.py | 9 --------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/source/tests/__init__.py b/source/tests/__init__.py index 6ceb116d85..5ca68af64d 100644 --- a/source/tests/__init__.py +++ b/source/tests/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os + +# set XLA FLAGS before any jax import +os.environ["XLA_FLAGS"] = " ".join( + ( + "--xla_cpu_multi_thread_eigen=false", + "intra_op_parallelism_threads=1", + "inter_op_parallelism_threads=1", + ) +) diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py index 52e6a17be2..6ceb116d85 100644 --- a/source/tests/jax/__init__.py +++ b/source/tests/jax/__init__.py @@ -1,10 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import os - -os.environ["XLA_FLAGS"] = " ".join( - ( - "--xla_cpu_multi_thread_eigen=false", - "intra_op_parallelism_threads=1", - "inter_op_parallelism_threads=1", - ) -) From 19f798c052a9c506837e4e241bbaf73916129497 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 15:25:06 -0500 Subject: [PATCH 14/31] set NPROC Signed-off-by: Jinzhe Zeng --- source/tests/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/__init__.py b/source/tests/__init__.py index 5ca68af64d..c0392efed3 100644 --- a/source/tests/__init__.py +++ b/source/tests/__init__.py @@ -9,3 +9,4 @@ "inter_op_parallelism_threads=1", ) ) +os.environ["NPROC"] = "1" From befc0c7f31bbe2c5795abfb1a788c9c8e5146ab1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 16:48:53 -0500 Subject: [PATCH 15/31] I don;t understand why the tests fail randomly agter I add some tests... Signed-off-by: Jinzhe Zeng --- source/tests/jax/jax2tf/__init__.py | 11 ++++++++--- source/tests/jax/jax2tf/test_nlist.py | 3 ++- source/tests/jax/jax2tf/test_region.py | 3 ++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index d5efa3d5dc..51fa555855 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -1,6 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf -# limit the number of threads -tf.config.threading.set_inter_op_parallelism_threads(1) -tf.config.threading.set_intra_op_parallelism_threads(1) +from ...common import ( + DP_TEST_TF2_ONLY, +) + +if DP_TEST_TF2_ONLY: + # limit the number of threads + tf.config.threading.set_inter_op_parallelism_threads(1) + tf.config.threading.set_intra_op_parallelism_threads(1) diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index 049d948f84..dadedd81a9 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -2,13 +2,14 @@ import pytest import tensorflow as tf -import tensorflow.experimental.numpy as tnp from ...utils import ( DP_TEST_TF2_ONLY, ) if DP_TEST_TF2_ONLY: + import tensorflow.experimental.numpy as tnp + from deepmd.jax.jax2tf.nlist import ( build_neighbor_list, extend_coord_with_ghosts, diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index f38b661401..4db7deeec5 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -3,7 +3,6 @@ import pytest import tensorflow as tf -import tensorflow.experimental.numpy as tnp from ...seed import ( GLOBAL_SEED, @@ -13,6 +12,8 @@ ) if DP_TEST_TF2_ONLY: + import tensorflow.experimental.numpy as tnp + from deepmd.jax.jax2tf.region import ( inter2phys, to_face_distance, From 89e8371df2636fd32a58f5d05c75703d7d389c98 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 16:57:26 -0500 Subject: [PATCH 16/31] typo Signed-off-by: Jinzhe Zeng --- source/tests/jax/jax2tf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index 51fa555855..755ffb04a4 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf -from ...common import ( +from ...utils import ( DP_TEST_TF2_ONLY, ) From 8570079c8ec10fa39fc3be8659c39e92aa6d4581 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 18:04:28 -0500 Subject: [PATCH 17/31] release memory? Signed-off-by: Jinzhe Zeng --- source/tests/conftest.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 source/tests/conftest.py diff --git a/source/tests/conftest.py b/source/tests/conftest.py new file mode 100644 index 0000000000..5faf1230bf --- /dev/null +++ b/source/tests/conftest.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import gc + +import pytest + + +@pytest.fixture(scope="package", autouse=True) +def automatic_memory_release(): + """Release memory after each package.""" + # pre + yield + # post + gc.collect() From 89041edc0a335b2e727993730704ea901d86f992 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 18:16:11 -0500 Subject: [PATCH 18/31] --cov-append --- .github/workflows/test_python.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index f03aa410f9..b7c43dbcc5 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -58,7 +58,7 @@ jobs: env: NUM_WORKERS: 0 - name: Test TF2 eager mode - run: pytest --cov=deepmd source/tests/consistent/io/test_io.py source/tests/jax/jax2tf --durations=0 + run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/tests/jax/jax2tf --durations=0 env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 From b9ff7555e61f0564f6011d19ce06968ce77d319b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 20:07:37 -0500 Subject: [PATCH 19/31] try scope module --- source/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/conftest.py b/source/tests/conftest.py index 5faf1230bf..c742bcd017 100644 --- a/source/tests/conftest.py +++ b/source/tests/conftest.py @@ -4,7 +4,7 @@ import pytest -@pytest.fixture(scope="package", autouse=True) +@pytest.fixture(scope="module", autouse=True) def automatic_memory_release(): """Release memory after each package.""" # pre From b1b8dd940069446da7ccceed480f47f52557a723 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:34:20 -0500 Subject: [PATCH 20/31] Revert "try scope module" This reverts commit b9ff7555e61f0564f6011d19ce06968ce77d319b. --- source/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/conftest.py b/source/tests/conftest.py index c742bcd017..5faf1230bf 100644 --- a/source/tests/conftest.py +++ b/source/tests/conftest.py @@ -4,7 +4,7 @@ import pytest -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(scope="package", autouse=True) def automatic_memory_release(): """Release memory after each package.""" # pre From 48dca4b56494a54f1adc6db2b1ebdf51353a84b2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:34:31 -0500 Subject: [PATCH 21/31] Revert "release memory?" This reverts commit 8570079c8ec10fa39fc3be8659c39e92aa6d4581. --- source/tests/conftest.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 source/tests/conftest.py diff --git a/source/tests/conftest.py b/source/tests/conftest.py deleted file mode 100644 index 5faf1230bf..0000000000 --- a/source/tests/conftest.py +++ /dev/null @@ -1,13 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import gc - -import pytest - - -@pytest.fixture(scope="package", autouse=True) -def automatic_memory_release(): - """Release memory after each package.""" - # pre - yield - # post - gc.collect() From 8d7a2301c31e0b959070a1123dccc984b88e12eb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:35:00 -0500 Subject: [PATCH 22/31] Revert "typo" This reverts commit 89e8371df2636fd32a58f5d05c75703d7d389c98. --- source/tests/jax/jax2tf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index 755ffb04a4..51fa555855 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf -from ...utils import ( +from ...common import ( DP_TEST_TF2_ONLY, ) From 5762a31899d23b9e58c94cdf5bdd19a786af6708 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:35:01 -0500 Subject: [PATCH 23/31] Revert "I don;t understand why the tests fail randomly agter I add some tests..." This reverts commit befc0c7f31bbe2c5795abfb1a788c9c8e5146ab1. --- source/tests/jax/jax2tf/__init__.py | 11 +++-------- source/tests/jax/jax2tf/test_nlist.py | 3 +-- source/tests/jax/jax2tf/test_region.py | 3 +-- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index 51fa555855..d5efa3d5dc 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -1,11 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf -from ...common import ( - DP_TEST_TF2_ONLY, -) - -if DP_TEST_TF2_ONLY: - # limit the number of threads - tf.config.threading.set_inter_op_parallelism_threads(1) - tf.config.threading.set_intra_op_parallelism_threads(1) +# limit the number of threads +tf.config.threading.set_inter_op_parallelism_threads(1) +tf.config.threading.set_intra_op_parallelism_threads(1) diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index dadedd81a9..049d948f84 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -2,14 +2,13 @@ import pytest import tensorflow as tf +import tensorflow.experimental.numpy as tnp from ...utils import ( DP_TEST_TF2_ONLY, ) if DP_TEST_TF2_ONLY: - import tensorflow.experimental.numpy as tnp - from deepmd.jax.jax2tf.nlist import ( build_neighbor_list, extend_coord_with_ghosts, diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index 4db7deeec5..f38b661401 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -3,6 +3,7 @@ import pytest import tensorflow as tf +import tensorflow.experimental.numpy as tnp from ...seed import ( GLOBAL_SEED, @@ -12,8 +13,6 @@ ) if DP_TEST_TF2_ONLY: - import tensorflow.experimental.numpy as tnp - from deepmd.jax.jax2tf.region import ( inter2phys, to_face_distance, From ee831e0334a034f3da2f021440576c8bda08096f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:35:06 -0500 Subject: [PATCH 24/31] Revert "set NPROC" This reverts commit 19f798c052a9c506837e4e241bbaf73916129497. --- source/tests/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/source/tests/__init__.py b/source/tests/__init__.py index c0392efed3..5ca68af64d 100644 --- a/source/tests/__init__.py +++ b/source/tests/__init__.py @@ -9,4 +9,3 @@ "inter_op_parallelism_threads=1", ) ) -os.environ["NPROC"] = "1" From e6953c948dc7bf0008cb210efe54571b5aad8a73 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:35:21 -0500 Subject: [PATCH 25/31] Revert "set xla flags before any imports" This reverts commit 050c20c6044839795c4c56f1562a31b2bbb78bd0. --- source/tests/__init__.py | 10 ---------- source/tests/jax/__init__.py | 9 +++++++++ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/source/tests/__init__.py b/source/tests/__init__.py index 5ca68af64d..6ceb116d85 100644 --- a/source/tests/__init__.py +++ b/source/tests/__init__.py @@ -1,11 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import os - -# set XLA FLAGS before any jax import -os.environ["XLA_FLAGS"] = " ".join( - ( - "--xla_cpu_multi_thread_eigen=false", - "intra_op_parallelism_threads=1", - "inter_op_parallelism_threads=1", - ) -) diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py index 6ceb116d85..52e6a17be2 100644 --- a/source/tests/jax/__init__.py +++ b/source/tests/jax/__init__.py @@ -1 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os + +os.environ["XLA_FLAGS"] = " ".join( + ( + "--xla_cpu_multi_thread_eigen=false", + "intra_op_parallelism_threads=1", + "inter_op_parallelism_threads=1", + ) +) From c365cb64e8512ca66bc1e1695a05e52056338572 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:35:22 -0500 Subject: [PATCH 26/31] Revert "limit threads during tests" This reverts commit 31f9e8772b4ee719591c2fdfa0bc5c1f229f748e. --- source/tests/jax/__init__.py | 9 --------- source/tests/jax/jax2tf/__init__.py | 5 ----- 2 files changed, 14 deletions(-) diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py index 52e6a17be2..6ceb116d85 100644 --- a/source/tests/jax/__init__.py +++ b/source/tests/jax/__init__.py @@ -1,10 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import os - -os.environ["XLA_FLAGS"] = " ".join( - ( - "--xla_cpu_multi_thread_eigen=false", - "intra_op_parallelism_threads=1", - "inter_op_parallelism_threads=1", - ) -) diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py index d5efa3d5dc..6ceb116d85 100644 --- a/source/tests/jax/jax2tf/__init__.py +++ b/source/tests/jax/jax2tf/__init__.py @@ -1,6 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import tensorflow as tf - -# limit the number of threads -tf.config.threading.set_inter_op_parallelism_threads(1) -tf.config.threading.set_intra_op_parallelism_threads(1) From c94302dc5044da9391b6eb4d910bbf20d4aa0094 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:36:05 -0500 Subject: [PATCH 27/31] Revert "try to resolve OOM issue" This reverts commit fc7b6b7a1198a3fe8fad1ba0fb8ba14d3ea39309. --- source/tests/consistent/io/test_io.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 39bbf8056e..8eb26e7ac3 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy -import gc import shutil import unittest from pathlib import ( @@ -109,9 +108,6 @@ def test_data_equal(self): data.pop(kk, None) reference_data.pop(kk, None) np.testing.assert_equal(data, reference_data) - # try to resolve OOM issue in the CI - del data, reference_data - gc.collect() def test_deep_eval(self): self.coords = np.array( @@ -204,8 +200,6 @@ def test_deep_eval(self): atomic=True, ) rets_nopbc.append(ret) - del deep_eval - gc.collect() for ret in rets[1:]: for vv1, vv2 in zip(rets[0], ret): if np.isnan(vv2).all(): From b0a496cfafcb0dbc73a27510b3a430bfeaabcb77 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:38:53 -0500 Subject: [PATCH 28/31] move the time-comsuming test out of main test Signed-off-by: Jinzhe Zeng --- source/tests/consistent/io/test_io.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 8eb26e7ac3..bc9103c56e 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -73,14 +73,12 @@ def tearDown(self): shutil.rmtree(ii) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") - @unittest.skipIf(DP_TEST_TF2_ONLY, "Conflict with TF2 eager mode.") def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name, suffix_idx in ( - ("tensorflow", 0), + ("tensorflow", 0) if not DP_TEST_TF2_ONLY else ("jax", 0), ("pytorch", 0), ("dpmodel", 0), - ("jax", 0), ): with self.subTest(backend_name=backend_name): backend = Backend.get_backend(backend_name)() @@ -148,8 +146,10 @@ def test_deep_eval(self): ("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0), ("pytorch", 0), ("dpmodel", 0), - ("jax", 0), + ("jax", 0) if DP_TEST_TF2_ONLY else (None, None), ): + if backend_name is None: + continue backend = Backend.get_backend(backend_name)() if not backend.is_available(): continue From 4e42105ea44b1bad97d0d2419315b41e4901460a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 23:57:23 -0500 Subject: [PATCH 29/31] try Signed-off-by: Jinzhe Zeng --- source/tests/jax/jax2tf/test_nlist.py | 5 +++-- source/tests/jax/jax2tf/test_region.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index 049d948f84..28c5a0dfc8 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import pytest +import unittest + import tensorflow as tf import tensorflow.experimental.numpy as tnp @@ -20,7 +21,7 @@ dtype = tnp.float64 -@pytest.mark.skipif( +@unittest.skipIf( not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", ) diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index f38b661401..a6baffcb33 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import pytest +import unittest + import tensorflow as tf import tensorflow.experimental.numpy as tnp @@ -19,7 +20,7 @@ ) -@pytest.mark.skipif( +@unittest.skipIf( not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", ) From 215efc6d47e059dfe1d467cb8ce7988f4204414f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 9 Nov 2024 01:17:08 -0500 Subject: [PATCH 30/31] try to move jax2tf_tests to a different directory Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 2 +- pyproject.toml | 1 + source/{tests/jax/jax2tf => jax2tf_tests}/__init__.py | 0 source/{tests/jax/jax2tf => jax2tf_tests}/test_nlist.py | 0 source/{tests/jax/jax2tf => jax2tf_tests}/test_region.py | 0 5 files changed, 2 insertions(+), 1 deletion(-) rename source/{tests/jax/jax2tf => jax2tf_tests}/__init__.py (100%) rename source/{tests/jax/jax2tf => jax2tf_tests}/test_nlist.py (100%) rename source/{tests/jax/jax2tf => jax2tf_tests}/test_region.py (100%) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index b7c43dbcc5..ba8858d6b9 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -58,7 +58,7 @@ jobs: env: NUM_WORKERS: 0 - name: Test TF2 eager mode - run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/tests/jax/jax2tf --durations=0 + run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/jax2tf_tests --durations=0 env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 diff --git a/pyproject.toml b/pyproject.toml index b238ed42cf..cf2c8c3b93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -424,6 +424,7 @@ banned-module-level-imports = [ "source/tests/pt/**" = ["TID253"] "source/tests/jax/**" = ["TID253"] "source/tests/universal/pt/**" = ["TID253"] +"source/jax2tf_tests/**" = ["TID253"] "source/ipi/tests/**" = ["TID253"] "source/lmp/tests/**" = ["TID253"] "**/*.ipynb" = ["T20"] # printing in a nb file is expected diff --git a/source/tests/jax/jax2tf/__init__.py b/source/jax2tf_tests/__init__.py similarity index 100% rename from source/tests/jax/jax2tf/__init__.py rename to source/jax2tf_tests/__init__.py diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/jax2tf_tests/test_nlist.py similarity index 100% rename from source/tests/jax/jax2tf/test_nlist.py rename to source/jax2tf_tests/test_nlist.py diff --git a/source/tests/jax/jax2tf/test_region.py b/source/jax2tf_tests/test_region.py similarity index 100% rename from source/tests/jax/jax2tf/test_region.py rename to source/jax2tf_tests/test_region.py From e4bac358381cab0fef18157ca89d957fcbe567e2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 9 Nov 2024 02:05:37 -0500 Subject: [PATCH 31/31] clean up Signed-off-by: Jinzhe Zeng --- source/jax2tf_tests/test_nlist.py | 24 +++++++----------------- source/jax2tf_tests/test_region.py | 20 ++++---------------- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/source/jax2tf_tests/test_nlist.py b/source/jax2tf_tests/test_nlist.py index 28c5a0dfc8..5b13e4231c 100644 --- a/source/jax2tf_tests/test_nlist.py +++ b/source/jax2tf_tests/test_nlist.py @@ -1,30 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest import tensorflow as tf import tensorflow.experimental.numpy as tnp -from ...utils import ( - DP_TEST_TF2_ONLY, +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.jax.jax2tf.region import ( + inter2phys, ) -if DP_TEST_TF2_ONLY: - from deepmd.jax.jax2tf.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, - ) - from deepmd.jax.jax2tf.region import ( - inter2phys, - ) - - dtype = tnp.float64 +dtype = tnp.float64 -@unittest.skipIf( - not DP_TEST_TF2_ONLY, - reason="TF2 conflicts with TF1", -) class TestNeighList(tf.test.TestCase): def setUp(self): self.nf = 3 diff --git a/source/jax2tf_tests/test_region.py b/source/jax2tf_tests/test_region.py index a6baffcb33..2becf08c94 100644 --- a/source/jax2tf_tests/test_region.py +++ b/source/jax2tf_tests/test_region.py @@ -1,29 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest - import tensorflow as tf import tensorflow.experimental.numpy as tnp -from ...seed import ( - GLOBAL_SEED, -) -from ...utils import ( - DP_TEST_TF2_ONLY, +from deepmd.jax.jax2tf.region import ( + inter2phys, + to_face_distance, ) -if DP_TEST_TF2_ONLY: - from deepmd.jax.jax2tf.region import ( - inter2phys, - to_face_distance, - ) +GLOBAL_SEED = 20241109 -@unittest.skipIf( - not DP_TEST_TF2_ONLY, - reason="TF2 conflicts with TF1", -) class TestRegion(tf.test.TestCase): def setUp(self): self.cell = tnp.array(