Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change bnb tests #34713

Merged
merged 16 commits into from
Dec 18, 2024
22 changes: 21 additions & 1 deletion tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def get_some_linear_layer(model):
except AttributeError:
# for AutoModelforCausalLM
return model.model.decoder.layers[0].fc1
elif model.config.model_type == "llama":
return model.model.layers[0].mlp.gate_proj
else:
return model.transformer.h[0].mlp.dense_4h_to_h

Expand Down Expand Up @@ -106,6 +108,7 @@ class Base4bitTest(unittest.TestCase):
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
EXPECTED_OUTPUTS.add("Hello my name is John and I am 25 years old.")
MAX_NEW_TOKENS = 10

def setUp(self):
Expand Down Expand Up @@ -555,6 +558,8 @@ def test_training(self):

if torch.cuda.is_available():
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
elif torch.xpu.is_available():
self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"})
else:
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))

Expand Down Expand Up @@ -588,11 +593,18 @@ def test_training(self):


@apply_skip_if_not_implemented
@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed")
class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "openai-community/gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187


@apply_skip_if_not_implemented
class Bnb4BitLlamaTest(Bnb4BitTest):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
EXPECTED_RELATIVE_DIFFERENCE = 2.9461410686392764


@require_bitsandbytes
@require_accelerate
@require_torch
Expand Down Expand Up @@ -672,7 +684,7 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
out_0 = model_0(**encoded_input)
out_1 = model_1(**encoded_input)
self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))
self.assertTrue(torch.allclose(out_0["logits"], out_1["logits"], atol=0.05))

# comparing generate() outputs
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
Expand Down Expand Up @@ -734,6 +746,14 @@ class GPTSerializationTest(BaseSerializationTest):
model_name = "openai-community/gpt2-xl"


class LlamaSerializationTest(BaseSerializationTest):
"""
default BaseSerializationTest config tested with Llama family model
"""

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"


@require_bitsandbytes
@require_accelerate
@require_torch_gpu_if_bnb_not_multi_backend_enabled
Expand Down
55 changes: 46 additions & 9 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
elif model.config.model_type == "llama":
return model.model.layers[0].mlp.gate_proj
return model.transformer.h[0].mlp.dense_4h_to_h


Expand All @@ -65,12 +67,12 @@ def get_some_linear_layer(model):
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""

def __init__(self, module: nn.Module, rank: int):
def __init__(self, module: nn.Module, rank: int, dtype: torch.dtype):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
nn.Linear(module.in_features, rank, bias=False, dtype=dtype),
nn.Linear(rank, module.out_features, bias=False, dtype=dtype),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
Expand Down Expand Up @@ -858,29 +860,36 @@ def test_training(self):

if torch.cuda.is_available():
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
elif torch.xpu.is_available():
self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"})
else:
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))

for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
# cast all non INT8 parameters to fp32
if param.dtype in (torch.float16, torch.bfloat16) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)

# Step 2: add adapters
for _, module in model.named_modules():
if isinstance(module, OPTAttention):
module.q_proj = LoRALayer(module.q_proj, rank=16)
module.k_proj = LoRALayer(module.k_proj, rank=16)
module.v_proj = LoRALayer(module.v_proj, rank=16)
module.q_proj = LoRALayer(module.q_proj, rank=16, dtype=model.dtype)
module.k_proj = LoRALayer(module.k_proj, rank=16, dtype=model.dtype)
module.v_proj = LoRALayer(module.v_proj, rank=16, dtype=model.dtype)

# Step 3: dummy batch
batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device)

# Step 4: Check if the gradient is not None
with torch.autocast(torch_device):
if torch_device in {"xpu", "cpu"}:
# XPU and CPU finetune do not support autocast for now.
out = model.forward(**batch)
out.logits.norm().backward()
else:
with torch.autocast(torch_device):
out = model.forward(**batch)
out.logits.norm().backward()

for module in model.modules():
if isinstance(module, LoRALayer):
Expand All @@ -891,6 +900,7 @@ def test_training(self):


@apply_skip_if_not_implemented
@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed")
class MixedInt8GPT2Test(MixedInt8Test):
model_name = "openai-community/gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
Expand Down Expand Up @@ -922,3 +932,30 @@ def test_int8_from_pretrained(self):
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)


class MixedInt8LlamaTest(MixedInt8Test):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
EXPECTED_RELATIVE_DIFFERENCE = 1.7869331026479096
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John Smith and I am a software engineer. I")

def test_int8_from_pretrained(self):
r"""
Test whether loading a 8bit model from the Hub works as expected
"""
from bitsandbytes.nn import Int8Params

model_id = "Jiqing/TinyLlama-1.1B-Chat-v1.0-bnb-8bit"

model = AutoModelForCausalLM.from_pretrained(model_id)

linear = get_some_linear_layer(model)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))

# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)