diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index feb58c74cb..d21fc998b5 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -40,9 +40,9 @@ def model_call_from_call_lower( model_output_def: ModelOutputDef, coord: tnp.ndarray, atype: tnp.ndarray, - box: tnp.ndarray = None, - fparam: tnp.ndarray = None, - aparam: tnp.ndarray = None, + box: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, do_atomic_virial: bool = False, ): """Return model prediction from lower interface.