Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoding latents independently using AnimateDiffVideoToVideoPipeline takes more memory than outputting images directly #7378

Closed
AbhinavGopal opened this issue Mar 19, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@AbhinavGopal
Copy link
Contributor

AbhinavGopal commented Mar 19, 2024

Describe the bug

When I use the AnimateDiffVideoToVideo pipeline to output images, my server doesn't run out of memory. However, if I output the latents, and then manually run pipe.decode_latents, I somehow run out of memory.

Reproduction

Code with manual latent decoding:

import imageio
from io import BytesIO
import PIL.Image
import requests
import torch
from diffusers import  AnimateDiffVideoToVideoPipeline, AutoencoderKL, MotionAdapter
def load_video(file_path: str):
    images = []

    if file_path.startswith(("http://", "https://")):
        # If the file_path is a URL
        response = requests.get(file_path, timeout=10)
        response.raise_for_status()
        content = BytesIO(response.content)
        vid = imageio.get_reader(content)
    else:
        # Assuming it's a local file path
        vid = imageio.get_reader(file_path)

    for frame in vid:
        pil_image = PIL.Image.fromarray(frame).convert("RGB")
        images.append(pil_image)

    return images
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to('cuda')
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to('cuda')
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16).to('cuda')
video = load_video("https://i.makeagif.com/media/8-25-2022/qpOyQo.gif")[:16]
prompt = "red countdown timer"
guidance_scale = 7.5
strength = 0.8
width, height = 512, 512
output_type = "latent"
combined_outputs = pipe(prompt=prompt,video=video,guidance_scale=guidance_scale,strength=strength, width=width,height=height, output_type=output_type, num_inference_steps=5).frames
images = pipe.decode_latents(combined_outputs)

I get the error that is in the logs pasted below.

But if I directly run the pipe to get image outputs, I have no error.

import imageio
from io import BytesIO
import PIL.Image
import requests
import torch
from diffusers import  AnimateDiffVideoToVideoPipeline, AutoencoderKL, MotionAdapter
def load_video(file_path: str):
    images = []

    if file_path.startswith(("http://", "https://")):
        # If the file_path is a URL
        response = requests.get(file_path, timeout=10)
        response.raise_for_status()
        content = BytesIO(response.content)
        vid = imageio.get_reader(content)
    else:
        # Assuming it's a local file path
        vid = imageio.get_reader(file_path)

    for frame in vid:
        pil_image = PIL.Image.fromarray(frame).convert("RGB")
        images.append(pil_image)

    return images

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to('cuda')
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to('cuda')
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16).to('cuda')
video = load_video("https://i.makeagif.com/media/8-25-2022/qpOyQo.gif")[:16]
prompt = "red countdown timer"
guidance_scale = 7.5
strength = 0.8
width, height = 512, 512
output_type = "np"
images = pipe(prompt=prompt,video=video,guidance_scale=guidance_scale,strength=strength, width=width,height=height, output_type=output_type, num_inference_steps=5).frames

I have about 22.5GB of vRAM available on my server.

Logs

Traceback (most recent call last):
  File "/app/temp.py", line 40, in <module>
    images = pipe.decode_latents(combined_outputs)
  File "/app/diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py", line 508, in decode_latents
    image = self.vae.decode(latents).sample
  File "/app/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/app/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 304, in decode
    decoded = self._decode(z).sample
  File "/app/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 275, in _decode
    dec = self.decoder(z)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/diffusers/src/diffusers/models/autoencoders/vae.py", line 338, in forward
    sample = up_block(sample, latent_embeds)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 2737, in forward
    hidden_states = resnet(hidden_states, temb=temb)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/diffusers/src/diffusers/models/resnet.py", line 366, in forward
    hidden_states = self.norm2(hidden_states)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 287, in forward
    return F.group_norm(
  File "/opt/conda/envs/abhinavg/lib/python3.10/site-packages/torch/nn/functional.py", line 2561, in group_norm
    return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 21.99 GiB of which 57.00 MiB is free. Process 1099 has 21.92 GiB memory in use. Of the allocated memory 20.02 GiB is allocated by PyTorch, and 1.59 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

System Info

  • diffusers version: 0.28.0.dev0
  • Platform: Linux-4.14.322-246.539.amzn2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Huggingface_hub version: 0.21.4
  • Transformers version: 4.38.2
  • Accelerate version: 0.28.0
  • xFormers version: not installed
  • Using GPU in script?:Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @saya

@AbhinavGopal AbhinavGopal added the bug Something isn't working label Mar 19, 2024
@sayakpaul
Copy link
Member

Cc: @DN6

@sayakpaul
Copy link
Member

What happens when you delete the pipe object and then decode the latents?

@AbhinavGopal
Copy link
Contributor Author

AbhinavGopal commented Mar 19, 2024

When I delete the pipe object, still out of memory. Here's what I'm running:
`

import imageio
from io import BytesIO
import PIL.Image
import requests
import torch
from diffusers import  AnimateDiffVideoToVideoPipeline, AutoencoderKL, MotionAdapter
def load_video(file_path: str):
    images = []

    if file_path.startswith(("http://", "https://")):
        # If the file_path is a URL
        response = requests.get(file_path, timeout=10)
        response.raise_for_status()
        content = BytesIO(response.content)
        vid = imageio.get_reader(content)
    else:
        # Assuming it's a local file path
        vid = imageio.get_reader(file_path)

    for frame in vid:
        pil_image = PIL.Image.fromarray(frame).convert("RGB")
        images.append(pil_image)

    return images
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to('cuda')
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to('cuda')
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16).to('cuda')
video = load_video("https://i.makeagif.com/media/8-25-2022/qpOyQo.gif")[:16]
prompt = "red countdown timer"
guidance_scale = 7.5
strength = 0.8
width, height = 512, 512
output_type = "latent"
combined_outputs = pipe(prompt=prompt,video=video,guidance_scale=guidance_scale,strength=strength, width=width,height=height, output_type=output_type, num_inference_steps=5).frames
del pipe
import gc
gc.collect()
torch.cuda.empty_cache()

#copied from diffusers animatediffvid2vidpipeline
latents = 1 / vae.config.scaling_factor * combined_outputs
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Mar 22, 2024

Just taking a wild guess, but can you try something like:

with torch.no_grad():
  images = pipe.decode_latents(combined_outputs)

Additionally, before running just maybe it might help to do garbage collection from torch cuda and python.

@AbhinavGopal
Copy link
Contributor Author

Works! Good guess HAHA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants