From aa850aa531b0e396b6f2fbd68cd1e6f1319d1d0b Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:34:20 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index fa6407eef..938e41938 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,25 +1034,25 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break