You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Currently setting the cache type to static model.generation_config.cache_implementation ="static" or using StaticCache breaks with multi-gpu. It throws the following error, probably because the cache is not placed on the right device on some layers:
File/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:169, inadd_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
167output=module._old_forward(*args, **kwargs)
168else:
-->169output=module._old_forward(*args, **kwargs)
170returnmodule._hf_hook.post_forward(module, output)
File/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:640, inLlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
637ifpast_key_valueisnotNone:
638# sin and cos are specific to RoPE models; cache_position needed for the static cache639cache_kwargs= {"sin": sin, "cos": cos, "cache_position": cache_position}
-->640key_states, value_states=past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
642key_states=repeat_kv(key_states, self.num_key_value_groups)
643value_states=repeat_kv(value_states, self.num_key_value_groups)
File/opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1083, inStaticCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
1080try:
1081# If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one1082cache_position.to(device=k_out.device)
->1083k_out.index_copy_(2, cache_position, key_states)
1084v_out.index_copy_(2, cache_position, value_states)
1085exceptNotImplementedError:
1086# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.RuntimeError: Expectedalltensorstobeonthesamedevice, butfoundatleasttwodevices, cuda:0andcuda:1! (whencheckingargumentforargumentindexinmethodwrapper_CUDA_index_copy_)
I can reproduce this with Llama3 70B on 2xA100. The dynamic cache version is working fine and the static cache on a single GPU with a smaller model (same model but quantized) works fine.
Expected behavior
Static cache generation should work with multi-gpu runtime.
The text was updated successfully, but these errors were encountered:
We currently initialize the whole cache on the same device, which is obviously wrong in multi-gpu settings :) We will add a fix for it.
@SunMarc for synchronization on the multi-device department before I start potentially redundant work: was this issue on your radar before? If so, was any work done on it?
Hey @gante ! I've added a quick fix to this issue here ! #32543
Basically, we do the same as what was done in the HybridCache. Can you try it @mobicham ?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.44.0Who can help?
@ArthurZucker @gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Currently setting the cache type to static
model.generation_config.cache_implementation ="static"
or usingStaticCache
breaks with multi-gpu. It throws the following error, probably because the cache is not placed on the right device on some layers:I can reproduce this with Llama3 70B on 2xA100. The dynamic cache version is working fine and the static cache on a single GPU with a smaller model (same model but quantized) works fine.
Expected behavior
Static cache generation should work with multi-gpu runtime.
The text was updated successfully, but these errors were encountered: