Skip to content

Latest commit

 

History

History
132 lines (93 loc) · 5.74 KB

README.md

File metadata and controls

132 lines (93 loc) · 5.74 KB

HIVE

Shu Zhang*1, Xinyi Yang*1, Yihao Feng*1, Can Qin3, Chia-Chih Chen1, Ning Yu1, Zeyuan Chen1, Huan Wang1, Silvio Savarese1,2, Stefano Ermon2, Caiming Xiong1, and Ran Xu1
1Salesforce AI, 2Stanford University, 3Northeastern University
*denotes equal contribution
arXiv 2023

This is a PyTorch implementation of HIVE: Harnessing Human Feedback for Instructional Visual Editing. The major part of the code follows InstructPix2Pix. In this repo, we have implemented both stable diffusion v1.5-base and stable diffusion v2.1-base as the backbone.

Updates

  • 07/08/23: Training code and training data is public.😊
  • 03/26/24: HIVE will appear in CVPR, 2024.😊

Usage

Preparation

First set-up the hive enviroment and download the pretrianed model as below. This is only verified on CUDA 11.0 and CUDA 11.3 with NVIDIA A100 GPU.

conda env create -f environment.yaml
conda activate hive
bash scripts/download_checkpoints.sh

To fine-tune a stable diffusion model, you need to obtain the pre-trained stable diffusion models following their instructions. If you use SD-V1.5, you can download the huggingface weights HuggingFace SD 1.5. If you use SD-V2.1, the weights can be downloaded on HuggingFace SD 2.1. You can decide which version of checkpoint to use. We use v2-1_512-ema-pruned.ckpt. Download the model to checkpoints/.

Data

We suggest to install Gcloud CLI following Gcloud download. To obtain both training and evaluation data, run

bash scripts/download_hive_data.sh

An alternative method is to directly download the data through Evaluation data and Evaluation instructions.

Step-1 Training

For SD v2.1, we run

python main.py --name step1 --base configs/train_v21_base.yaml --train --gpus 0,1,2,3,4,5,6,7

Inference

Samples can be obtained by running the command.

For SD v2.1, if we use the conditional reward, we run

python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" --ckpt checkpoints/hive_v2_rw_condition.ckpt \
--config configs/generate_v21_base.yaml

or run batch inference on our inference data:

python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv21_rw_label/ --ckpt checkpoints/hive_v2_rw_condition.ckpt \
--config configs/generate_v21_base.yaml --image_dir data/evaluation/

For SD v2.1, if we use the weighted reward, we can run

python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_v2_rw.ckpt --config configs/generate_v21_base.yaml

or run batch inference on our inference data:

python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv21/ --ckpt checkpoints/hive_v2_rw.ckpt \
--config configs/generate_v21_base.yaml --image_dir data/evaluation/

For SD v1.5, if we use the conditional reward, we can run

python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml

or run batch inference on our inference data:

python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv15_rw_label/ \
--ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml \
--image_dir data/evaluation/

For SD v1.5, if we use the weighted reward, we run

python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 --input imgs/example1.jpg \
--output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml

or run batch inference on our inference data:

python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv15/ \
--ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml \
--image_dir data/evaluation/

Citation

@article{zhang2023hive,
	title={HIVE: Harnessing Human Feedback for Instructional Visual Editing},
	author={Zhang, Shu and Yang, Xinyi and Feng, Yihao and Qin, Can and Chen, Chia-Chih and Yu, Ning and Chen, Zeyuan and Wang, Huan and Savarese, Silvio and Ermon, Stefano and Xiong, Caiming and Xu, Ran},
	journal={arXiv preprint arXiv:2303.09618},
	year={2023}
  }