From b1faf56ba22727f1e97eeb6e4362dac02e8a1654 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:15:18 +0800 Subject: [PATCH] add optional cast --- deepmd/pt/model/network/mlp.py | 5 +++++ deepmd/pt/utils/env.py | 1 + 2 files changed, 6 insertions(+) diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 2b8383806b..582abf4d69 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -32,6 +32,7 @@ ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, + DP_DTYPE_PROMOTION_STRICT, PRECISION_DICT, ) from deepmd.pt.utils.utils import ( @@ -200,6 +201,8 @@ def forward( The output. """ ori_prec = xx.dtype + if not DP_DTYPE_PROMOTION_STRICT: + xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias if self.bias is not None @@ -214,6 +217,8 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy + if not DP_DTYPE_PROMOTION_STRICT: + yy = yy.to(ori_prec) return yy def serialize(self) -> dict: diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 3ee0b7b54d..81dce669ff 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -15,6 +15,7 @@ ) SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" try: # only linux ncpus = len(os.sched_getaffinity(0))