Skip to content

Commit

Permalink
Update network.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 22, 2024
1 parent 65da179 commit 8d97005
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,20 @@ def forward(self, atype):
return self.embedding(atype.device)[atype]

def get_full_embedding(self, device: torch.device):
"""
Get the type embeddings of all types.
Parameters
----------
device : torch.device
The device on which to perform the computation.
Returns
-------
type_embedding : torch.Tensor
The full type embeddings of all types. The last index corresponds to the zero padding.
Shape: (ntypes + 1) x tebd_dim
"""
return self.embedding(device)

def share_params(self, base_class, shared_level, resume=False) -> None:
Expand Down

0 comments on commit 8d97005

Please sign in to comment.