Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: local variable 'dataset' referenced before assignment #1600

Merged
merged 1 commit into from
Jan 26, 2024

Conversation

hiyouga
Copy link
Contributor

@hiyouga hiyouga commented Dec 15, 2023

What does this PR do?

Fixes local variable 'dataset' referenced before assignment

Reproduce

from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = [{"input_ids": [0, 1, 2, 3, 4], "attention_mask": [1, 1, 1, 1, 1]}]
gptq_config = GPTQConfig(bits=4, dataset=dataset, tokenizer=tokenizer)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config)

The logic is broken if we pass a tokenized dataset to the GPTQQuantizer.

# Step 1: Prepare the data
if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str):
logger.info("GPTQQuantizer dataset appears to be already tokenized. Skipping tokenization.")
else:
if isinstance(tokenizer, str):
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
raise ValueError(
f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained`
with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input.
For now, we only support quantization for text model. Support for vision, speech and multimodel will come later."""
)
if self.dataset is None:
raise ValueError("You need to pass `dataset` in order to quantize your model")
elif isinstance(self.dataset, str):
dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train")
elif isinstance(self.dataset, list):
dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset]
else:
raise ValueError(
f"You need to pass a list of string, a list of tokenized data or a string for `dataset`. Found: {type(self.dataset)}."
)
dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size)

Traceback (most recent call last):
  File "lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
    return model_class.from_pretrained(
  File "lib/python3.10/site-packages/transformers/modeling_utils.py", line 3768, in from_pretrained
    quantizer.quantize_model(model, quantization_config.tokenizer)
  File "lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "lib/python3.10/site-packages/optimum/gptq/quantizer.py", line 382, in quantize_model
    dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size)
UnboundLocalError: local variable 'dataset' referenced before assignment

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@fxmarty fxmarty merged commit a2fd011 into huggingface:main Jan 26, 2024
41 of 46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants