Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MISC] Save sharded #40

Merged
merged 11 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ jobs:
- name: test_serialization.py
run: pytest tests/test_serialization.py

- name: test_sharded_loading.py
run: pytest tests/test_sharded_loading.py
- name: test_sharded.py
run: pytest tests/test_sharded.py

- name: test_triton.py
run: pytest tests/test_triton.py
Expand Down
118 changes: 78 additions & 40 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import json
import logging
import os
from os.path import join
import re
from os.path import isfile, join
from typing import Dict, List, Optional, Union

import accelerate
Expand All @@ -12,7 +14,7 @@
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from transformers.utils.generic import ContextManagers

from ..quantization import GPTQ, QuantizeConfig
Expand Down Expand Up @@ -465,6 +467,8 @@ def save_quantized(
safetensors_metadata: Optional[Dict[str, str]] = None,
format: Optional[FORMAT] = None,
use_safetensors: bool = True,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
Expand All @@ -484,6 +488,14 @@ def save_quantized(
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")

if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)

state_dict = self.model.state_dict()

if format == FORMAT.GPTQ_V2 or (format is None and quantize_config.format == FORMAT.GPTQ_V2):
logger.warning(
f"Using 'format = {FORMAT.GPTQ_V2}': the serialized model is only supported by GPTQModel version >= {MIN_VERSION_WITH_V2}."
Expand Down Expand Up @@ -539,47 +551,73 @@ def save_quantized(
model_base_name = quantize_config.model_file_base_name

if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)

# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
model_save_name = model_base_name + ".safetensors"
else:
logger.warning("We highly suggest saving quantized model using safetensors format for security reasons. Please set `use_safetensors=True` whenever possible.")
model_save_name = model_base_name + ".bin"
torch.save(model.state_dict(), join(save_dir, model_save_name))

# Shard checkpoint
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)

# Clean the folder from a previous save
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)

# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")

if (
filename.startswith(model_base_name)
and isfile(full_filename)
and filename not in shards.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

# Save the model
for shard_file, shard in shards.items():
if use_safetensors:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")

# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"

safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
else:
torch.save(shard, join(save_dir, shard_file))

if index is not None:
index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
# Save the index as well
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
config.quantization_config = quantize_config.to_dict()
config.save_pretrained(save_dir)

Expand Down
64 changes: 64 additions & 0 deletions tests/test_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import tempfile
import unittest

from gptqmodel import GPTQModel
from gptqmodel.quantization import FORMAT, QuantizeConfig

from transformers import AutoTokenizer


class TestSharded(unittest.TestCase):

def get_wikitext2_data(self, tokenizer, n_samples=1):
from datasets import load_dataset
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
traindata = traindata.filter(lambda x: len(x['text']) >= 512)

ds = traindata

traindataset = []
for example in ds:
if len(traindataset) == n_samples:
break

traindataset.append(tokenizer(example["text"]))

return traindataset

def test_save_and_load(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = GPTQModel.from_pretrained(
model_name,
quantize_config=QuantizeConfig(
bits=4,
group_size=128,
format=FORMAT.GPTQ_V2,
))

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

cal_data = self.get_wikitext2_data(tokenizer)

model.quantize(cal_data)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_quantized(
tmp_dir,
max_shard_size="10MB"
)

files_and_dirs = os.listdir(tmp_dir)

self.assertTrue(len(files_and_dirs) == 72)

model = GPTQModel.from_quantized(
tmp_dir,
device="cuda:0",
)

tokens = model.generate(**tokenizer("1337", return_tensors="pt").to(model.device), max_new_tokens=20)[0]
result = tokenizer.decode(tokens)

self.assertTrue(result == "<s> 1337 \n- 1437 \n- 1537 \n- ")
36 changes: 0 additions & 36 deletions tests/test_sharded_loading.py

This file was deleted.