Skip to content

Commit

Permalink
fix training bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fuxiao0719 committed Jul 13, 2024
1 parent deb1260 commit ee814b7
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 137 deletions.
Binary file modified .DS_Store
Binary file not shown.
Binary file modified geowizard/.DS_Store
Binary file not shown.
Binary file modified geowizard/training/.DS_Store
Binary file not shown.
4 changes: 0 additions & 4 deletions geowizard/training/scripts/train_depth_normal.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# accelerate config
root_path=''
output_dir=''
maddr=''
PORT=''

pretrained_model_name_or_path='stabilityai/stable-diffusion-2'
train_batch_size=4
Expand All @@ -16,8 +14,6 @@ tracker_project_name='pretrain_tracker'
seed=1234

accelerate launch --config_file ../node_config/8gpu.yaml \
--main_process_ip ${maddr} \
--main_process_port $PORT \
../training/train_depth_normal.py \
--pretrained_model_name_or_path $pretrained_model_name_or_path \
--dataset_path $root_path \
Expand Down
6 changes: 1 addition & 5 deletions geowizard/training/scripts/train_depth_normal_v2.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# accelerate config
root_path=''
output_dir=''
maddr=''
PORT=''

pretrained_model_name_or_path='stabilityai/stable-diffusion-2'
train_batch_size=4
Expand All @@ -16,8 +14,6 @@ tracker_project_name='pretrain_tracker'
seed=1234

accelerate launch --config_file ../node_config/8gpu.yaml \
--main_process_ip ${maddr} \
--main_process_port $PORT \
../training/train_depth_normal_v2.py \
--pretrained_model_name_or_path $pretrained_model_name_or_path \
--dataset_path $root_path \
Expand All @@ -33,4 +29,4 @@ accelerate launch --config_file ../node_config/8gpu.yaml \
--dataloader_num_workers $dataloader_num_workers \
--tracker_project_name $tracker_project_name \
--enable_xformers_memory_efficient_attention \
--use_ema
--use_ema
81 changes: 0 additions & 81 deletions geowizard/training/training/dataset_configuration.py

This file was deleted.

37 changes: 18 additions & 19 deletions geowizard/training/training/train_depth_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from accelerate import Accelerator
import transformers
import datasets
import numpy as np
from accelerate.logging import get_logger
from accelerate.utils import set_seed
Expand Down Expand Up @@ -52,7 +51,7 @@
from utils.de_normalized import align_scale_shift
from utils.depth2normal import *

from training.dataset_configuration import prepare_dataset, depth_scale_shift_normalization, resize_max_res_tensor
from utils.dataset_configuration import prepare_dataset, depth_scale_shift_normalization, resize_max_res_tensor

from PIL import Image

Expand All @@ -77,21 +76,21 @@ def parse_args():
help="Path to pretrained model or model identifier from huggingface.co/models.",
)

parser.add_argument(
"--input_rgb_path",
type=str,
required=True,
help="Path to the input image.",
)
# parser.add_argument(
# "--input_rgb_path",
# type=str,
# required=True,
# help="Path to the input image.",
# )

parser.add_argument(
"--dataset_path",
type=str,
default="/data1/liu",
default="/data/",
required=True,
help="The Root Dataset Path.",
)

parser.add_argument(
"--max_train_samples",
type=int,
Expand Down Expand Up @@ -462,10 +461,10 @@ def load_model_hook(models, input_dir):
# get the training dataset
with accelerator.main_process_first():
train_loader, dataset_config_dict = prepare_dataset(data_dir=args.dataset_path,
batch_size=args.train_batch_size,
test_batch=1,
datathread=args.dataloader_num_workers,
logger=logger)
batch_size=args.train_batch_size,
test_batch=1,
datathread=args.dataloader_num_workers,
logger=logger)

# because the optimizer not optimized every time, so we need to calculate how many steps it optimizes,
# it is usually optimized by
Expand Down Expand Up @@ -632,7 +631,7 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

batch_imgs_embed = imgs_embed.repeat((bsz, 1, 1)) # [B*2, 1, 768]
batch_imgs_embed = imgs_embed.repeat((2, 1, 1)) # [B*2, 1, 768]

# hybrid hierarchical switcher
geo_class = torch.tensor([[0, 1], [1, 0]], dtype=weight_dtype, device=device)
Expand All @@ -647,9 +646,9 @@ def load_model_hook(models, input_dir):
unet_input = torch.cat((rgb_latents.repeat(2,1,1,1), noisy_geo_latents), dim=1)

noise_pred = unet(unet_input,
timesteps,
encoder_hidden_states=batch_imgs_embed,
class_labels=class_embedding).sample # [B, 4, h, w]
timesteps,
encoder_hidden_states=batch_imgs_embed,
class_labels=class_embedding).sample # [B, 4, h, w]
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")

