Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] [!4bit] save_quantized TypeError: cannot pickle 'module' object #47

Closed
FrederikHandberg opened this issue Jun 22, 2024 · 18 comments · Fixed by #49
Closed

[BUG] [!4bit] save_quantized TypeError: cannot pickle 'module' object #47

FrederikHandberg opened this issue Jun 22, 2024 · 18 comments · Fixed by #49
Assignees
Labels
bug Something isn't working

Comments

@FrederikHandberg
Copy link

Describe the bug

When saving the quantized model I get this error

INFO - {'layer': 40, 'module': 'mlp.gate_proj', 'avg_loss': '0.1468', 'time': '1.2835'}
INFO - {'layer': 40, 'module': 'mlp.down_proj', 'avg_loss': '0.4185', 'time': '5.4308'}
INFO - Packing model...
Packing model.layers.39.mlp.down_proj: 100%|██████████████████████████████████████████████████████████████████████████████████████| 280/280 [07:56<00:00,  1.70s/it]
INFO - Model packed.
Traceback (most recent call last):
  File "/workspace/GPTQModel/fred_quant.py", line 98, in <module>
    model.save_quantized(quant_output_dir)
  File "/workspace/GPTQModel/gptqmodel/models/base.py", line 545, in save_quantized
    model = copy.deepcopy(self.model)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.10/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle 'module' object

*GPU Info

A600

Here is the script I used

from transformers import AutoTokenizer
from gptqmodel import GPTQModel, QuantizeConfig
import json 
import random
import json
import random
import time
from argparse import ArgumentParser
import torch
from datasets import Dataset
import os
import math
pretrained_model_dir = "FrederikH/Numina-Base-v2"
quant_output_dir = "./numina-base-v2-gptq-8bit"
num_samples = 1000
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)

def load_data(data_path, tokenizer, n_samples):
    with open(data_path, "r", encoding="utf-8") as f:
        raw_data = json.load(f)

    raw_data = random.sample(raw_data, k=min(n_samples, len(raw_data)))

    def dummy_gen():
        return raw_data

    def tokenize(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]

        prompts = []
        texts = []
        input_ids = []
        attention_mask = []
        for istr, inp, opt in zip(instructions, inputs, outputs):
            if inp:
                prompt = f"Instruction:\n{istr}\nInput:\n{inp}\nOutput:\n"
                text = prompt + opt
            else:
                prompt = f"Instruction:\n{istr}\nOutput:\n"
                text = prompt + opt
            if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length:
                continue

            tokenized_data = tokenizer(text)

            input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
            attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
            prompts.append(prompt)
            texts.append(text)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "prompt": prompts,
        }

    dataset = Dataset.from_generator(dummy_gen)

    dataset = dataset.map(
        tokenize,
        batched=True,
        batch_size=len(dataset),
        num_proc=1,
        keep_in_memory=True,
        load_from_cache_file=False,
        remove_columns=["instruction", "input"],
    )

    dataset = dataset.to_list()

    for sample in dataset:
        sample["input_ids"] = torch.LongTensor(sample["input_ids"])
        sample["attention_mask"] = torch.LongTensor(sample["attention_mask"])

    return dataset


calibration_dataset = load_data("alpaca_data_cleaned.json", tokenizer, num_samples)
examples_for_quant = [
{"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} for example in calibration_dataset
]

quant_config = QuantizeConfig(
    bits=8,  # bit
    group_size=128,  # 128 is good balance between quality and performance
    desc_act=True
)

# load un-quantized model, by default, the model will always be loaded into CPU memory
model = GPTQModel.from_pretrained(pretrained_model_dir, quant_config)

# quantize model, the calibration_dataset should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(calibration_dataset)

# save quantized model
model.save_quantized(quant_output_dir)
@FrederikHandberg FrederikHandberg added the bug Something isn't working label Jun 22, 2024
@Qubitium
Copy link
Collaborator

@FupsGamer We are going to check this.

Assigned to @CSY-ModelCloud

Our unit tests only ran for 4bit quants so we need to expand the tests to 8bit.

@Qubitium
Copy link
Collaborator

@FupsGamer Can you check if you get similar error with 4bit quant?

@FrederikHandberg
Copy link
Author

checking now

@FrederikHandberg
Copy link
Author

yeah works for 4bit, seems like its a problem with 8bit

@Qubitium
Copy link
Collaborator

yeah works for 4bit, seems like its a problem with 8bit

Thank you for confirming the 8bit issue. We will get to the bottom if this.

@FrederikHandberg
Copy link
Author

yeah works for 4bit, seems like its a problem with 8bit

Thank you for confirming the 8bit issue. We will get to the bottom if this.

Thanks! let me know if you fix it or if there is anything else I can assist with

@Qubitium Qubitium changed the title [BUG] save_quantized TypeError: cannot pickle 'module' object [BUG] 8bit - save_quantized TypeError: cannot pickle 'module' object Jun 24, 2024
@Qubitium Qubitium changed the title [BUG] 8bit - save_quantized TypeError: cannot pickle 'module' object [BUG] [8bit] save_quantized TypeError: cannot pickle 'module' object Jun 24, 2024
@Qubitium Qubitium changed the title [BUG] [8bit] save_quantized TypeError: cannot pickle 'module' object [BUG] [!4bit] save_quantized TypeError: cannot pickle 'module' object Jun 24, 2024
@Qubitium
Copy link
Collaborator

Qubitium commented Jun 24, 2024

Status Update: src of bug found and fix undergoing testing. This bug affected all non-4bit quantization process. Expect resolution in next 12 hours.

Qubitium added a commit that referenced this issue Jun 24, 2024
…49)

