Skip to content

Commit

Permalink
[core / Quantization] Fix for 8bit serialization tests (huggingfa…
Browse files Browse the repository at this point in the history
…ce#27234)

* fix for 8bit serialization

* added regression tests.

* fixup
  • Loading branch information
younesbelkada authored Nov 2, 2023
1 parent c52e429 commit 9b25c16
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2110,7 +2110,13 @@ def save_pretrained(
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[id_tensor_storage(tensor)].append(name)
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
Expand Down
27 changes: 27 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,33 @@ def test_int8_serialization(self):
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)

def test_int8_serialization_regression(self):
r"""
Test whether it is possible to serialize a model in 8-bit - using not safetensors
"""
from bitsandbytes.nn import Int8Params

with tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname, safe_serialization=False)

# check that the file `quantization_config` is present
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))

model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")

linear = get_some_linear_layer(model_from_saved)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))

# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

self.assertEqual(
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)

def test_int8_serialization_sharded(self):
r"""
Test whether it is possible to serialize a model in 8-bit - sharded version.
Expand Down

0 comments on commit 9b25c16

Please sign in to comment.