Skip to content

Commit

Permalink
nopbc
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Nov 6, 2024
1 parent 933e4df commit 94d2054
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 9 deletions.
6 changes: 6 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def __init__(
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
stablehlo_no_ghost=model_data["@variables"][
"stablehlo_no_ghost"
].tobytes(),
stablehlo_atomic_virial_no_ghost=model_data["@variables"][
"stablehlo_atomic_virial_no_ghost"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
Expand Down
18 changes: 15 additions & 3 deletions deepmd/jax/model/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(
self,
stablehlo,
stablehlo_atomic_virial,
stablehlo_no_ghost,
stablehlo_atomic_virial_no_ghost,
model_def_script,
type_map,
rcut,
Expand All @@ -62,6 +64,10 @@ def __init__(
self._call_lower_atomic_virial = jax_export.deserialize(
stablehlo_atomic_virial
).call
self._call_lower_no_ghost = jax_export.deserialize(stablehlo_no_ghost).call
self._call_lower_atomic_virial_no_ghost = jax_export.deserialize(
stablehlo_atomic_virial_no_ghost
).call
self.stablehlo = stablehlo
self.type_map = type_map
self.rcut = rcut
Expand Down Expand Up @@ -174,10 +180,16 @@ def call_lower(
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
if do_atomic_virial:
call_lower = self._call_lower_atomic_virial
if extended_coord.shape[1] > nlist.shape[1]:
if do_atomic_virial:
call_lower = self._call_lower_atomic_virial
else:
call_lower = self._call_lower
else:
call_lower = self._call_lower
if do_atomic_virial:
call_lower = self._call_lower_atomic_virial_no_ghost
else:
call_lower = self._call_lower_no_ghost
return call_lower(
extended_coord,
extended_atype,
Expand Down
35 changes: 29 additions & 6 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None:

nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")

def exported_whether_do_atomic_virial(do_atomic_virial):
def exported_whether_do_atomic_virial(
do_atomic_virial: bool, has_ghost_atoms: bool
):
def call_lower_with_fixed_do_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
Expand All @@ -67,13 +69,18 @@ def call_lower_with_fixed_do_atomic_virial(
do_atomic_virial=do_atomic_virial,
)

if has_ghost_atoms:
nghost_ = nghost
else:
nghost_ = 0

return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))(
jax.ShapeDtypeStruct(
(nf, nloc + nghost, 3), jnp.float64
(nf, nloc + nghost_, 3), jnp.float64
), # extended_coord
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
if model.get_dim_fparam()
else None, # fparam
Expand All @@ -82,18 +89,34 @@ def call_lower_with_fixed_do_atomic_virial(
else None, # aparam
)

exported = exported_whether_do_atomic_virial(do_atomic_virial=False)
exported = exported_whether_do_atomic_virial(
do_atomic_virial=False, has_ghost_atoms=True
)
exported_atomic_virial = exported_whether_do_atomic_virial(
do_atomic_virial=True
do_atomic_virial=True, has_ghost_atoms=True
)
serialized: bytearray = exported.serialize()
serialized_atomic_virial = exported_atomic_virial.serialize()

exported_no_ghost = exported_whether_do_atomic_virial(
do_atomic_virial=False, has_ghost_atoms=False
)
exported_atomic_virial_no_ghost = exported_whether_do_atomic_virial(
do_atomic_virial=True, has_ghost_atoms=False
)
serialized_no_ghost: bytearray = exported_no_ghost.serialize()
serialized_atomic_virial_no_ghost = exported_atomic_virial_no_ghost.serialize()

data = data.copy()
data.setdefault("@variables", {})
data["@variables"]["stablehlo"] = np.void(serialized)
data["@variables"]["stablehlo_atomic_virial"] = np.void(
serialized_atomic_virial
)
data["@variables"]["stablehlo_no_ghost"] = np.void(serialized_no_ghost)
data["@variables"]["stablehlo_atomic_virial_no_ghost"] = np.void(
serialized_atomic_virial_no_ghost
)
data["constants"] = {
"type_map": model.get_type_map(),
"rcut": model.get_rcut(),
Expand Down
27 changes: 27 additions & 0 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_deep_eval(self):
nframes = self.atype.shape[0]
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
rets = []
rets_nopbc = []
for backend_name, suffix_idx in (
# unfortunately, jax2tf cannot work with tf v1 behaviors
("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0),
Expand Down Expand Up @@ -182,13 +183,39 @@ def test_deep_eval(self):
atomic=True,
)
rets.append(ret)
ret = deep_eval.eval(
self.coords,
None,
self.atype,
fparam=fparam,
aparam=aparam,
)
rets_nopbc.append(ret)
ret = deep_eval.eval(
self.coords,
None,
self.atype,
fparam=fparam,
aparam=aparam,
atomic=True,
)
rets_nopbc.append(ret)
for ret in rets[1:]:
for vv1, vv2 in zip(rets[0], ret):
if np.isnan(vv2).all():
# expect all nan if not supported
continue
np.testing.assert_allclose(vv1, vv2, rtol=1e-12, atol=1e-12)

for idx, ret in enumerate(rets_nopbc[1:]):
for vv1, vv2 in zip(rets_nopbc[0], ret):
if np.isnan(vv2).all():
# expect all nan if not supported
continue
np.testing.assert_allclose(
vv1, vv2, rtol=1e-12, atol=1e-12, err_msg=f"backend {idx+1}"
)


class TestDeepPot(unittest.TestCase, IOTest):
def setUp(self):
Expand Down

0 comments on commit 94d2054

Please sign in to comment.