-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TEST: Add llama logits tests #30835
TEST: Add llama logits tests #30835
Changes from 7 commits
f832128
e21f4ba
802051e
891e5e5
e2d9538
3241d26
353d2f8
2050efd
09f1194
b66e01e
9d90982
2de1ef2
3be7a11
cef9eb3
2903db9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,6 @@ | |
require_flash_attn, | ||
require_read_token, | ||
require_torch, | ||
require_torch_accelerator, | ||
require_torch_gpu, | ||
require_torch_sdpa, | ||
slow, | ||
|
@@ -45,7 +44,6 @@ | |
import torch | ||
|
||
from transformers import ( | ||
CodeLlamaTokenizer, | ||
LlamaForCausalLM, | ||
LlamaForQuestionAnswering, | ||
LlamaForSequenceClassification, | ||
|
@@ -605,76 +603,79 @@ def setUpClass(cls): | |
# 8 is for A100 / A10 and 7 for T4 | ||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] | ||
|
||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") | ||
@slow | ||
def test_model_7b_logits(self): | ||
@require_read_token | ||
def test_model_7b_logits_bf16(self): | ||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] | ||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto") | ||
out = model(torch.tensor([input_ids])) | ||
# Expected mean on dim = -1 | ||
EXPECTED_MEAN = torch.tensor([[-6.6550, -4.1227, -4.9859, -3.2406, 0.8262, -3.0033, 1.2964, -3.3699]]) | ||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) | ||
# slicing logits[0, 0, 0:30] | ||
EXPECTED_SLICE = torch.tensor([-12.8281, -7.4453, -0.4639, -8.0625, -7.2500, -8.0000, -6.4883, -7.7695, -7.8438, -7.0312, -6.2188, -7.1328, -1.8496, 1.9961, -8.6250, -6.7227, -12.8281, -6.9492, -7.0742, -7.7852, -7.5820, -7.9062, -6.9375, -7.9805, -8.3438, -8.1562, -8.0469, -7.6250, -7.7422, -7.3398,]) # fmt: skip | ||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) | ||
|
||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") | ||
@slow | ||
def test_model_13b_logits(self): | ||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] | ||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf", device_map="auto") | ||
out = model(torch.tensor(input_ids)) | ||
# Expected mean on dim = -1 | ||
EXPECTED_MEAN = torch.tensor([[-2.0622, -1.2794, -1.1638, -0.9788, -1.4603, -1.0238, -1.7893, -1.4411]]) | ||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) | ||
# slicing logits[0, 0, 0:30] | ||
EXPECTED_SLICE = torch.tensor([-8.1406, -8.0547, 2.7461, -1.2344, -0.1448, -1.8262, -1.0020, -1.8154, -1.6895, -1.8516, -2.3574, -0.9277, 3.7598, 6.5742, -1.2998, -0.1177, -8.1406, -2.9688, -2.9199, -3.1699, -3.5254, -2.3555, -2.7988, -3.4141, -2.8262, -4.5195, -3.3379, -3.3164, -2.7832, -3.0273]) # fmt: skip | ||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) | ||
model = LlamaForCausalLM.from_pretrained( | ||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" | ||
) | ||
|
||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") | ||
@slow | ||
def test_model_13bf_logits(self): | ||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] | ||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", device_map="auto") | ||
out = model(torch.tensor(input_ids)) | ||
with torch.no_grad(): | ||
out = model(torch.tensor([input_ids]).to(torch_device)) | ||
# Expected mean on dim = -1 | ||
EXPECTED_MEAN = torch.tensor([[-0.8562, -1.8520, -0.7551, -0.4162, -1.5161, -1.2038, -2.4823, -2.3254]]) | ||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) | ||
# slicing logits[0, 0, 0:30] | ||
EXPECTED_SLICE = torch.tensor([-2.2227, 4.8828, 0.9023, -0.4578, -0.7871, -0.1033, -0.6221, -0.5786, -0.7803, -1.0674, -1.2920, -0.1570, 0.8008, 2.0723, -0.9497, 0.2771, -2.2227, -0.7612, -1.4346, -1.2061, -1.6426, -0.3000, -0.7139, -1.1934, -1.8691, -1.6973, -1.5947, -1.2705, -0.3523, -0.5513]) # fmt: skip | ||
torch.testing.assert_close(out.mean(-1), EXPECTED_SLICE, atol=1e-2, rtol=1e-2) | ||
|
||
@unittest.skip( | ||
"Logits are not exactly the same, once we fix the instabalities somehow, will update! Also it is gonna be a `too_slow` test" | ||
) | ||
@slow | ||
def test_model_70b_logits(self): | ||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] | ||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf", device_map="auto") | ||
out = model(torch.tensor(input_ids)) | ||
|
||
EXPECTED_MEAN = torch.tensor( | ||
[[-4.2327, -3.3360, -4.6665, -4.7631, -1.8180, -3.4170, -1.4211, -3.1810]], dtype=torch.float32 | ||
# fmt: off | ||
EXPECTED_MEAN = { | ||
7: torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), # fmt: skip | ||
8: torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) # fmt: skip | ||
} | ||
|
||
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2)) | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# slicing logits[0, 0, 0:15] | ||
EXPECTED_SLICE = { | ||
7: torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), # fmt: skip | ||
8: torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) # fmt: skip | ||
} | ||
# fmt: on | ||
|
||
self.assertTrue( | ||
torch.allclose( | ||
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device), | ||
out.logits[0, 0, :15], | ||
atol=1e-3, | ||
rtol=1e-3, | ||
) | ||
) | ||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) | ||
EXPECTED_SLICE = torch.tensor([-9.4922, -3.9551, 1.7998, -5.6758, -5.1055, -5.8984, -4.8320, -6.8086, -6.5391, -5.6172, -5.5820, -5.5352, 1.7881, 3.6289, -6.5117, -3.4785, -9.5000, -6.0352, -6.8125, -6.0195, -6.6836, -5.4727, -6.2812, -6.0391, -7.3398, -7.4297, -7.4844, -6.5820, -5.8789, -5.5312]) # fmt: skip | ||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) | ||
|
||
@unittest.skip("Model is curently gated") | ||
@slow | ||
def test_model_13b_greedy_generation(self): | ||
EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that 1) the laws of physics are the same everywhere in the universe and 2) the passage of time and the length of objects can vary depending on the observer\'s frame of reference.\n\nThe first part of the theory, that the laws of physics are the same everywhere, is known as the "princi""" | ||
prompt = "Simply put, the theory of relativity states that " | ||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf") | ||
input_ids = tokenizer.encode(prompt, return_tensors="pt") | ||
@require_read_token | ||
def test_model_7b_logits(self): | ||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] | ||
|
||
model = LlamaForCausalLM.from_pretrained( | ||
"meta-llama/Llama-2-13b-chat-hf", device_map="sequential", use_safetensors=False | ||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 | ||
) | ||
|
||
# greedy generation outputs | ||
generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False) | ||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text) | ||
with torch.no_grad(): | ||
out = model(torch.tensor([input_ids]).to(torch_device)) | ||
|
||
# fmt: off | ||
# Expected mean on dim = -1 | ||
EXPECTED_MEAN = { | ||
7: torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), # fmt: skip | ||
8: torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]) # fmt: skip | ||
} | ||
|
||
self.assertTrue(torch.allclose(EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2)) | ||
|
||
# slicing logits[0, 0, 0:15] | ||
EXPECTED_SLICE = { | ||
7: torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), # fmt: skip | ||
8: torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) # fmt: skip | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fmt skip not needed in between fmt on and off There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch! |
||
# fmt: on | ||
|
||
self.assertTrue( | ||
torch.allclose( | ||
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device), | ||
out.logits[0, 0, :15], | ||
atol=1e-3, | ||
rtol=1e-3, | ||
) | ||
) | ||
|
||
@slow | ||
@require_torch_gpu | ||
|
@@ -736,89 +737,6 @@ def test_compile_static_cache(self): | |
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) | ||
|
||
|
||
@require_torch | ||
class CodeLlamaIntegrationTest(unittest.TestCase): | ||
PROMPTS = [ | ||
'''def remove_non_ascii(s: str) -> str: | ||
""" <FILL_ME> | ||
return result | ||
''', | ||
"""# Installation instructions: | ||
```bash | ||
<FILL_ME> | ||
``` | ||
This downloads the LLaMA inference code and installs the repository as a local pip package. | ||
""", | ||
"""class InterfaceManagerFactory(AbstractManagerFactory): | ||
def __init__(<FILL_ME> | ||
def main(): | ||
factory = InterfaceManagerFactory(start=datetime.now()) | ||
managers = [] | ||
for i in range(10): | ||
managers.append(factory.build(id=i)) | ||
""", | ||
"""/-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/ | ||
theorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) : | ||
π₁ P = 0 ↔ <FILL_ME> = 0 := | ||
begin | ||
split, | ||
{ intros h f, | ||
rw pi_1_etalisation at h, | ||
simp [h], | ||
refl | ||
}, | ||
{ intro h, | ||
have := @quasi_adjoint C D P, | ||
simp [←pi_1_etalisation, this, h], | ||
refl | ||
} | ||
end | ||
""", | ||
] | ||
|
||
@require_torch_accelerator | ||
@slow | ||
@unittest.skip("Model is too large") | ||
def test_model_7b_logits(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as the tokenizer is tested with this, LGMT to remove |
||
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf").to(torch_device) | ||
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf") | ||
# Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. | ||
# meaning by default this supports passing splitted list of inputs | ||
processed_text = tokenizer.batch_decode(tokenizer(self.PROMPTS)["input_ids"], add_special_tokens=False) | ||
# fmt: off | ||
EXPECTED_TEXT = [ | ||
'<s> <PRE> def remove_non_ascii(s: str) -> str:\n """ <SUF>\n return result\n <MID>', | ||
'<s> <PRE> # Installation instructions:\n ```bash\n <SUF>\n ```\nThis downloads the LLaMA inference code and installs the repository as a local pip package.\n <MID>', | ||
'<s> <PRE> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__( <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID>', | ||
'<s> <PRE> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID>' | ||
] | ||
# fmt: on | ||
self.assertEqual(processed_text, EXPECTED_TEXT) | ||
processed_text_suffix_first = tokenizer.batch_decode( | ||
tokenizer(self.PROMPTS, suffix_first=True, add_special_tokens=False)["input_ids"] | ||
) | ||
|
||
# fmt: off | ||
EXPECTED_TEXT = [ | ||
'<PRE> <SUF>\n return result\n <MID> def remove_non_ascii(s: str) -> str:\n """ ', | ||
'<PRE> <SUF>\n ```\nThis downloads the LLaMA inference code and installs the repository as a local pip package.\n <MID> # Installation instructions:\n ```bash\n', | ||
'<PRE> <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__(', | ||
'<PRE> <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ ' | ||
] | ||
EXPECTED_IDS = torch.tensor([[1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898, 29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]]) | ||
# fmt: on | ||
self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT) | ||
input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"] | ||
generated_ids = model.generate(input_ids.to(torch_device), max_new_tokens=128) | ||
torch.testing.assert_close(generated_ids, EXPECTED_IDS) | ||
|
||
EXPECTED_INFILLING = [ | ||
'<s> <PRE> def remove_non_ascii(s: str) -> str:\n """ <SUF>\n return result\n <MID>Remove non-ASCII characters from a string.\n\n Args:\n s: The string to remove non-ASCII characters from.\n\n Returns:\n The string with non-ASCII characters removed.\n """\n result = ""\n for c in s:\n if ord(c) < 128:\n result += c <EOT></s>' | ||
] | ||
infilling = tokenizer.batch_decode(generated_ids) | ||
self.assertEqual(infilling, EXPECTED_INFILLING) | ||
|
||
|
||
@slow | ||
@require_torch_gpu | ||
class Mask4DTestHard(unittest.TestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to use eager otherwise bf16 + SDPA fails