Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") breaks for large bsz #984

Closed
NouamaneTazi opened this issue Oct 25, 2022 · 3 comments · Fixed by #1006
Closed
Labels
bug Something isn't working

Comments

@NouamaneTazi
Copy link
Member

NouamaneTazi commented Oct 25, 2022

Describe the bug

Thanks to the amazing work done in the memory efficient PR, I can now run Stable Diffusion in fp16, on TITAN RTX (24Go VRAM) until a batch size of 31 with no issue.

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    use_auth_token=True,
    revision="fp16",
    torch_dtype=torch.float16,
).to("cuda")

batch_size = 32

with torch.inference_mode():
    image = pipe([prompt] * batch_size, num_inference_steps=5).images[0]

When I try a batch size of 32, I get the following error:

Traceback (most recent call last):
  File "/home/nouamane/projects/diffusers/a.py", line 45, in <module>
    image = pipe([prompt] * batch_size, num_inference_steps=5).images[0]
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 353, in __call__
    image = self.vae.decode(latents).sample
  File "/home/nouamane/projects/diffusers/src/diffusers/models/vae.py", line 577, in decode
    dec = self.decoder(z)
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/models/vae.py", line 217, in forward
    sample = up_block(sample)
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/models/unet_blocks.py", line 1281, in forward
    hidden_states = upsampler(hidden_states)
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/models/resnet.py", line 54, in forward
    hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/functional.py", line 3910, in interpolate
    return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements

Is there a way to fix this issue?
@patrickvonplaten @patil-suraj

System Info

  • diffusers version: 0.7.0.dev0
  • Platform: Linux-5.3.0-64-generic-x86_64-with-glibc2.30
  • Python version: 3.9.13
  • PyTorch version (GPU?): 1.12.1 (True)
  • Huggingface_hub version: 0.10.0
  • Transformers version: 4.24.0.dev0
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no
@NouamaneTazi NouamaneTazi added the bug Something isn't working label Oct 25, 2022
@patil-suraj
Copy link
Contributor

patil-suraj commented Oct 26, 2022

Interesting, didn't know this limitation exists. If BS > 32, maybe we could split batches, in the upsampling/down module.
But I think it would make sense to open a issue in PyTorch to see why this limitation exists and if there's any other alternative to this.

@patrickvonplaten
Copy link
Contributor

Same, I think however this is more a PyTorch bug than a diffusers bug

@NouamaneTazi
Copy link
Member Author

Seems like @pcuenca has already come across this issue, and has already raised the issue on Pytorch AND has suggested a fix, which is to make the tensors contiguous before the interpolation: pytorch/pytorch#81665 👏👏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants