Skip to content

Commit

Permalink
[MISC] Save sharded (#40)
Browse files Browse the repository at this point in the history
* rename for cp

* `save_quantized` method support shard checkpoint

* add new args of `save_quantized` method to `push_to_hub` method

pick save sharded

* fix import

* no need there metadatas

* add comment

* update import

* format

* update sharded test

* rename

---------

Co-authored-by: student686 <student686_2e5042963e864558@code.jdcloud.com>
  • Loading branch information
CSY-ModelCloud and student686 authored Jun 21, 2024
1 parent 717e357 commit 95f2de2
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 78 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ jobs:
- name: test_serialization.py
run: pytest tests/test_serialization.py

- name: test_sharded_loading.py
run: pytest tests/test_sharded_loading.py
- name: test_sharded.py
run: pytest tests/test_sharded.py

- name: test_triton.py
run: pytest tests/test_triton.py
Expand Down
118 changes: 78 additions & 40 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import json
import logging
import os
from os.path import join
import re
from os.path import isfile, join
from typing import Dict, List, Optional, Union

import accelerate
Expand All @@ -12,7 +14,7 @@
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from transformers.utils.generic import ContextManagers

from ..quantization import GPTQ, QuantizeConfig
Expand Down Expand Up @@ -465,6 +467,8 @@ def save_quantized(
safetensors_metadata: Optional[Dict[str, str]] = None,
format: Optional[FORMAT] = None,
use_safetensors: bool = True,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
Expand All @@ -484,6 +488,14 @@ def save_quantized(
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")

if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)

state_dict = self.model.state_dict()

if format == FORMAT.GPTQ_V2 or (format is None and quantize_config.format == FORMAT.GPTQ_V2):
logger.warning(
f"Using 'format = {FORMAT.GPTQ_V2}': the serialized model is only supported by GPTQModel version >= {MIN_VERSION_WITH_V2}."
Expand Down Expand Up @@ -539,47 +551,73 @@ def save_quantized(
model_base_name = quantize_config.model_file_base_name

if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)

# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
model_save_name = model_base_name + ".safetensors"
else:
logger.warning("We highly suggest saving quantized model using safetensors format for security reasons. Please set `use_safetensors=True` whenever possible.")
model_save_name = model_base_name + ".bin"
torch.save(model.state_dict(), join(save_dir, model_save_name))

# Shard checkpoint
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)

# Clean the folder from a previous save
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)

# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")

if (
filename.startswith(model_base_name)
and isfile(full_filename)
and filename not in shards.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

# Save the model
for shard_file, shard in shards.items():
if use_safetensors:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")

# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"

safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
else:
torch.save(shard, join(save_dir, shard_file))

if index is not None:
index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
# Save the index as well
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
config.quantization_config = quantize_config.to_dict()
config.save_pretrained(save_dir)

Expand Down
64 changes: 64 additions & 0 deletions tests/test_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import tempfile
import unittest

from gptqmodel import GPTQModel
from gptqmodel.quantization import FORMAT, QuantizeConfig

from transformers import AutoTokenizer


class TestSharded(unittest.TestCase):

def get_wikitext2_data(self, tokenizer, n_samples=1):
from datasets import load_dataset
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
traindata = traindata.filter(lambda x: len(x['text']) >= 512)

ds = traindata

traindataset = []
for example in ds:
if len(traindataset) == n_samples:
break

traindataset.append(tokenizer(example["text"]))

return traindataset

def test_save_and_load(self):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = GPTQModel.from_pretrained(
model_name,
quantize_config=QuantizeConfig(
bits=4,
group_size=128,
format=FORMAT.GPTQ_V2,
))

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

cal_data = self.get_wikitext2_data(tokenizer)

model.quantize(cal_data)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_quantized(
tmp_dir,
max_shard_size="10MB"
)

files_and_dirs = os.listdir(tmp_dir)

self.assertTrue(len(files_and_dirs) == 72)

model = GPTQModel.from_quantized(
tmp_dir,
device="cuda:0",
)

tokens = model.generate(**tokenizer("1337", return_tensors="pt").to(model.device), max_new_tokens=20)[0]
result = tokenizer.decode(tokens)

self.assertTrue(result == "<s> 1337 \n- 1437 \n- 1537 \n- ")
36 changes: 0 additions & 36 deletions tests/test_sharded_loading.py

This file was deleted.

0 comments on commit 95f2de2

Please sign in to comment.