-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathtest_gguf.py
129 lines (107 loc) · 4.44 KB
/
test_gguf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import List
import pytest
import torch
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download
import vllm._custom_ops as ops
from vllm.platforms import current_platform
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
def get_gguf_sample_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> List[ReaderTensor]:
sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
DTYPES = [torch.half]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES = [256, 1024]
NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing
SEEDS = [0]
QUANT_TYPES = [
# i-matrix
GGMLQuantizationType.IQ1_M,
GGMLQuantizationType.IQ1_S,
GGMLQuantizationType.IQ2_S,
GGMLQuantizationType.IQ2_XS,
GGMLQuantizationType.IQ3_S,
GGMLQuantizationType.IQ3_XXS,
GGMLQuantizationType.IQ4_NL,
GGMLQuantizationType.IQ4_XS,
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quantization
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
]
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
for tensor in tensors:
shape_str = tensor.name.split("_")[-1]
shape = map(int, shape_str.split("x"))
ref_output = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
quant_type, *list(shape)).to(dtype)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type,
qweight.shape[0]).to(dtype)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"quant_type",
[
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quants
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
])
@torch.inference_mode()
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_a8(qweight, x, quant_type,
qweight.shape[0]).to(dtype)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)