Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Nov 7, 2024
1 parent 84cb819 commit bd27d4f
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
env:
NUM_WORKERS: 0
- name: Test TF2 eager mode
run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0
run: pytest --cov=deepmd source/tests/consistent/io/test_io.py source/tests/jax/jax2tf --durations=0
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,10 @@ convention = "numpy"
banned-module-level-imports = [
"deepmd.tf",
"deepmd.pt",
"deepmd.jax",
"tensorflow",
"torch",
"jax",
]

[tool.ruff.lint.flake8-tidy-imports.banned-api]
Expand All @@ -419,6 +421,7 @@ banned-module-level-imports = [
"deepmd/jax/**" = ["TID253"]
"source/tests/tf/**" = ["TID253"]
"source/tests/pt/**" = ["TID253"]
"source/tests/jax/**" = ["TID253"]
"source/tests/universal/pt/**" = ["TID253"]
"source/ipi/tests/**" = ["TID253"]
"source/lmp/tests/**" = ["TID253"]
Expand Down
1 change: 1 addition & 0 deletions source/tests/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
8 changes: 8 additions & 0 deletions source/tests/jax/jax2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import pytest

from ...utils import (
DP_TEST_TF2_ONLY,
)

pytest.mark.skipif(not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1")
153 changes: 153 additions & 0 deletions source/tests/jax/jax2tf/test_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.jax.jax2tf.region import (
inter2phys,
)

dtype = tnp.float64


class TestNeighList(tf.test.TestCase):
def setUp(self):
self.nf = 3
self.nloc = 3
self.ns = 5 * 5 * 3
self.nall = self.ns * self.nloc
self.cell = tnp.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype)
self.icoord = tnp.array([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype)
self.atype = tnp.array([-1, 0, 1], dtype=tnp.int32)
[self.cell, self.icoord, self.atype] = [
tnp.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype]
]
self.coord = inter2phys(self.icoord, self.cell).reshape([-1, self.nloc * 3])
self.cell = self.cell.reshape([-1, 9])
[self.cell, self.coord, self.atype] = [
tnp.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype]
]
self.rcut = 1.01
self.prec = 1e-10
self.nsel = [10, 10]
self.ref_nlist = tnp.array(
[
[-1] * sum(self.nsel),
[1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1],
[1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1],
]
)

def test_build_notype(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
nlist = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=False,
)
self.assertAllClose(nlist[0], nlist[1])
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc)
self.assertAllClose(
tnp.sort(nlist_loc, axis=-1),
tnp.sort(self.ref_nlist, axis=-1),
)

def test_build_type(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
nlist = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
self.nsel,
distinguish_types=True,
)
self.assertAllClose(nlist[0], nlist[1])
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc)
for ii in range(2):
self.assertAllClose(
tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1),
tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1),
)

def test_extend_coord(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
# expected ncopy x nloc
self.assertEqual(list(ecoord.shape), [self.nf, self.nall * 3])
self.assertEqual(list(eatype.shape), [self.nf, self.nall])
self.assertEqual(list(mapping.shape), [self.nf, self.nall])
# check the nloc part is identical with original coord
self.assertAllClose(
ecoord[:, : self.nloc * 3], self.coord, rtol=self.prec, atol=self.prec
)
# check the shift vectors are aligned with grid
shift_vec = (
ecoord.reshape([-1, self.ns, self.nloc, 3])
- self.coord.reshape([-1, self.nloc, 3])[:, None, :, :]
)
shift_vec = shift_vec.reshape([-1, self.nall, 3])
# hack!!! assumes identical cell across frames
shift_vec = tnp.matmul(
shift_vec, tf.linalg.inv(self.cell.reshape([self.nf, 3, 3])[0])
)
# nf x nall x 3
shift_vec = tnp.round(shift_vec)
# check: identical shift vecs
self.assertAllClose(shift_vec[0], shift_vec[1], rtol=self.prec, atol=self.prec)
# check: shift idx aligned with grid
mm, _, cc = tf.unique_with_counts(shift_vec[0][:, 0])
self.assertAllClose(
tnp.sort(mm),
tnp.array([-2, -1, 0, 1, 2], dtype=dtype),
rtol=self.prec,
atol=self.prec,
)
self.assertAllClose(
cc,
tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32),
rtol=self.prec,
atol=self.prec,
)
mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 1])
self.assertAllClose(
tnp.sort(mm),
tnp.array([-2, -1, 0, 1, 2], dtype=dtype),
rtol=self.prec,
atol=self.prec,
)
self.assertAllClose(
cc,
tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32),
rtol=self.prec,
atol=self.prec,
)
mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 2])
self.assertAllClose(
tnp.sort(mm),
tnp.array([-1, 0, 1], dtype=dtype),
rtol=self.prec,
atol=self.prec,
)
self.assertAllClose(
cc,
tnp.array([self.ns * self.nloc // 3] * 3, dtype=tnp.int32),
rtol=self.prec,
atol=self.prec,
)
53 changes: 53 additions & 0 deletions source/tests/jax/jax2tf/test_region.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.jax.jax2tf.region import (
inter2phys,
to_face_distance,
)

from ...seed import (
GLOBAL_SEED,
)


class TestRegion(tf.test.TestCase):
def setUp(self):
self.cell = tnp.array(
[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]],
)
self.cell = tnp.reshape(self.cell, [1, 1, -1, 3])
self.cell = tnp.tile(self.cell, [4, 5, 1, 1])
self.prec = 1e-8

def test_inter_to_phys(self):
rng = tf.random.Generator.from_seed(GLOBAL_SEED)
inter = rng.normal(shape=[4, 5, 3, 3])
phys = inter2phys(inter, self.cell)
for ii in range(4):
for jj in range(5):
expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj])
self.assertAllClose(
phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec
)

def test_to_face_dist(self):
cell0 = self.cell[0][0]
vol = tf.linalg.det(cell0)
# area of surfaces xy, xz, yz
sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1]))
sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2]))
syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2]))
# vol / area gives distance
dz = vol / sxy
dy = vol / sxz
dx = vol / syz
expected = tnp.array([dx, dy, dz])
dists = to_face_distance(self.cell)
for ii in range(4):
for jj in range(5):
self.assertAllClose(
dists[ii][jj], expected, rtol=self.prec, atol=self.prec
)

0 comments on commit bd27d4f

Please sign in to comment.