From e4bac358381cab0fef18157ca89d957fcbe567e2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 9 Nov 2024 02:05:37 -0500 Subject: [PATCH] clean up Signed-off-by: Jinzhe Zeng --- source/jax2tf_tests/test_nlist.py | 24 +++++++----------------- source/jax2tf_tests/test_region.py | 20 ++++---------------- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/source/jax2tf_tests/test_nlist.py b/source/jax2tf_tests/test_nlist.py index 28c5a0dfc8..5b13e4231c 100644 --- a/source/jax2tf_tests/test_nlist.py +++ b/source/jax2tf_tests/test_nlist.py @@ -1,30 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest import tensorflow as tf import tensorflow.experimental.numpy as tnp -from ...utils import ( - DP_TEST_TF2_ONLY, +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.jax.jax2tf.region import ( + inter2phys, ) -if DP_TEST_TF2_ONLY: - from deepmd.jax.jax2tf.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, - ) - from deepmd.jax.jax2tf.region import ( - inter2phys, - ) - - dtype = tnp.float64 +dtype = tnp.float64 -@unittest.skipIf( - not DP_TEST_TF2_ONLY, - reason="TF2 conflicts with TF1", -) class TestNeighList(tf.test.TestCase): def setUp(self): self.nf = 3 diff --git a/source/jax2tf_tests/test_region.py b/source/jax2tf_tests/test_region.py index a6baffcb33..2becf08c94 100644 --- a/source/jax2tf_tests/test_region.py +++ b/source/jax2tf_tests/test_region.py @@ -1,29 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest - import tensorflow as tf import tensorflow.experimental.numpy as tnp -from ...seed import ( - GLOBAL_SEED, -) -from ...utils import ( - DP_TEST_TF2_ONLY, +from deepmd.jax.jax2tf.region import ( + inter2phys, + to_face_distance, ) -if DP_TEST_TF2_ONLY: - from deepmd.jax.jax2tf.region import ( - inter2phys, - to_face_distance, - ) +GLOBAL_SEED = 20241109 -@unittest.skipIf( - not DP_TEST_TF2_ONLY, - reason="TF2 conflicts with TF1", -) class TestRegion(tf.test.TestCase): def setUp(self): self.cell = tnp.array(