-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
- Loading branch information
Showing
6 changed files
with
219 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import pytest | ||
|
||
from ...utils import ( | ||
DP_TEST_TF2_ONLY, | ||
) | ||
|
||
pytest.mark.skipif(not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
import tensorflow as tf | ||
import tensorflow.experimental.numpy as tnp | ||
|
||
from deepmd.jax.jax2tf.nlist import ( | ||
build_neighbor_list, | ||
extend_coord_with_ghosts, | ||
) | ||
from deepmd.jax.jax2tf.region import ( | ||
inter2phys, | ||
) | ||
|
||
dtype = tnp.float64 | ||
|
||
|
||
class TestNeighList(tf.test.TestCase): | ||
def setUp(self): | ||
self.nf = 3 | ||
self.nloc = 3 | ||
self.ns = 5 * 5 * 3 | ||
self.nall = self.ns * self.nloc | ||
self.cell = tnp.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype) | ||
self.icoord = tnp.array([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) | ||
self.atype = tnp.array([-1, 0, 1], dtype=tnp.int32) | ||
[self.cell, self.icoord, self.atype] = [ | ||
tnp.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] | ||
] | ||
self.coord = inter2phys(self.icoord, self.cell).reshape([-1, self.nloc * 3]) | ||
self.cell = self.cell.reshape([-1, 9]) | ||
[self.cell, self.coord, self.atype] = [ | ||
tnp.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] | ||
] | ||
self.rcut = 1.01 | ||
self.prec = 1e-10 | ||
self.nsel = [10, 10] | ||
self.ref_nlist = tnp.array( | ||
[ | ||
[-1] * sum(self.nsel), | ||
[1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], | ||
[1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1], | ||
] | ||
) | ||
|
||
def test_build_notype(self): | ||
ecoord, eatype, mapping = extend_coord_with_ghosts( | ||
self.coord, self.atype, self.cell, self.rcut | ||
) | ||
nlist = build_neighbor_list( | ||
ecoord, | ||
eatype, | ||
self.nloc, | ||
self.rcut, | ||
sum(self.nsel), | ||
distinguish_types=False, | ||
) | ||
self.assertAllClose(nlist[0], nlist[1]) | ||
nlist_mask = nlist[0] == -1 | ||
nlist_loc = mapping[0][nlist[0]] | ||
nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) | ||
self.assertAllClose( | ||
tnp.sort(nlist_loc, axis=-1), | ||
tnp.sort(self.ref_nlist, axis=-1), | ||
) | ||
|
||
def test_build_type(self): | ||
ecoord, eatype, mapping = extend_coord_with_ghosts( | ||
self.coord, self.atype, self.cell, self.rcut | ||
) | ||
nlist = build_neighbor_list( | ||
ecoord, | ||
eatype, | ||
self.nloc, | ||
self.rcut, | ||
self.nsel, | ||
distinguish_types=True, | ||
) | ||
self.assertAllClose(nlist[0], nlist[1]) | ||
nlist_mask = nlist[0] == -1 | ||
nlist_loc = mapping[0][nlist[0]] | ||
nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) | ||
for ii in range(2): | ||
self.assertAllClose( | ||
tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), | ||
tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), | ||
) | ||
|
||
def test_extend_coord(self): | ||
ecoord, eatype, mapping = extend_coord_with_ghosts( | ||
self.coord, self.atype, self.cell, self.rcut | ||
) | ||
# expected ncopy x nloc | ||
self.assertEqual(list(ecoord.shape), [self.nf, self.nall * 3]) | ||
self.assertEqual(list(eatype.shape), [self.nf, self.nall]) | ||
self.assertEqual(list(mapping.shape), [self.nf, self.nall]) | ||
# check the nloc part is identical with original coord | ||
self.assertAllClose( | ||
ecoord[:, : self.nloc * 3], self.coord, rtol=self.prec, atol=self.prec | ||
) | ||
# check the shift vectors are aligned with grid | ||
shift_vec = ( | ||
ecoord.reshape([-1, self.ns, self.nloc, 3]) | ||
- self.coord.reshape([-1, self.nloc, 3])[:, None, :, :] | ||
) | ||
shift_vec = shift_vec.reshape([-1, self.nall, 3]) | ||
# hack!!! assumes identical cell across frames | ||
shift_vec = tnp.matmul( | ||
shift_vec, tf.linalg.inv(self.cell.reshape([self.nf, 3, 3])[0]) | ||
) | ||
# nf x nall x 3 | ||
shift_vec = tnp.round(shift_vec) | ||
# check: identical shift vecs | ||
self.assertAllClose(shift_vec[0], shift_vec[1], rtol=self.prec, atol=self.prec) | ||
# check: shift idx aligned with grid | ||
mm, _, cc = tf.unique_with_counts(shift_vec[0][:, 0]) | ||
self.assertAllClose( | ||
tnp.sort(mm), | ||
tnp.array([-2, -1, 0, 1, 2], dtype=dtype), | ||
rtol=self.prec, | ||
atol=self.prec, | ||
) | ||
self.assertAllClose( | ||
cc, | ||
tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), | ||
rtol=self.prec, | ||
atol=self.prec, | ||
) | ||
mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 1]) | ||
self.assertAllClose( | ||
tnp.sort(mm), | ||
tnp.array([-2, -1, 0, 1, 2], dtype=dtype), | ||
rtol=self.prec, | ||
atol=self.prec, | ||
) | ||
self.assertAllClose( | ||
cc, | ||
tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), | ||
rtol=self.prec, | ||
atol=self.prec, | ||
) | ||
mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 2]) | ||
self.assertAllClose( | ||
tnp.sort(mm), | ||
tnp.array([-1, 0, 1], dtype=dtype), | ||
rtol=self.prec, | ||
atol=self.prec, | ||
) | ||
self.assertAllClose( | ||
cc, | ||
tnp.array([self.ns * self.nloc // 3] * 3, dtype=tnp.int32), | ||
rtol=self.prec, | ||
atol=self.prec, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
import tensorflow as tf | ||
import tensorflow.experimental.numpy as tnp | ||
|
||
from deepmd.jax.jax2tf.region import ( | ||
inter2phys, | ||
to_face_distance, | ||
) | ||
|
||
from ...seed import ( | ||
GLOBAL_SEED, | ||
) | ||
|
||
|
||
class TestRegion(tf.test.TestCase): | ||
def setUp(self): | ||
self.cell = tnp.array( | ||
[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], | ||
) | ||
self.cell = tnp.reshape(self.cell, [1, 1, -1, 3]) | ||
self.cell = tnp.tile(self.cell, [4, 5, 1, 1]) | ||
self.prec = 1e-8 | ||
|
||
def test_inter_to_phys(self): | ||
rng = tf.random.Generator.from_seed(GLOBAL_SEED) | ||
inter = rng.normal(shape=[4, 5, 3, 3]) | ||
phys = inter2phys(inter, self.cell) | ||
for ii in range(4): | ||
for jj in range(5): | ||
expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj]) | ||
self.assertAllClose( | ||
phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec | ||
) | ||
|
||
def test_to_face_dist(self): | ||
cell0 = self.cell[0][0] | ||
vol = tf.linalg.det(cell0) | ||
# area of surfaces xy, xz, yz | ||
sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1])) | ||
sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2])) | ||
syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2])) | ||
# vol / area gives distance | ||
dz = vol / sxy | ||
dy = vol / sxz | ||
dx = vol / syz | ||
expected = tnp.array([dx, dy, dz]) | ||
dists = to_face_distance(self.cell) | ||
for ii in range(4): | ||
for jj in range(5): | ||
self.assertAllClose( | ||
dists[ii][jj], expected, rtol=self.prec, atol=self.prec | ||
) |