Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1-5 docs updates #921

Merged
merged 13 commits into from
Oct 24, 2022
113 changes: 96 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,44 +64,54 @@ In order to get started, we recommend taking a look at two notebooks:
- The [Training a diffusers model](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook summarizes diffusion models training methods. This notebook takes a step-by-step approach to training your
diffusion models on an image dataset, with explanatory graphics.

## **New** Stable Diffusion is now fully compatible with `diffusers`!
## Stable Diffusion is fully compatible with `diffusers`!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) and [RunwayML](https://runwayml.com/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~4 now I think, with attention slicing, but probably better to keep this simple.

See the [model card](https://huggingface.co/CompVis/stable-diffusion) for more information.

You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation.
You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation.


### Text-to-Image generation with Stable Diffusion

First let's install
```bash
pip install --upgrade diffusers transformers scipy
```

Run this command to log in with your HF Hub token if you haven't before (you can skip this step if you prefer to run the model locally, follow [this](#running-the-model-locally) instead)
```bash
huggingface-cli login
```

We recommend using the model in [half-precision (`fp16`)](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) as it gives almost always the same results as full
precision while being roughly twice as fast and requiring half the amount of GPU RAM.

```python
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_type=torch.float16, revision="fp16")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_type=torch.float16, revision="fp16")
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
```

**Note**: If you don't want to use the token, you can also simply download the model weights
(after having [accepted the license](https://huggingface.co/CompVis/stable-diffusion-v1-4)) and pass
#### Running the model locally
If you don't want to login to Hugging Face, you can also simply download the model folder
(after having [accepted the license](https://huggingface.co/runwayml/stable-diffusion-v1-5)) and pass
the path to the local folder to the `StableDiffusionPipeline`.

```
git lfs install
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
```

Assuming the folder is stored locally under `./stable-diffusion-v1-4`, you can also run stable diffusion
Assuming the folder is stored locally under `./stable-diffusion-v1-5`, you can also run stable diffusion
without requiring an authentication token:

```python
pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
Expand All @@ -114,7 +124,7 @@ The following snippet should result in less than 4GB VRAM.

```python
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
Expand All @@ -125,7 +135,7 @@ pipe.enable_attention_slicing()
image = pipe(prompt).images[0]
```

If you wish to use a different scheduler, you can simply instantiate
If you wish to use a different scheduler (e.g.: DDIM, LMS, PNDM/PLMS), you can instantiate
it before the pipeline and pass it to `from_pretrained`.

```python
Expand All @@ -138,7 +148,7 @@ lms = LMSDiscreteScheduler(
)

pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
scheduler=lms,
Expand All @@ -158,7 +168,7 @@ please run the model in the default *full-precision* setting:
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")

# disable the following line if you run on CPU
pipe = pipe.to("cuda")
Expand All @@ -169,6 +179,75 @@ image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```

### JAX/Flax

To use StableDiffusion on TPUs and GPUs for faster inference you can leverage JAX/Flax.

Running the pipeline with default PNDMScheduler

```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", revision="flax", dtype=jax.numpy.bfloat16
)

prompt = "a photo of an astronaut riding a horse on mars"

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```

**Note**:
If you are limited by TPU memory, please make sure to load the `FlaxStableDiffusionPipeline` in `bfloat16` precision instead of the default `float32` precision as done above. You can do so by telling diffusers to load the weights from "bf16" branch.

```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16
)

prompt = "a photo of an astronaut riding a horse on mars"

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR: Think we should move the reshape functionality in the numpy_to_pil function (maybe we could open an issue for this)

```

### Image-to-Image text-guided generation with Stable Diffusion

The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
Expand All @@ -183,14 +262,14 @@ from diffusers import StableDiffusionImg2ImgPipeline

# load the pipeline
device = "cuda"
model_id_or_path = "CompVis/stable-diffusion-v1-4"
model_id_or_path = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id_or_path,
revision="fp16",
torch_dtype=torch.float16,
)
# or download via git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
# and pass `model_id_or_path="./stable-diffusion-v1-4"`.
# or download via git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
# and pass `model_id_or_path="./stable-diffusion-v1-5"`.
pipe = pipe.to(device)

# let's download an initial image
Expand Down
8 changes: 4 additions & 4 deletions docs/source/api/pipelines/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ Diffusion models often consist of multiple independently-trained models or other
Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one.
During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality:

- [`from_pretrained` method](../diffusion_pipeline) that accepts a Hugging Face Hub repository id, *e.g.* [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) or a path to a local directory, *e.g.*
"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [CompVis/stable-diffusion-v1-4/model_index.json](https://huggingface.co/CompVis/stable-diffusion-v1-4/blob/main/model_index.json), which defines all components that should be
- [`from_pretrained` method](../diffusion_pipeline) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.*
"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be
loaded into the pipelines. More specifically, for each model/component one needs to define the format `<name>: ["<library>", "<class name>"]`. `<name>` is the attribute name given to the loaded instance of `<class name>` which can be found in the library or pipeline folder called `"<library>"`.
- [`save_pretrained`](../diffusion_pipeline) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`.
In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated
Expand Down Expand Up @@ -100,7 +100,7 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
Expand All @@ -123,7 +123,7 @@ from diffusers import StableDiffusionImg2ImgPipeline
# load the pipeline
device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
"runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16
).to(device)

# let's download an initial image
Expand Down
10 changes: 5 additions & 5 deletions docs/source/optimization/fp16.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inf
from torch import autocast
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
Expand All @@ -72,7 +72,7 @@ To save more GPU memory and get even more speed, you can load and run the model

```Python
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
Expand All @@ -97,7 +97,7 @@ import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
Expand Down Expand Up @@ -152,7 +152,7 @@ def generate_inputs():


pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
Expand Down Expand Up @@ -216,7 +216,7 @@ class UNet2DConditionOutput:


pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
Expand Down
2 changes: 1 addition & 1 deletion docs/source/optimization/mps.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ We recommend to "prime" the pipeline using an additional one-time pass through i
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("mps")

prompt = "a photo of an astronaut riding a horse on mars"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/optimization/onnx.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The snippet below demonstrates how to use the ONNX runtime. You need to use `Sta
from diffusers import StableDiffusionOnnxPipeline

pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
revision="onnx",
provider="CUDAExecutionProvider",
)
Expand Down
14 changes: 7 additions & 7 deletions docs/source/quicktour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ You can save the image by simply calling:

More advanced models, like [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) require you to accept a [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) before running the model.
This is due to the improved image generation capabilities of the model and the potentially harmful content that could be produced with it.
Long story short: Head over to your stable diffusion model of choice, *e.g.* [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4), read through the license and click-accept to get
Long story short: Head over to your stable diffusion model of choice, *e.g.* [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5), read through the license and click-accept to get
access to the model.
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
Having "click-accepted" the license, you can save your token:
Expand All @@ -77,13 +77,13 @@ Having "click-accepted" the license, you can save your token:
AUTH_TOKEN = "<please-fill-with-your-token>"
```

You can then load [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)
You can then load [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)
just like we did before only that now you need to pass your `AUTH_TOKEN`:

```python
>>> from diffusers import DiffusionPipeline

>>> generator = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=AUTH_TOKEN)
>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
```

If you do not pass your authentication token you will see that the diffusion system will not be correctly
Expand All @@ -95,15 +95,15 @@ the weights locally via:

```
git lfs install
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
```

and then load locally saved weights into the pipeline. This way, you do not need to pass an authentication
token. Assuming that `"./stable-diffusion-v1-4"` is the local path to the cloned stable-diffusion-v1-4 repo,
token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned stable-diffusion-v1-5 repo,
you can also load the pipeline as follows:

```python
>>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
>>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
```

Running the pipeline is then identical to the code above as it's the same model architecture.
Expand All @@ -125,7 +125,7 @@ you could use it as follows:
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")

>>> generator = StableDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=AUTH_TOKEN
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
... )
```

Expand Down
4 changes: 2 additions & 2 deletions docs/source/training/text_inversion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ accelerate config

### Cat toy example

You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this until we've tested it with v1-5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).

Expand All @@ -83,7 +83,7 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
And launch the training using

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="path-to-dir-containing-images"

accelerate launch textual_inversion.py \
Expand Down
2 changes: 1 addition & 1 deletion docs/source/using-diffusers/custom_pipelines.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id)

pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj should we change or leave this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay to chnage this!

custom_pipeline="clip_guided_stable_diffusion",
clip_model=clip_model,
feature_extractor=feature_extractor,
Expand Down
Loading