Skip to content

Commit

Permalink
fix(pt/dp): make strip more efficient (#4400)
Browse files Browse the repository at this point in the history
The strip methods are different between se_atten and se_t_tebd, it's not
necessary to merge them.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new optional parameter `type_embedding` in various
methods across descriptor classes to enhance handling of atomic types
and embeddings.
- Added a method `get_full_embedding` in the `TypeEmbedNet` class for
easier access to complete embeddings.

- **Bug Fixes**
- Improved error handling and assertions for the new `type_embedding`
parameter in multiple classes to prevent runtime errors.

- **Documentation**
- Updated method signatures and docstrings to reflect the addition of
`type_embedding`.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
iProzd and njzjz authored Nov 23, 2024
1 parent 7bd2e5a commit 5d589da
Show file tree
Hide file tree
Showing 14 changed files with 295 additions and 97 deletions.
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def call(
extended_atype: np.ndarray,
extended_atype_embd: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
):
"""Calculate DescriptorBlock."""
pass
Expand Down
83 changes: 61 additions & 22 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,10 @@ def call(
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
nf, nloc, nnei = nlist.shape
nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3
type_embedding = self.type_embedding.call()
# nf x nall x tebd_dim
atype_embd_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
Expand All @@ -507,6 +508,7 @@ def call(
atype_ext,
atype_embd_ext,
mapping=None,
type_embedding=type_embedding,
)
# nf x nloc x (ng x ng1 + tebd_dim)
if self.concat_output_tebd:
Expand Down Expand Up @@ -874,10 +876,6 @@ def cal_g_strip(
embedding_idx,
):
assert self.embeddings_strip is not None
xp = array_api_compat.array_namespace(ss)
nfnl, nnei = ss.shape[0:2]
shape2 = math.prod(ss.shape[2:])
ss = xp.reshape(ss, (nfnl, nnei, shape2))
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
return gg
Expand All @@ -889,13 +887,15 @@ def call(
atype_ext: np.ndarray,
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
)
nf, nloc, nnei, _ = dmatrix.shape
atype = atype_ext[:, :nloc]
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
Expand All @@ -906,28 +906,33 @@ def call(
dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
# nfnl x nnei x 1
sw = xp.reshape(sw, (nf * nloc, nnei, 1))
# nfnl x tebd_dim
atype_embd = xp.reshape(atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1))
# nfnl x nnei
nlist_mask = nlist != -1
# nfnl x nnei x 1
sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0))
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
ng = self.neuron[-1]
nt = self.tebd_dim
# nfnl x nnei x 4
rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype)
# nfnl x nnei x 1
ss = rr[..., 0:1]
if self.tebd_input_mode in ["concat"]:
# nfnl x tebd_dim
atype_embd = xp.reshape(
atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim)
)
# nfnl x nnei x tebd_dim
atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1))
index = xp.tile(
xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)
)
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
if not self.type_one_side:
# nfnl x nnei x (1 + 2 * tebd_dim)
ss = xp.concat([ss, atype_embd_nlist, atype_embd_nnei], axis=-1)
Expand All @@ -941,14 +946,48 @@ def call(
# nfnl x nnei x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
if not self.type_one_side:
# nfnl x nnei x (tebd_dim * 2)
tt = xp.concat([atype_embd_nlist, atype_embd_nnei], axis=-1)
assert type_embedding is not None
ntypes_with_padding = type_embedding.shape[0]
# nf x (nl x nnei)
nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei))
# nf x (nl x nnei)
nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1)
# (nf x nl x nnei) x ng
nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng))
if self.type_one_side:
tt_full = self.cal_g_strip(type_embedding, 0)
# (nf x nl x nnei) x ng
gg_t = xp_take_along_axis(tt_full, nei_type_index, axis=0)
else:
# nfnl x nnei x tebd_dim
tt = atype_embd_nlist
# nfnl x nnei x ng
gg_t = self.cal_g_strip(tt, 0)
idx_i = xp.reshape(
xp.tile(
(xp.reshape(atype, (-1, 1)) * ntypes_with_padding), (1, nnei)
),
(-1),
)
idx_j = xp.reshape(nei_type, (-1,))
# (nf x nl x nnei) x ng
idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng))
# (ntypes) * ntypes * nt
type_embedding_nei = xp.tile(
xp.reshape(type_embedding, (1, ntypes_with_padding, nt)),
(ntypes_with_padding, 1, 1),
)
# ntypes * (ntypes) * nt
type_embedding_center = xp.tile(
xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)),
(1, ntypes_with_padding, 1),
)
# (ntypes * ntypes) * (nt+nt)
two_side_type_embedding = xp.reshape(
xp.concat([type_embedding_nei, type_embedding_center], axis=-1),
(-1, nt * 2),
)
tt_full = self.cal_g_strip(two_side_type_embedding, 0)
# (nf x nl x nnei) x ng
gg_t = xp_take_along_axis(tt_full, idx, axis=0)
# (nf x nl) x nnei x ng
gg_t = xp.reshape(gg_t, (nf * nloc, nnei, ng))
if self.smooth:
gg_t = gg_t * xp.reshape(sw, (-1, self.nnei, 1))
# nfnl x nnei x ng
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,9 +811,10 @@ def call(
self.rcut_list,
self.nsel_list,
)
type_embedding = self.type_embedding.call()
# repinit
g1_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nframes, nall, self.tebd_dim),
)
g1_inp = g1_ext[:, :nloc, :]
Expand All @@ -825,6 +826,7 @@ def call(
atype_ext,
g1_ext,
mapping,
type_embedding=type_embedding,
)
if use_three_body:
assert self.repinit_three_body is not None
Expand All @@ -839,6 +841,7 @@ def call(
atype_ext,
g1_ext,
mapping,
type_embedding=type_embedding,
)
g1 = xp.concat([g1, g1_three_body], axis=-1)
# linear to change shape
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def call(
atype_ext: np.ndarray,
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
Expand Down
77 changes: 62 additions & 15 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,10 @@ def call(
del mapping
nf, nloc, nnei = nlist.shape
nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3
type_embedding = self.type_embedding.call()
# nf x nall x tebd_dim
atype_embd_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
Expand All @@ -345,6 +346,7 @@ def call(
atype_ext,
atype_embd_ext,
mapping=None,
type_embedding=type_embedding,
)
# nf x nloc x (ng + tebd_dim)
if self.concat_output_tebd:
Expand Down Expand Up @@ -667,6 +669,7 @@ def call(
atype_ext: np.ndarray,
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
type_embedding: Optional[np.ndarray] = None,
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
Expand Down Expand Up @@ -703,20 +706,26 @@ def call(
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
# nfnl x nt_i x nt_j x 1
ss = env_ij[..., None]

nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
# nfnl x nt_i x nt_j x tebd_dim
nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1))
nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1))
ng = self.neuron[-1]
nt = self.tebd_dim

if self.tebd_input_mode in ["concat"]:
index = xp.tile(
xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)
)
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
# nfnl x nt_i x nt_j x tebd_dim
nlist_tebd_i = xp.tile(
atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1)
)
nlist_tebd_j = xp.tile(
atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1)
)
# nfnl x nt_i x nt_j x (1 + tebd_dim * 2)
ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
Expand All @@ -725,10 +734,48 @@ def call(
# nfnl x nt_i x nt_j x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
# nfnl x nt_i x nt_j x (tebd_dim * 2)
tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
gg_t = self.cal_g_strip(tt, 0)
assert type_embedding is not None
ntypes_with_padding = type_embedding.shape[0]
# nf x (nl x nnei)
nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei))
# nf x (nl x nnei)
nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1)
# nfnl x nnei
nei_type = xp.reshape(nei_type, (nf * nloc, nnei))

# nfnl x nnei x nnei
nei_type_i = xp.tile(nei_type[:, :, np.newaxis], (1, 1, nnei))
nei_type_j = xp.tile(nei_type[:, np.newaxis, :], (1, nnei, 1))

idx_i = nei_type_i * ntypes_with_padding
idx_j = nei_type_j

# (nf x nl x nt_i x nt_j) x ng
idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng))

# ntypes * (ntypes) * nt
type_embedding_i = xp.tile(
xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)),
(1, ntypes_with_padding, 1),
)

# (ntypes) * ntypes * nt
type_embedding_j = xp.tile(
xp.reshape(type_embedding, (1, ntypes_with_padding, nt)),
(ntypes_with_padding, 1, 1),
)

# (ntypes * ntypes) * (nt+nt)
two_side_type_embedding = xp.reshape(
xp.concat([type_embedding_i, type_embedding_j], axis=-1), (-1, nt * 2)
)
tt_full = self.cal_g_strip(two_side_type_embedding, 0)

# (nfnl x nt_i x nt_j) x ng
gg_t = xp_take_along_axis(tt_full, idx, axis=0)

# (nfnl x nt_i x nt_j) x ng
gg_t = xp.reshape(gg_t, (nf * nloc, nnei, nnei, ng))
if self.smooth:
gg_t = (
gg_t
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
type_embedding: Optional[torch.Tensor] = None,
):
"""Calculate DescriptorBlock."""
pass
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,12 +687,17 @@ def forward(
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
if self.tebd_input_mode in ["strip"]:
type_embedding = self.type_embedding.get_full_embedding(g1_ext.device)
else:
type_embedding = None
g1, g2, h2, rot_mat, sw = self.se_atten(
nlist,
extended_coord,
extended_atype,
g1_ext,
mapping=None,
type_embedding=type_embedding,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
9 changes: 8 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def init_subclass_params(sub_data, sub_class):

self.repinit_args = init_subclass_params(repinit, RepinitArgs)
self.repformer_args = init_subclass_params(repformer, RepformerArgs)
self.tebd_input_mode = self.repinit_args.tebd_input_mode

self.repinit = DescrptBlockSeAtten(
self.repinit_args.rcut,
Expand Down Expand Up @@ -765,6 +766,10 @@ def forward(
# repinit
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
if self.tebd_input_mode in ["strip"]:
type_embedding = self.type_embedding.get_full_embedding(g1_ext.device)
else:
type_embedding = None
g1, _, _, _, _ = self.repinit(
nlist_dict[
get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel())
Expand All @@ -773,6 +778,7 @@ def forward(
extended_atype,
g1_ext,
mapping,
type_embedding,
)
if use_three_body:
assert self.repinit_three_body is not None
Expand All @@ -787,6 +793,7 @@ def forward(
extended_atype,
g1_ext,
mapping,
type_embedding,
)
g1 = torch.cat([g1, g1_three_body], dim=-1)
# linear to change shape
Expand All @@ -813,7 +820,7 @@ def forward(
extended_atype,
g1,
mapping,
comm_dict,
comm_dict=comm_dict,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
type_embedding: Optional[torch.Tensor] = None,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
if comm_dict is None:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
type_embedding: Optional[torch.Tensor] = None,
):
"""Calculate decoded embedding for each atom.
Expand Down
Loading

0 comments on commit 5d589da

Please sign in to comment.