Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static Cache is broken with multi-gpu inference #32624

Closed
2 of 4 tasks
mobicham opened this issue Aug 12, 2024 · 4 comments
Closed
2 of 4 tasks

Static Cache is broken with multi-gpu inference #32624

mobicham opened this issue Aug 12, 2024 · 4 comments

Comments

@mobicham
Copy link
Contributor

System Info

  • transformers version: 4.44.0
  • Platform: Linux-6.5.0-15-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.4
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0.dev20240812+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Currently setting the cache type to static model.generation_config.cache_implementation ="static" or using StaticCache breaks with multi-gpu. It throws the following error, probably because the cache is not placed on the right device on some layers:

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:640, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    637 if past_key_value is not None:
    638     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    639     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 640     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    642 key_states = repeat_kv(key_states, self.num_key_value_groups)
    643 value_states = repeat_kv(value_states, self.num_key_value_groups)

File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1083, in StaticCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1080 try:
   1081     # If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
   1082     cache_position.to(device=k_out.device)
-> 1083     k_out.index_copy_(2, cache_position, key_states)
   1084     v_out.index_copy_(2, cache_position, value_states)
   1085 except NotImplementedError:
   1086     # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA_index_copy_)

I can reproduce this with Llama3 70B on 2xA100. The dynamic cache version is working fine and the static cache on a single GPU with a smaller model (same model but quantized) works fine.

Expected behavior

Static cache generation should work with multi-gpu runtime.

@gante
Copy link
Member

gante commented Aug 12, 2024

Hi @mobicham 👋 Thank you for raising this issue!

We currently initialize the whole cache on the same device, which is obviously wrong in multi-gpu settings :) We will add a fix for it.

@SunMarc for synchronization on the multi-device department before I start potentially redundant work: was this issue on your radar before? If so, was any work done on it?

@SunMarc
Copy link
Member

SunMarc commented Aug 12, 2024

Hey @gante ! I've added a quick fix to this issue here ! #32543
Basically, we do the same as what was done in the HybridCache. Can you try it @mobicham ?

@ArthurZucker
Copy link
Collaborator

Is there a way for us to hint accelerate / name update as forward to let it do the transfers?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants