-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
v1-5 docs updates #921
v1-5 docs updates #921
Changes from 4 commits
118b5e4
7778c2b
355746f
af65886
3d31560
35086d2
6a0eb1c
e943d9c
a2ff989
6f0d34e
0af885c
f48fc0d
a77496b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,44 +64,54 @@ In order to get started, we recommend taking a look at two notebooks: | |
- The [Training a diffusers model](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Open In Colab"](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook summarizes diffusion models training methods. This notebook takes a step-by-step approach to training your | ||
diffusion models on an image dataset, with explanatory graphics. | ||
|
||
## **New** Stable Diffusion is now fully compatible with `diffusers`! | ||
## Stable Diffusion is fully compatible with `diffusers`! | ||
|
||
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM. | ||
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) and [RunwayML](https://runwayml.com/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ~4 now I think, with attention slicing, but probably better to keep this simple.
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
See the [model card](https://huggingface.co/CompVis/stable-diffusion) for more information. | ||
|
||
You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation. | ||
You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation. | ||
apolinario marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
### Text-to-Image generation with Stable Diffusion | ||
|
||
First let's install | ||
```bash | ||
pip install --upgrade diffusers transformers scipy | ||
``` | ||
|
||
Run this command to log in with your HF Hub token if you haven't before (you can skip this step if you prefer to run the model locally, follow [this](#running-the-model-locally) instead) | ||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
We recommend using the model in [half-precision (`fp16`)](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) as it gives almost always the same results as full | ||
precision while being roughly twice as fast and requiring half the amount of GPU RAM. | ||
|
||
```python | ||
# make sure you're logged in with `huggingface-cli login` | ||
from diffusers import StableDiffusionPipeline | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_type=torch.float16, revision="fp16") | ||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_type=torch.float16, revision="fp16") | ||
pipe = pipe.to("cuda") | ||
|
||
prompt = "a photo of an astronaut riding a horse on mars" | ||
image = pipe(prompt).images[0] | ||
``` | ||
|
||
**Note**: If you don't want to use the token, you can also simply download the model weights | ||
(after having [accepted the license](https://huggingface.co/CompVis/stable-diffusion-v1-4)) and pass | ||
#### Running the model locally | ||
If you don't want to login to Hugging Face, you can also simply download the model folder | ||
(after having [accepted the license](https://huggingface.co/runwayml/stable-diffusion-v1-5)) and pass | ||
the path to the local folder to the `StableDiffusionPipeline`. | ||
|
||
``` | ||
git lfs install | ||
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 | ||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 | ||
``` | ||
|
||
Assuming the folder is stored locally under `./stable-diffusion-v1-4`, you can also run stable diffusion | ||
Assuming the folder is stored locally under `./stable-diffusion-v1-5`, you can also run stable diffusion | ||
without requiring an authentication token: | ||
|
||
```python | ||
pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") | ||
pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") | ||
pipe = pipe.to("cuda") | ||
|
||
prompt = "a photo of an astronaut riding a horse on mars" | ||
|
@@ -114,7 +124,7 @@ The following snippet should result in less than 4GB VRAM. | |
|
||
```python | ||
pipe = StableDiffusionPipeline.from_pretrained( | ||
"CompVis/stable-diffusion-v1-4", | ||
"runwayml/stable-diffusion-v1-5", | ||
revision="fp16", | ||
torch_dtype=torch.float16, | ||
) | ||
|
@@ -125,7 +135,7 @@ pipe.enable_attention_slicing() | |
image = pipe(prompt).images[0] | ||
``` | ||
|
||
If you wish to use a different scheduler, you can simply instantiate | ||
If you wish to use a different scheduler (e.g.: DDIM, LMS, PNDM/PLMS), you can instantiate | ||
it before the pipeline and pass it to `from_pretrained`. | ||
|
||
```python | ||
|
@@ -138,7 +148,7 @@ lms = LMSDiscreteScheduler( | |
) | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained( | ||
"CompVis/stable-diffusion-v1-4", | ||
"runwayml/stable-diffusion-v1-5", | ||
revision="fp16", | ||
torch_dtype=torch.float16, | ||
scheduler=lms, | ||
|
@@ -158,7 +168,7 @@ please run the model in the default *full-precision* setting: | |
# make sure you're logged in with `huggingface-cli login` | ||
from diffusers import StableDiffusionPipeline | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") | ||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") | ||
|
||
# disable the following line if you run on CPU | ||
pipe = pipe.to("cuda") | ||
|
@@ -169,6 +179,75 @@ image = pipe(prompt).images[0] | |
image.save("astronaut_rides_horse.png") | ||
``` | ||
|
||
### JAX/Flax | ||
|
||
To use StableDiffusion on TPUs and GPUs for faster inference you can leverage JAX/Flax. | ||
|
||
Running the pipeline with default PNDMScheduler | ||
|
||
```python | ||
import jax | ||
import numpy as np | ||
from flax.jax_utils import replicate | ||
from flax.training.common_utils import shard | ||
|
||
from diffusers import FlaxStableDiffusionPipeline | ||
|
||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( | ||
"runwayml/stable-diffusion-v1-5", revision="flax", dtype=jax.numpy.bfloat16 | ||
) | ||
|
||
prompt = "a photo of an astronaut riding a horse on mars" | ||
|
||
prng_seed = jax.random.PRNGKey(0) | ||
num_inference_steps = 50 | ||
|
||
num_samples = jax.device_count() | ||
prompt = num_samples * [prompt] | ||
prompt_ids = pipeline.prepare_inputs(prompt) | ||
|
||
# shard inputs and rng | ||
params = replicate(params) | ||
prng_seed = jax.random.split(prng_seed, 8) | ||
apolinario marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prompt_ids = shard(prompt_ids) | ||
|
||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | ||
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | ||
``` | ||
|
||
**Note**: | ||
If you are limited by TPU memory, please make sure to load the `FlaxStableDiffusionPipeline` in `bfloat16` precision instead of the default `float32` precision as done above. You can do so by telling diffusers to load the weights from "bf16" branch. | ||
|
||
```python | ||
import jax | ||
import numpy as np | ||
from flax.jax_utils import replicate | ||
from flax.training.common_utils import shard | ||
|
||
from diffusers import FlaxStableDiffusionPipeline | ||
|
||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( | ||
"runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16 | ||
) | ||
|
||
prompt = "a photo of an astronaut riding a horse on mars" | ||
|
||
prng_seed = jax.random.PRNGKey(0) | ||
num_inference_steps = 50 | ||
|
||
num_samples = jax.device_count() | ||
prompt = num_samples * [prompt] | ||
prompt_ids = pipeline.prepare_inputs(prompt) | ||
|
||
# shard inputs and rng | ||
params = replicate(params) | ||
prng_seed = jax.random.split(prng_seed, 8) | ||
apolinario marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prompt_ids = shard(prompt_ids) | ||
|
||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | ||
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to this PR: Think we should move the reshape functionality in the |
||
``` | ||
|
||
### Image-to-Image text-guided generation with Stable Diffusion | ||
|
||
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. | ||
|
@@ -183,14 +262,14 @@ from diffusers import StableDiffusionImg2ImgPipeline | |
|
||
# load the pipeline | ||
device = "cuda" | ||
model_id_or_path = "CompVis/stable-diffusion-v1-4" | ||
model_id_or_path = "runwayml/stable-diffusion-v1-5" | ||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | ||
model_id_or_path, | ||
revision="fp16", | ||
torch_dtype=torch.float16, | ||
) | ||
# or download via git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 | ||
# and pass `model_id_or_path="./stable-diffusion-v1-4"`. | ||
# or download via git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 | ||
# and pass `model_id_or_path="./stable-diffusion-v1-5"`. | ||
pipe = pipe.to(device) | ||
|
||
# let's download an initial image | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,7 @@ accelerate config | |
|
||
### Cat toy example | ||
|
||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. | ||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's revert this until we've tested it with v1-5 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
|
||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). | ||
|
||
|
@@ -83,7 +83,7 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c | |
And launch the training using | ||
|
||
```bash | ||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export MODEL_NAME="runwayml/stable-diffusion-v1-5" | ||
export DATA_DIR="path-to-dir-containing-images" | ||
|
||
accelerate launch textual_inversion.py \ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,7 @@ feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) | |
clip_model = CLIPModel.from_pretrained(clip_model_id) | ||
|
||
pipeline = DiffusionPipeline.from_pretrained( | ||
"CompVis/stable-diffusion-v1-4", | ||
"runwayml/stable-diffusion-v1-5", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj should we change or leave this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay to chnage this! |
||
custom_pipeline="clip_guided_stable_diffusion", | ||
clip_model=clip_model, | ||
feature_extractor=feature_extractor, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