# Gather the losses across all processes for logging (if we use distributed training).
Expand Down
54 changes: 26 additions & 28 deletions geowizard/training/training/train_depth_normal_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from accelerate import Accelerator
import transformers
import datasets
import numpy as np
from accelerate.logging import get_logger
from accelerate.utils import set_seed
Expand All @@ -48,8 +47,7 @@
import cv2
from utils.de_normalized import align_scale_shift
from utils.depth2normal import *

from training.dataset_configuration import prepare_dataset, depth_scale_shift_normalization, depth_scale_normalization, resize_max_res_tensor
from utils.dataset_configuration import prepare_dataset, depth_scale_shift_normalization, resize_max_res_tensor

from PIL import Image

Expand All @@ -73,17 +71,17 @@ def parse_args():
help="Path to pretrained model or model identifier from huggingface.co/models.",
)

parser.add_argument(
"--input_rgb_path",
type=str,
required=True,
help="Path to the input image.",
)
# parser.add_argument(
# "--input_rgb_path",
# type=str,
# required=True,
# help="Path to the input image.",
# )

parser.add_argument(
"--dataset_path",
type=str,
default="/data1/liu",
default="/data/",
required=True,
help="The Root Dataset Path.",
)
Expand Down Expand Up @@ -457,10 +455,10 @@ def load_model_hook(models, input_dir):
# get the training dataset
with accelerator.main_process_first():
train_loader, dataset_config_dict = prepare_dataset(data_dir=args.dataset_path,
batch_size=args.train_batch_size,
test_batch=1,
datathread=args.dataloader_num_workers,
logger=logger)
batch_size=args.train_batch_size,
test_batch=1,
datathread=args.dataloader_num_workers,
logger=logger)

# because the optimizer not optimized every time, so we need to calculate how many steps it optimizes,
# it is usually optimized by
Expand Down Expand Up @@ -511,7 +509,7 @@ def load_model_hook(models, input_dir):
if accelerator.is_main_process:
tracker_config = dict(vars(args))
accelerator.init_trackers(args.tracker_project_name, tracker_config)

# Here is the DDP training: actually is 4
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -562,7 +560,7 @@ def load_model_hook(models, input_dir):
)

# Encode text embedding for prompt
prompt_list = ['', 'indoor geometry', 'outdoor geometry', 'object geometry']
prompt_list = ['indoor geometry', 'outdoor geometry', 'object geometry']
text_embed_list = []
for prompt in prompt_list:
text_inputs =tokenizer(
Expand All @@ -572,10 +570,10 @@ def load_model_hook(models, input_dir):
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(text_encoder.device) #[1,2]
text_input_ids = text_inputs.input_ids.to(text_encoder.device)
text_embed = text_encoder(text_input_ids)[0].to(weight_dtype)
text_embed_list.append(text_embed)

# using the epochs to training the model
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
Expand Down Expand Up @@ -631,15 +629,15 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

batch_text_embed = text_embed_list[0].repeat((bsz, 1, 1)) # [B, 2, 1024]
batch_text_embed = torch.zeros_like(text_embed_list[0]).repeat((bsz, 1, 1)) # [B, 4, 1024]
for i in range(len(batch['domain'])):
if batch['domain'][i] == torch.Tensor([1., 0., 0.]):
if batch['domain'][i][0].item() == 1:
batch_text_embed[i] = text_embed_list[0]
elif batch['domain'][i][1].item() == 1:
batch_text_embed[i] = text_embed_list[1]
elif batch['domain'][i] == torch.Tensor([0., 1., 0.]):
elif batch['domain'][i][2].item() == 1:
batch_text_embed[i] = text_embed_list[2]
elif batch['domain'][i] == torch.Tensor([0., 0., 1.]):
batch_text_embed[i] = text_embed_list[3]
batch_text_embed = batch_text_embed.repeat((2, 1, 1)) # [B*2, 2, 1024]
batch_text_embed = batch_text_embed.repeat((2, 1, 1)) # [B*2, 4, 1024]

# hybrid hierarchical switcher
geo_class = torch.tensor([[0, 1], [1, 0]], dtype=weight_dtype, device=device)
Expand All @@ -651,9 +649,9 @@ def load_model_hook(models, input_dir):
unet_input = torch.cat((rgb_latents.repeat(2,1,1,1), noisy_geo_latents), dim=1)

noise_pred = unet(unet_input,
timesteps,
encoder_hidden_states=batch_text_embed,
class_labels=class_embedding).sample # [B, 4, h, w]
timesteps,
encoder_hidden_states=batch_text_embed,
class_labels=class_embedding).sample
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")

# Gather the losses across all processes for logging (if we use distributed training).
Expand Down Expand Up @@ -744,4 +742,4 @@ def load_model_hook(models, input_dir):


if __name__=="__main__":
main()
main()

0 comments on commit ee814b7

Please sign in to comment.