From 10c3af85df98234ea2a0645a329f9e702a42e7c0 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Sat, 5 Oct 2024 16:06:39 +0800 Subject: [PATCH] fix Update train_util.py Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 1c9f07bfa..3df01e3e7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,6 +19,7 @@ Sequence, Tuple, Union, + Callable, ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -5239,7 +5240,6 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, timesteps = time_shift(mu, 1.0, timesteps) else: timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - t = timesteps.view(-1, 1, 1, 1) timesteps = min_timestep + (timesteps * (max_timestep - min_timestep)) else: timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")