From bd27d4f7f730bcda0401759b68b7857a280e35a1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 02:56:16 -0500 Subject: [PATCH] add tests Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 2 +- pyproject.toml | 3 + source/tests/jax/__init__.py | 1 + source/tests/jax/jax2tf/__init__.py | 8 ++ source/tests/jax/jax2tf/test_nlist.py | 153 +++++++++++++++++++++++++ source/tests/jax/jax2tf/test_region.py | 53 +++++++++ 6 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 source/tests/jax/__init__.py create mode 100644 source/tests/jax/jax2tf/__init__.py create mode 100644 source/tests/jax/jax2tf/test_nlist.py create mode 100644 source/tests/jax/jax2tf/test_region.py diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 422dcb5f17..f03aa410f9 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -58,7 +58,7 @@ jobs: env: NUM_WORKERS: 0 - name: Test TF2 eager mode - run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0 + run: pytest --cov=deepmd source/tests/consistent/io/test_io.py source/tests/jax/jax2tf --durations=0 env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 diff --git a/pyproject.toml b/pyproject.toml index 802e920014..f911f6905c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -405,8 +405,10 @@ convention = "numpy" banned-module-level-imports = [ "deepmd.tf", "deepmd.pt", + "deepmd.jax", "tensorflow", "torch", + "jax", ] [tool.ruff.lint.flake8-tidy-imports.banned-api] @@ -419,6 +421,7 @@ banned-module-level-imports = [ "deepmd/jax/**" = ["TID253"] "source/tests/tf/**" = ["TID253"] "source/tests/pt/**" = ["TID253"] +"source/tests/jax/**" = ["TID253"] "source/tests/universal/pt/**" = ["TID253"] "source/ipi/tests/**" = ["TID253"] "source/lmp/tests/**" = ["TID253"] diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/jax/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/jax/jax2tf/__init__.py b/source/tests/jax/jax2tf/__init__.py new file mode 100644 index 0000000000..3c27417989 --- /dev/null +++ b/source/tests/jax/jax2tf/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +from ...utils import ( + DP_TEST_TF2_ONLY, +) + +pytest.mark.skipif(not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1") diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py new file mode 100644 index 0000000000..feb8deb92b --- /dev/null +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.jax.jax2tf.region import ( + inter2phys, +) + +dtype = tnp.float64 + + +class TestNeighList(tf.test.TestCase): + def setUp(self): + self.nf = 3 + self.nloc = 3 + self.ns = 5 * 5 * 3 + self.nall = self.ns * self.nloc + self.cell = tnp.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype) + self.icoord = tnp.array([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) + self.atype = tnp.array([-1, 0, 1], dtype=tnp.int32) + [self.cell, self.icoord, self.atype] = [ + tnp.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] + ] + self.coord = inter2phys(self.icoord, self.cell).reshape([-1, self.nloc * 3]) + self.cell = self.cell.reshape([-1, 9]) + [self.cell, self.coord, self.atype] = [ + tnp.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] + ] + self.rcut = 1.01 + self.prec = 1e-10 + self.nsel = [10, 10] + self.ref_nlist = tnp.array( + [ + [-1] * sum(self.nsel), + [1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1], + ] + ) + + def test_build_notype(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + sum(self.nsel), + distinguish_types=False, + ) + self.assertAllClose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + self.assertAllClose( + tnp.sort(nlist_loc, axis=-1), + tnp.sort(self.ref_nlist, axis=-1), + ) + + def test_build_type(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + self.nsel, + distinguish_types=True, + ) + self.assertAllClose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + for ii in range(2): + self.assertAllClose( + tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), + tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), + ) + + def test_extend_coord(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + # expected ncopy x nloc + self.assertEqual(list(ecoord.shape), [self.nf, self.nall * 3]) + self.assertEqual(list(eatype.shape), [self.nf, self.nall]) + self.assertEqual(list(mapping.shape), [self.nf, self.nall]) + # check the nloc part is identical with original coord + self.assertAllClose( + ecoord[:, : self.nloc * 3], self.coord, rtol=self.prec, atol=self.prec + ) + # check the shift vectors are aligned with grid + shift_vec = ( + ecoord.reshape([-1, self.ns, self.nloc, 3]) + - self.coord.reshape([-1, self.nloc, 3])[:, None, :, :] + ) + shift_vec = shift_vec.reshape([-1, self.nall, 3]) + # hack!!! assumes identical cell across frames + shift_vec = tnp.matmul( + shift_vec, tf.linalg.inv(self.cell.reshape([self.nf, 3, 3])[0]) + ) + # nf x nall x 3 + shift_vec = tnp.round(shift_vec) + # check: identical shift vecs + self.assertAllClose(shift_vec[0], shift_vec[1], rtol=self.prec, atol=self.prec) + # check: shift idx aligned with grid + mm, _, cc = tf.unique_with_counts(shift_vec[0][:, 0]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 1]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 2]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-1, 0, 1], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 3] * 3, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py new file mode 100644 index 0000000000..54286d4dac --- /dev/null +++ b/source/tests/jax/jax2tf/test_region.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.region import ( + inter2phys, + to_face_distance, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +class TestRegion(tf.test.TestCase): + def setUp(self): + self.cell = tnp.array( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], + ) + self.cell = tnp.reshape(self.cell, [1, 1, -1, 3]) + self.cell = tnp.tile(self.cell, [4, 5, 1, 1]) + self.prec = 1e-8 + + def test_inter_to_phys(self): + rng = tf.random.Generator.from_seed(GLOBAL_SEED) + inter = rng.normal(shape=[4, 5, 3, 3]) + phys = inter2phys(inter, self.cell) + for ii in range(4): + for jj in range(5): + expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj]) + self.assertAllClose( + phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec + ) + + def test_to_face_dist(self): + cell0 = self.cell[0][0] + vol = tf.linalg.det(cell0) + # area of surfaces xy, xz, yz + sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1])) + sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2])) + syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2])) + # vol / area gives distance + dz = vol / sxy + dy = vol / sxz + dx = vol / syz + expected = tnp.array([dx, dy, dz]) + dists = to_face_distance(self.cell) + for ii in range(4): + for jj in range(5): + self.assertAllClose( + dists[ii][jj], expected, rtol=self.prec, atol=self.prec + )