Skip to content

Commit

Permalink
[bnb] Fix bnb skip modules (#24043)
Browse files Browse the repository at this point in the history
* fix skip modules test

* oops

* address comments
  • Loading branch information
younesbelkada authored Jun 7, 2023
1 parent a116018 commit 4795219
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,18 @@ class `Int8Params` from `bitsandbytes`.
module._parameters[tensor_name] = new_value


def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
def _replace_with_bnb_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
Expand Down Expand Up @@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced


Expand Down
20 changes: 20 additions & 0 deletions tests/bitsandbytes/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ def test_linear_are_8bit(self):
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.int8)

def test_llm_skip(self):
r"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import bitsandbytes as bnb

quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
"roberta-large-mnli", quantization_config=quantization_config
)
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
self.assertTrue(
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
)

self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)

def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Expand Down

0 comments on commit 4795219

Please sign in to comment.