diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index c3fe3d034d..3b380bc2cc 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -1,20 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import pytest +import tensorflow as tf +import tensorflow.experimental.numpy as tnp from ...utils import ( DP_TEST_TF2_ONLY, ) -pytest.mark.skipif( - not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True -) - - if DP_TEST_TF2_ONLY: - import tensorflow as tf - import tensorflow.experimental.numpy as tnp - from deepmd.jax.jax2tf.nlist import ( build_neighbor_list, extend_coord_with_ghosts, @@ -26,6 +20,11 @@ dtype = tnp.float64 +pytest.mark.skipif( + not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True +) + + class TestNeighList(tf.test.TestCase): def setUp(self): self.nf = 3 diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index b7e3a7d89f..ce2536ddb0 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -2,28 +2,25 @@ import pytest +import tensorflow as tf +import tensorflow.experimental.numpy as tnp +from ...seed import ( + GLOBAL_SEED, +) from ...utils import ( DP_TEST_TF2_ONLY, ) -pytest.mark.skipif( - not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True -) - - if DP_TEST_TF2_ONLY: - import tensorflow as tf - import tensorflow.experimental.numpy as tnp - from deepmd.jax.jax2tf.region import ( inter2phys, to_face_distance, ) - from ...seed import ( - GLOBAL_SEED, - ) +pytest.mark.skipif( + not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", allow_module_level=True +) class TestRegion(tf.test.TestCase):