From 8d9700557996e9848252dba177c91f65a0aa2ac7 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:11:47 +0800 Subject: [PATCH] Update network.py --- deepmd/pt/model/network/network.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index 7f18ff4d53..353ed0c063 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -297,6 +297,20 @@ def forward(self, atype): return self.embedding(atype.device)[atype] def get_full_embedding(self, device: torch.device): + """ + Get the type embeddings of all types. + + Parameters + ---------- + device : torch.device + The device on which to perform the computation. + + Returns + ------- + type_embedding : torch.Tensor + The full type embeddings of all types. The last index corresponds to the zero padding. + Shape: (ntypes + 1) x tebd_dim + """ return self.embedding(device) def share_params(self, base_class, shared_level, resume=False) -> None: