From 24c6afe3bafb5f54f28997767e59c2980796c178 Mon Sep 17 00:00:00 2001 From: Yinzhanghao Zhou <64253517+floatingCatty@users.noreply.github.com> Date: Tue, 31 Dec 2024 13:54:39 -0500 Subject: [PATCH] Update loss.py align unit --- dptb/nnops/loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dptb/nnops/loss.py b/dptb/nnops/loss.py index 43223f3f..f1a2d154 100644 --- a/dptb/nnops/loss.py +++ b/dptb/nnops/loss.py @@ -323,6 +323,7 @@ def __init__( self.device = device self.onsite_shift = onsite_shift self.coeff_ham = coeff_ham + assert self.coeff_ham <= 1. self.coeff_ovp = coeff_ovp if basis is not None: @@ -393,7 +394,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): eigloss = self.eigloss(data, ref_data) - return self.coeff_ham * ham_loss + eigloss + return self.coeff_ham * ham_loss + (1 - self.coeff_ham) eigloss @@ -1106,4 +1107,4 @@ def __cal_norm__(self, irreps: Irreps, x: torch.Tensor): tensor = tensor.norm(dim=-1) out.append(tensor) - return torch.cat(out, dim=-1).squeeze(0) \ No newline at end of file + return torch.cat(out, dim=-1).squeeze(0)