* fix cannot pickle 'module' object for 8 bit

* remove unused import

* remove print

* check with tuple

* revert to len check

* add test for 8bit

* set same QuantizeConfig

* check if it's 4 bit

* fix grammar

* remove params

* it's not a list

* set gptqmodel_cuda back

* check is tuple

* format

* set desc_act=True

* set desc_act=True

* format

* format

* Refractor fix

* desc_act=True

---------

Co-authored-by: Qubitium <Qubitium@modelcloud.ai>
@Qubitium Qubitium reopened this Jun 24, 2024
@Qubitium
Copy link
Collaborator

@FrederikHandberg Fix merged to main. Please recompile from main and test again. Thanks.

@FrederikHandberg
Copy link
Author

@Qubitium Thanks! I have tested it, and true there is no longer an error, but it says killed, after the packaging finishes and creates the output dir, but its empty...

@Qubitium
Copy link
Collaborator

@FrederikHandberg Thats not good news. Can you confirm your python, cuda, and torch versions. I willl have @CSY-ModelCloud reproduce the issue mimicking your setup.

@Qubitium Qubitium reopened this Jun 25, 2024
@FrederikHandberg
Copy link
Author

cuda 11.8 paired with torch 2.1.0 just like the other error, just this one is when I use yesterdays commit.

@Qubitium
Copy link
Collaborator

@FrederikHandberg v0.9.1 has been released with all our CI unit test passing. Please try it now and let us know. For env with cuda < 12.1 and with bitblas enabled in quantize_config, you will be prompted to manua src compile bitblas.

@Qubitium
Copy link
Collaborator

Closing this as resolved with 0.9.1 release. If the issue persist, feel free to re-open this issue.

@lukehare
Copy link

lukehare commented Jul 9, 2024

I am experiencing the same error ( "Killed." ) when trying to run 8 bit quantization on v0.9.7

I see that the 8bit test was removed in #169 - is 8 bit quantization no longer supported?

Thanks!

@Qubitium
Copy link
Collaborator

Qubitium commented Jul 10, 2024

8bit should still be supported by gptq and gptq v2 format using backend.TritonV2 for inference.

Can you provide:

  1. how much ram?
  2. how much vram? what is your gpu?
  3. what is the full quantization script/code used?

We need full info to check your os level oom. OS killing your process if swap and ram cannot satisfy ram allocation. Vram oom have cuda stacktrace instead.

@Qubitium
Copy link
Collaborator

@lukehare Please create a new issue and provide the info we asked for so we can properly track and fix your issue. This issue is closed and your error unrelated to original post.

@FrederikHandberg
Copy link
Author

FrederikHandberg commented Jul 10, 2024

@Qubitium Unfortunately I get a new error now.

INFO - {'layer': 38, 'module': 'mlp.down_proj', 'avg_loss': '0.2514', 'time': '5.4949'}
INFO - {'layer': 39, 'module': 'self_attn.k_proj', 'avg_loss': '0.0118', 'time': '1.1979'}
INFO - {'layer': 39, 'module': 'self_attn.v_proj', 'avg_loss': '0.0159', 'time': '1.2004'}
INFO - {'layer': 39, 'module': 'self_attn.q_proj', 'avg_loss': '0.0362', 'time': '1.2369'}
INFO - {'layer': 39, 'module': 'self_attn.o_proj', 'avg_loss': '0.0485', 'time': '1.2130'}
INFO - {'layer': 39, 'module': 'mlp.up_proj', 'avg_loss': '0.1912', 'time': '1.3559'}
INFO - {'layer': 39, 'module': 'mlp.gate_proj', 'avg_loss': '0.1762', 'time': '1.3453'}
INFO - {'layer': 39, 'module': 'mlp.down_proj', 'avg_loss': '1.4219', 'time': '5.4637'}
INFO - {'layer': 40, 'module': 'self_attn.k_proj', 'avg_loss': '0.0086', 'time': '1.2064'}
INFO - {'layer': 40, 'module': 'self_attn.v_proj', 'avg_loss': '0.0077', 'time': '1.2014'}
INFO - {'layer': 40, 'module': 'self_attn.q_proj', 'avg_loss': '0.0228', 'time': '1.2299'}
INFO - {'layer': 40, 'module': 'self_attn.o_proj', 'avg_loss': '0.0241', 'time': '1.2327'}
INFO - {'layer': 40, 'module': 'mlp.up_proj', 'avg_loss': '0.1526', 'time': '1.3569'}
INFO - {'layer': 40, 'module': 'mlp.gate_proj', 'avg_loss': '0.1451', 'time': '1.3582'}
INFO - {'layer': 40, 'module': 'mlp.down_proj', 'avg_loss': '0.4116', 'time': '5.7132'}
INFO - Packing model...
Traceback (most recent call last):
  File "/workspace/GPTQModel/fred_quant.py", line 96, in <module>
    model.quantize(calibration_dataset)
  File "/workspace/GPTQModel/gptqmodel/models/base.py", line 164, in quantize
    self._quantize(calibration_dataset, batch_size, autotune_warmup_after_quantized, calibration_enable_gpu_cache)
  File "/workspace/GPTQModel/gptqmodel/models/base.py", line 545, in _quantize
    self.qlinear_kernel = pack_model(
  File "/workspace/GPTQModel/gptqmodel/utils/model.py", line 280, in pack_model
    make_quant(
  File "/workspace/GPTQModel/gptqmodel/utils/model.py", line 149, in make_quant
    new_layer = QuantLinear(
  File "/workspace/GPTQModel/gptqmodel/nn_modules/qlinear/qlinear_exllama.py", line 44, in __init__
    super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)
  File "/workspace/GPTQModel/gptqmodel/nn_modules/qlinear/__init__.py", line 19, in __init__
    raise NotImplementedError(err)
NotImplementedError: <class 'gptqmodel.nn_modules.qlinear.qlinear_exllama.ExllamaQuantLinear'> only supports `[4]` bits: actual bits = `8`

I am using the same config as before and using the checkout 6359c59
as I got a vLLM error on the latest commit from yesterday

Traceback (most recent call last):
  File "/workspace/GPTQModel/quant.py", line 2, in <module>
    from gptqmodel import GPTQModel, QuantizeConfig
  File "/workspace/GPTQModel/gptqmodel/__init__.py", line 1, in <module>
    from .models import GPTQModel
  File "/workspace/GPTQModel/gptqmodel/models/__init__.py", line 1, in <module>
    from .auto import MODEL_MAP, GPTQModel
  File "/workspace/GPTQModel/gptqmodel/models/auto.py", line 5, in <module>
    from .baichuan import BaiChuanGPTQ
  File "/workspace/GPTQModel/gptqmodel/models/baichuan.py", line 1, in <module>
    from .base import BaseGPTQModel
  File "/workspace/GPTQModel/gptqmodel/models/base.py", line 36, in <module>
    from ..utils.vllm import load_model_by_vllm, vllm_generate
  File "/workspace/GPTQModel/gptqmodel/utils/vllm.py", line 12, in <module>
    def convert_hf_params_to_vllm(hf_params: Dict[str, Any]) -> SamplingParams:
NameError: name 'SamplingParams' is not defined

@Qubitium
Copy link
Collaborator

@FrederikHandberg You have hit 2 issues in the tip/main code. You are trying to quantize 8bit but 4bit only kernel/packing was selected instead. We need to test swiching packing to triton which supports 8bit.

Second issue is vllm dependency on main, not official release. Vllm should be optional and should not throw this error.

We shall fix both of these issues. I will open 2 new issues regaring this as the two issues are separate and no longer appliable to this old issue.

DeJoker pushed a commit to DeJoker/GPTQModel that referenced this issue Jul 19, 2024
* Update model list

* Update README.md

---------

Co-authored-by: Qubitium-modelcloud <qubitium@modelcloud.ai>
DeJoker pushed a commit to DeJoker/GPTQModel that referenced this issue Jul 19, 2024
…Cloud#47) (ModelCloud#49)

* fix cannot pickle 'module' object for 8 bit

* remove unused import

* remove print

* check with tuple

* revert to len check

* add test for 8bit

* set same QuantizeConfig

* check if it's 4 bit

* fix grammar

* remove params

* it's not a list

* set gptqmodel_cuda back

* check is tuple

* format

* set desc_act=True

* set desc_act=True

* format

* format

* Refractor fix

* desc_act=True

---------

Co-authored-by: Qubitium <Qubitium@modelcloud.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants