Skip to content

Commit

Permalink
feat: update sel by statistics (#3348)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Feb 29, 2024
1 parent 84d0576 commit d09af56
Show file tree
Hide file tree
Showing 25 changed files with 597 additions and 260 deletions.
16 changes: 16 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def deserialize(cls, data: dict) -> "BD":
return BD.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
# call subprocess
cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__))
return cls.update_sel(global_jdata, local_jdata)

setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")

Expand Down
17 changes: 17 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import numpy as np

from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
Expand Down Expand Up @@ -388,3 +391,17 @@ def deserialize(cls, data: dict) -> "DescrptSeA":
obj.embeddings = NetworkCollection.deserialize(embeddings)
obj.env_mat = EnvMat.deserialize(env_mat)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
17 changes: 17 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -324,3 +327,17 @@ def deserialize(cls, data: dict) -> "DescrptSeR":
obj.embeddings = NetworkCollection.deserialize(embeddings)
obj.env_mat = EnvMat.deserialize(env_mat)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
15 changes: 15 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
pass

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
cls = cls.get_class_by_type(local_jdata.get("type", "standard"))
return cls.update_sel(global_jdata, local_jdata)

return BaseBaseModel


Expand Down
20 changes: 19 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
Expand All @@ -14,4 +17,19 @@
# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
pass
@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
21 changes: 21 additions & 0 deletions deepmd/dpmodel/utils/update_sel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Type,
)

from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStat,
)
from deepmd.utils.update_sel import (
BaseUpdateSel,
)


class UpdateSel(BaseUpdateSel):
@property
def neighbor_stat(self) -> Type[NeighborStat]:
return NeighborStat

def hook(self, min_nbor_dist, max_nbor_size):
# TODO: save to the model
pass
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def main_parser() -> argparse.ArgumentParser:
parser_train.add_argument(
"--skip-neighbor-stat",
action="store_true",
help="(Supported backend: TensorFlow) Skip calculating neighbor statistics. Sel checking, automatic sel, and model compression will be disabled.",
help="Skip calculating neighbor statistics. Sel checking, automatic sel, and model compression will be disabled.",
)
parser_train.add_argument(
# -m has been used by mpi-log
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.pt.infer import (
inference,
)
from deepmd.pt.model.model import (
BaseModel,
)
from deepmd.pt.train import (
training,
)
Expand Down Expand Up @@ -249,6 +252,12 @@ def train(FLAGS):
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
config["model"] = BaseModel.update_sel(config, config["model"])

trainer = get_trainer(
config,
FLAGS.init_model,
Expand Down
17 changes: 17 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -215,3 +218,17 @@ def forward(
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True)
32 changes: 32 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -396,3 +399,32 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
update_sel = UpdateSel()
local_jdata_cpy = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy,
True,
rcut_key="repinit_rcut",
sel_key="repinit_nsel",
)
local_jdata_cpy = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy,
True,
rcut_key="repformer_rcut",
sel_key="repformer_nsel",
)
return local_jdata_cpy
17 changes: 17 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
Expand Down Expand Up @@ -228,6 +231,20 @@ def t_cvt(xx):
obj.sea.filter_layers = NetworkCollection.deserialize(embeddings)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)


@DescriptorBlock.register("se_e2_a")
class DescrptBlockSeA(DescriptorBlock):
Expand Down
17 changes: 17 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
Expand Down Expand Up @@ -319,3 +322,17 @@ def t_cvt(xx):
obj["dstd"] = t_cvt(variables["dstd"])
obj.filter_layers = NetworkCollection.deserialize(embeddings)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
20 changes: 20 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
Expand Down Expand Up @@ -47,3 +50,20 @@ def __new__(cls, descriptor, fitting, *args, **kwargs):
cls = PolarModel
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
20 changes: 20 additions & 0 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import torch

from deepmd.dpmodel.model.dp_model import (
DPModel,
)
from deepmd.pt.model.atomic_model import (
DPZBLLinearAtomicModel,
)
Expand Down Expand Up @@ -97,3 +100,20 @@ def forward_lower(
model_predict["dforce"] = model_ret["dforce"]
model_predict = model_ret
return model_predict

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"] = DPModel.update_sel(
global_jdata, local_jdata["dpmodel"]
)
return local_jdata_cpy
21 changes: 21 additions & 0 deletions deepmd/pt/utils/update_sel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Type,
)

from deepmd.pt.utils.neighbor_stat import (
NeighborStat,
)
from deepmd.utils.update_sel import (
BaseUpdateSel,
)


class UpdateSel(BaseUpdateSel):
@property
def neighbor_stat(self) -> Type[NeighborStat]:
return NeighborStat

def hook(self, min_nbor_dist, max_nbor_size):
# TODO: save to the model
pass
Loading

0 comments on commit d09af56

Please sign in to comment.