diff --git a/diffengine/datasets/transforms/dump_image.py b/diffengine/datasets/transforms/dump_image.py index de583cc..0c8e58d 100644 --- a/diffengine/datasets/transforms/dump_image.py +++ b/diffengine/datasets/transforms/dump_image.py @@ -48,7 +48,7 @@ def __call__(self, results): out_file = osp.join(self.dump_dir, f'{dump_id}_image.png') cv2.imwrite(out_file, img.numpy().astype(np.uint8)) - if 'condition_img': + if 'condition_img' in results: condition_img = results['condition_img'] if condition_img.shape[0] in [1, 3]: condition_img = condition_img.permute(1, 2, 0) * 255 diff --git a/projects/README.md b/projects/README.md new file mode 100644 index 0000000..a6dac68 --- /dev/null +++ b/projects/README.md @@ -0,0 +1,12 @@ +# Welcome to Projects of DiffEngine + +In this folder, we welcome all contribution of diffusion model training from community. + +Here, these requirements, e.g. code standards, are not that strict as in core package. Thus, developers from the community can implement their algorithms much more easily and efficiently in DiffEngine. We appreciate all contributions from community to make DiffEngine greater. + +Here is an [example project](./face_expression/) about how to add your algorithms easily. + +We also provide some documentation listed below: + +- [🙌 Contributing \[🔝\]](../README.md#🙌-contributing-🔝) + The guides for new contributors. diff --git a/projects/face_expression/README.md b/projects/face_expression/README.md index 6fa3583..a2be1c3 100644 --- a/projects/face_expression/README.md +++ b/projects/face_expression/README.md @@ -70,6 +70,4 @@ You can see more details on [LoRA docs](../../docs/source/run_guides/run_lora.md #### stable_diffusion_xl_lora_face_expression -![example1](https://github.com/okotaku/diffengine/assets/24734142/68c7569b-f62c-4228-a00d-997f2d963ad0) - -Note that training failed. We should improve SDXL LoRA training. +![example1](https://github.com/okotaku/diffengine/assets/24734142/0d6e3434-4ba6-4c90-8457-06595ad9183d) diff --git a/projects/face_expression/_base_/face_expression_xl_dataset.py b/projects/face_expression/_base_/face_expression_xl_dataset.py index 96b539d..75a871c 100644 --- a/projects/face_expression/_base_/face_expression_xl_dataset.py +++ b/projects/face_expression/_base_/face_expression_xl_dataset.py @@ -5,6 +5,7 @@ dict(type='RandomHorizontalFlip', p=0.5), dict(type='ComputeTimeIds'), dict(type='torchvision/ToTensor'), + dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'), dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]), dict(type='PackInputs', input_keys=['img', 'text', 'time_ids']), ] diff --git a/projects/face_expression/stable_diffusion_xl_lora_face_expression.py b/projects/face_expression/stable_diffusion_xl_lora_face_expression.py index 2abd337..b1296af 100644 --- a/projects/face_expression/stable_diffusion_xl_lora_face_expression.py +++ b/projects/face_expression/stable_diffusion_xl_lora_face_expression.py @@ -5,11 +5,10 @@ '../../configs/_base_/default_runtime.py' ] -train_dataloader = dict(batch_size=1) +train_dataloader = dict(batch_size=2) -optim_wrapper = dict(optimizer=dict(lr=1e-4)) +optim_wrapper = dict(optimizer=dict(lr=1e-4), accumulative_counts=2) -model = dict( - model='Linaqruf/animagine-xl', vae_model=None, lora_config=dict(rank=128)) +model = dict(model='gsdf/CounterfeitXL', lora_config=dict(rank=32)) -train_cfg = dict(by_epoch=True, max_epochs=100) +train_cfg = dict(by_epoch=True, max_epochs=50)