-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix for "two devices" issue due to RoPE changes #630
Conversation
moves the rotary embeddings to the appropriate device to avoid having tensors on multiple devices; adds compatibility with transformers versions 4.43.0 and newer
awq/quantize/quantizer.py
Outdated
@@ -542,6 +542,7 @@ def init_quant(self, n_samples=128, max_seq_len=512): | |||
best_device = get_best_device() | |||
modules[0] = modules[0].to(best_device) | |||
self.awq_model.move_embed(self.model, best_device) | |||
self.awq_model.model.model.rotary_emb.to(best_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work for all models, you need to create a separate method just like for the embedding with move_embed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see what you mean now @SunMarc (you mentioned this over in the other PR). I'll go through each of the model files and check how this needs to be adjusted for each one. Sorry for not testing more fully: I tried this simple fix you proposed for several different models and it worked fine, but I see that it's not guaranteed to work depending on the model architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again @SunMarc! I checked (as in tried quantizing) all 26 of the models currently supported (including several variants where relevant) and fixed each of these separately only in cases where the fixes were needed, ensuring the code was adjusted to respect the model-specific move_embed
method.
Note that a couple models are broken for various reasons, and I also noticed a somewhat tangential issue with llava and lava_next (seemingly related inputs_embeds
on this line in transformers for llava and llava_next -- this can be tested with llava-hf/llava-1.5-7b-hf and llava-hf/llava-v1.6-mistral-7b-hf, respectively [as long as device_map = None
is used; I'm not sure how this ought to be handled on the AWQ side specifically).
Notes
(parentheses indicate models/variants tested)
fixed:
- cohere (command r)
- gpt_neox (but model not currently not working due to a separate issue:
prev_op <class 'transformers.activations.FastGELUActivation'> not supported yet!
) - llama (2, 3, 3.1)
- qwen2 (2, 2.5)
- stablelm (1; stablelm 2 not working for unrelated reasons)
- starcoder2
no issues:
- aquila
- baichuan
- gemma (1, 1.1)
- gemma2
- gpt_bigcode (bigcode/gpt_bigcode-santacoder)
- gptj
- internlm2
- minicpm
- mistral (0.1, 0.2, 0.3)
- mixtral
- opt
- phi3 (3, 3.5)
- qwen (1, 1.5)
- yi (1, 1.5)
unfixed:
current issue not noted but model not working (i.e., currently not working for unrelated reasons):
- bloom (see bloomz_7b1 error message TypeError: forward() missing 1 required positional argument: 'alibi' #288)
deepseek_v2.py (deepseek-ai/DeepSeek-V2-Lite returnsAssertionError:
[blank]) - falcon (see bloomz_7b1 error message TypeError: forward() missing 1 required positional argument: 'alibi' #288 (comment))
- llava (1.5; seems to affected by a similar but slightly different issue:
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
) - llava-next (1.6)
- mpt (see MPT-7B: 'NoneType' object has no attribute 'bool' #293)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SunMarc: Turns out this issue is more complicated than I thought. I was testing this patch further, and unfortunately it does not work reliably as it can introduce a secondary error depending on the exact size of the model being quantized. For example, I didn't have issues with Qwen2.5 7B but 32B will fail about half-way through due to some issue with accelerate
likely related to the device_map
:
NotImplementedError: Cannot copy out of meta tensor; no data!
(similar to the issue reported here on StackOverflow)
EDIT: I think I may have jumped to conclusions here -- I still get the NotImplementedError: Cannot copy out of meta tensor; no data!
error even with the other PR in in place. I need to test this further, but it looks like there may be two different issues going on here at once, which obviously complicates testing. This one may be related to this recent change in AutoAWQ (#607), since setting device_map = None
(the old behavior in the 0.2.6 release) seems to resolve the problem.
remove changes from quantizer.py and test with qwen2 first
update move_embed
cohere, gpt_neox, stablelm, starcoder2
Hi @davedgd, I have not had the time to follow-up on this specific issue. My understanding is that it went away after we set |
This PR causes the exact problem it was trying to fix. I think I will have to revert this and publish a new version. I wish I had a CI pipeline to check for this, but currently don't have GPUs to do that
|
@casper-hansen: That's interesting -- I did notice that error once as well since this was merged, but I was having no issues as long as I ran my conversions with model = AutoAWQForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, use_cache=False, device_map=None) I do have 2 x A6000 Ada 48GB I can test changes with, so please let me know if I can be of any help. I'm sorry for any inconvenience this has led to! |
@davedgd I think CPU offloading can be achieved with the auto setting AFAIK. Either way, we need a bit more testing to ensure these type of changes work in many scenarios, maybe it could be something triggered only when |
@casper-hansen: To clarify, the reason I was using In any case, I saw you already reverted the PR a few minutes ago -- I will update to that version and run some tests (using both None, "auto", etc. with some larger models and report back. There have been some changes to the transformers package as well since this issue cropped up, so it'll be good to know where things stand. |
UPDATE 2I went ahead and ran the conversion in full after noticing the issue where conversion can fail after a certain number of layers can complete. It does turn out that with the current latest pip version (i.e., v0.2.7.post2), using To sum it up, the only approach that reliably works -- and this one I am certain of since I have converted dozens of models using it -- is using UPDATE 1@casper-hansen: Update -- I was wrong earlier -- with the latest patch applied, single GPU conversion fails after being nearly 2/3 complete with PS. I would argue that reverting this patch is worse in some ways. While it's true that you can complete quantization if you can fully load it into GPUs, this reversion also makes it impossible to do the conversion with CPU offloading if you don't have sufficient VRAM (i.e., it's win some lose some). Initial Comments@casper-hansen: Starting with a TLDR summary, you are correct that this patch is better off reverted, especially for mutli-GPU. When With the prior (current) release, the Long story short, yes, reverting it is better -- sorry for the headache... PS. I will allow the multi-GPU run to fully process using the new commit, and then go back and confirm this with the single-GPU run as well. I will follow up in ~2 hours, but I fully expect everything below to be accurate based on significant prior testing (i.e., if the issue does not crop up by the 2nd or 3rd layer, it won't crop up at all). Follow-Up: Mutli-GPU completed succesfully; still waiting on single GPU results. Following up with my findings (using Qwen2.5-32B-Instruct, which requires more than 48GB VRAM to process and either two GPUs or one GPU plus CPU offloading). All results use this setup (with model = AutoAWQForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, use_cache=False, device_map=...) Note that SUCCESS is based on reaching at least the third layer, which generally means everything will work (I did not wait out the full hour for each of these, but I plan to go back to do so). AutoAWQ v0.2.7.post2One GPU
Two GPUs
AutoAWQ v0.2.7.post2 with Commit 9f13358One GPU
Two GPUs
|
Adding the full failure for ---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[2], line 13
10 tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
12 # Quantize
---> 13 model.quantize(tokenizer, quant_config=quant_config)
15 # Save quantized model
16 model.save_quantized(quant_path)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/models/base.py:239, in BaseAWQForCausalLM.quantize(self, tokenizer, quant_config, calib_data, split, text_column, duo_scaling, export_compatible, apply_clip, n_parallel_calib_samples, max_calib_samples, max_calib_seq_len, max_chunk_memory, quantizer_cls, **kwargs)
216 self.quant_config.modules_to_not_convert = self.modules_to_not_convert
218 self.quantizer = quantizer_cls(
219 self,
220 self.model,
(...)
237 **kwargs,
238 )
--> 239 self.quantizer.quantize()
241 self.is_quantized = True
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:159, in AwqQuantizer.quantize(self)
154 # Filter out the linear layers we don't want to exclude
155 named_linears = exclude_layers_to_not_quantize(
156 named_linears, self.modules_to_not_convert
157 )
--> 159 input_feat = self._get_input_feat(self.modules[i], named_linears)
160 clear_memory()
162 # [STEP 2]: Compute and apply scale list
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:633, in AwqQuantizer._get_input_feat(self, layer, named_linears)
626 # get output as next layer's input
627
628 # Sanitize the kwargs in case we use transformers version that contains
629 # kwargs that are not handled by the module.
630 # Useful for trust_remote_code models.
631 module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
--> 633 self.inps = self._module_forward(self.inps, layer, module_kwargs)
634 for h in handles:
635 h.remove()
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:247, in AwqQuantizer._module_forward(self, x, module, module_kwargs)
241 @torch.no_grad()
242 def _module_forward(
243 self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict
244 ) -> torch.Tensor:
245 if self.n_parallel_calib_samples is None:
246 # runs through all samples at once
--> 247 module_output = module(x, **module_kwargs)
248 if isinstance(module_output, tuple):
249 module_output = module_output[0]
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 def new_forward(module, *args, **kwargs):
--> 165 args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
166 if module._hf_hook.no_grad:
167 with torch.no_grad():
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/accelerate/hooks.py:364, in AlignDevicesHook.pre_forward(self, module, *args, **kwargs)
353 self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
355 set_module_tensor_to_device(
356 module,
357 name,
(...)
361 tied_params_map=self.tied_params_map,
362 )
--> 364 return send_to_device(args, self.execution_device), send_to_device(
365 kwargs, self.execution_device, skip_keys=self.skip_keys
366 )
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/accelerate/utils/operations.py:175, in send_to_device(tensor, device, non_blocking, skip_keys)
173 return tensor.to(device)
174 elif isinstance(tensor, (tuple, list)):
--> 175 return honor_type(
176 tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
177 )
178 elif isinstance(tensor, Mapping):
179 if isinstance(skip_keys, str):
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/accelerate/utils/operations.py:82, in honor_type(obj, generator)
80 return type(obj)(*list(generator))
81 else:
---> 82 return type(obj)(generator)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/accelerate/utils/operations.py:176, in <genexpr>(.0)
173 return tensor.to(device)
174 elif isinstance(tensor, (tuple, list)):
175 return honor_type(
--> 176 tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
177 )
178 elif isinstance(tensor, Mapping):
179 if isinstance(skip_keys, str):
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/accelerate/utils/operations.py:156, in send_to_device(tensor, device, non_blocking, skip_keys)
154 device = "xpu:0"
155 try:
--> 156 return tensor.to(device, non_blocking=non_blocking)
157 except TypeError: # .to() doesn't accept non_blocking as kwarg
158 return tensor.to(device)
NotImplementedError: Cannot copy out of meta tensor; no data! |
One more update...when trying to convert Qwen2.5-72B-Instruct using my two A6000 Adas, this results in some CPU offloading occuring, and the kernel dying every time with no processing occuring using I'm about to try constraining the run to a single GPU with Long story short, here's where things stand:
Personally, I feel the second situation (i.e., with the patch that reverts this pull) is far worse for me, as I can at least quantize larger models with CPU offload and some extra swap space, but with the latest commit, I'd no longer be able to work with models larger than 32B unless I had more than 96GB VRAM to spare. More broadly, I feel more people would be able to actually perform greater-than-VRAM conversions without the commit in place, as anyone with a single GPU could still do them, whereas with the commit, you're constrained to models that fully fit into your available VRAM. ---------------------------------------------------------------------------
OutOfMemoryError Traceback (most recent call last)
Cell In[2], line 13
10 tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
12 # Quantize
---> 13 model.quantize(tokenizer, quant_config=quant_config)
15 # Save quantized model
16 model.save_quantized(quant_path)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/models/base.py:239, in BaseAWQForCausalLM.quantize(self, tokenizer, quant_config, calib_data, split, text_column, duo_scaling, export_compatible, apply_clip, n_parallel_calib_samples, max_calib_samples, max_calib_seq_len, max_chunk_memory, quantizer_cls, **kwargs)
216 self.quant_config.modules_to_not_convert = self.modules_to_not_convert
218 self.quantizer = quantizer_cls(
219 self,
220 self.model,
(...)
237 **kwargs,
238 )
--> 239 self.quantizer.quantize()
241 self.is_quantized = True
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:166, in AwqQuantizer.quantize(self)
162 # [STEP 2]: Compute and apply scale list
163 module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
164 self.modules[i], input_feat, self.module_kwargs
165 )
--> 166 scales_list = [
167 self._search_best_scale(self.modules[i], **layer)
168 for layer in module_config
169 ]
170 apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
171 scales_list = append_str_prefix(
172 scales_list, get_op_name(self.model, self.modules[i]) + "."
173 )
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:167, in <listcomp>(.0)
162 # [STEP 2]: Compute and apply scale list
163 module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
164 self.modules[i], input_feat, self.module_kwargs
165 )
166 scales_list = [
--> 167 self._search_best_scale(self.modules[i], **layer)
168 for layer in module_config
169 ]
170 apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
171 scales_list = append_str_prefix(
172 scales_list, get_op_name(self.model, self.modules[i]) + "."
173 )
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:330, in AwqQuantizer._search_best_scale(self, module, prev_op, layers, inp, module2inspect, kwargs)
327 fp16_output = self._module_forward(inp, module2inspect, module_kwargs)
329 # [STEP 4]: Compute loss
--> 330 best_scales = self._compute_best_scale(
331 inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
332 )
334 return (
335 get_op_name(module, prev_op),
336 tuple([get_op_name(module, m) for m in layers]),
337 best_scales,
338 )
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:395, in AwqQuantizer._compute_best_scale(self, x, w_mean, x_mean, module2inspect, linears2scale, fp16_output, kwargs)
390 fc.weight.data = (
391 self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
392 )
394 # W * X
--> 395 int_w_output = self._module_forward(x, module2inspect, kwargs)
397 # compute mean squared error (L2 norm)
398 loss = self._compute_loss(fp16_output, int_w_output, device)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/awq/quantize/quantizer.py:247, in AwqQuantizer._module_forward(self, x, module, module_kwargs)
241 @torch.no_grad()
242 def _module_forward(
243 self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict
244 ) -> torch.Tensor:
245 if self.n_parallel_calib_samples is None:
246 # runs through all samples at once
--> 247 module_output = module(x, **module_kwargs)
248 if isinstance(module_output, tuple):
249 module_output = module_output[0]
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/miniforge3/envs/awq/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:223, in Qwen2MLP.forward(self, hidden_state)
222 def forward(self, hidden_state):
--> 223 return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
OutOfMemoryError: CUDA out of memory. Tried to allocate 1.66 GiB. GPU 0 has a total capacity of 47.40 GiB of which 1.33 GiB is free. Including non-PyTorch memory, this process has 46.07 GiB memory in use. Of the allocated memory 44.12 GiB is allocated by PyTorch, and 1.45 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) |
Ok, I see that both solutions are not sufficient. The best solution would be if we can revert to The question is how to achieve this after HF broke this behavior?
So the question is really how we can use your PR while fixing the multi-GPU usage and switch the |
@casper-hansen: Your summary is spot on, and I agree that one-layer at a time is the safe bet in many ways, since it allows for the lowest GPU requirement to still process large models. Quick update: in my earlier post, I only evaluated the runs for a couple of layers since each conversion takes ~50 minutes. Long story short, as per the update above, please note that For the time being, the solution I'm certain works reliably for both smaller- and larger-than-VRAM models is In terms of a temporary workaround, I would encourage recommending a default of import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' (or calling python in the command line in a similar fashion, e.g., Lastly, regarding the broken HF behavior, I'm curious if @SunMarc can take another look at this, as I'm also at a bit of a loss on how this can ultimately be fixed to allow for more variety in terms of multi-GPU configurations, |
i have a fix ready #668. just need to add support for the remaining models |
I am using the latest package autoawq-0.2.7.post3 to quantify QWEN72B, but I still encounter this issue. |
Please use the updated example found in examples/quantize.py |
This PR moves the rotary embeddings to the appropriate device to avoid having tensors on multiple devices; adds compatibility with transformers versions 4.43.0 and newer. Developed thanks to the help of @SunMarc (see history in this closed transformers PR: huggingface/transformers#33742).
This fixes several issues here for AutoAWQ: #510, #558, #571, #585
Another relevant issue that this would close over in transformers is here: huggingface/transformers#32420