From f86777c85f4e14641b89c121277c11ae9cd5a9a7 Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Mon, 24 Jun 2024 13:15:47 +0800 Subject: [PATCH] Fix nan by rescheduling attention scaling (#322) --- README.md | 20 ++++---- chatglm.cpp | 25 +++++----- chatglm.h | 99 ++++++++++++++++++++++++--------------- chatglm_cpp/_C.pyi | 2 +- chatglm_cpp/__init__.py | 2 +- chatglm_cpp/convert.py | 4 +- chatglm_pybind.cpp | 2 +- chatglm_test.cpp | 40 +++++++--------- tests/test_chatglm_cpp.py | 18 +++---- 9 files changed, 115 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index dd163b7..d65c70f 100644 --- a/README.md +++ b/README.md @@ -59,13 +59,15 @@ The original model (`-i `) can be a Hugging Face model name * CodeGeeX2: `THUDM/codegeex2-6b`, `THUDM/codegeex2-6b-int4` You are free to try any of the below quantization types by specifying `-t `: -* `q4_0`: 4-bit integer quantization with fp16 scales. -* `q4_1`: 4-bit integer quantization with fp16 scales and minimum values. -* `q5_0`: 5-bit integer quantization with fp16 scales. -* `q5_1`: 5-bit integer quantization with fp16 scales and minimum values. -* `q8_0`: 8-bit integer quantization with fp16 scales. -* `f16`: half precision floating point weights without quantization. -* `f32`: single precision floating point weights without quantization. +| type | precision | symmetric | +| ------ | --------- | --------- | +| `q4_0` | int4 | true | +| `q4_1` | int4 | false | +| `q5_0` | int5 | true | +| `q5_1` | int5 | false | +| `q8_0` | int8 | true | +| `f16` | half | | +| `f32` | float | | For LoRA models, add `-l ` flag to merge your LoRA weights into the base model. For example, run `python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o models/chatglm3-ggml-lora.bin -l shibing624/chatglm3-6b-csc-chinese-lora` to merge public LoRA weights from Hugging Face. @@ -551,8 +553,8 @@ Download and unzip the dataset from [link](https://s3.amazonaws.com/research.met | | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F16 | |-------------------------|-------|-------|-------|-------|-------|-------| -| [ChatGLM3-6B-Base][1] | 6.215 | 6.184 | 5.997 | 6.015 | 5.965 | 5.971 | -| [ChatGLM4-9B-Base][2] | 6.851 | 6.793 | 6.652 | 6.635 | 6.582 | 6.586 | +| [ChatGLM3-6B-Base][1] | 6.215 | 6.188 | 6.006 | 6.022 | 5.971 | 5.972 | +| [ChatGLM4-9B-Base][2] | 6.834 | 6.780 | 6.645 | 6.624 | 6.576 | 6.577 | [1]: https://huggingface.co/THUDM/chatglm3-6b-base [2]: https://huggingface.co/THUDM/glm-4-9b diff --git a/chatglm.cpp b/chatglm.cpp index 6cf5c32..a7fc336 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -624,7 +624,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta const int hidden_size = hidden_states->ne[0]; const int qlen = hidden_states->ne[1]; const int head_size = hidden_size / num_attention_heads; - const int num_shared_q_heads = num_attention_heads / num_kv_heads; + const int num_shared_q_heads = num_attention_heads / num_key_value_heads; ggml_tensor *qkv = query_key_value.forward(mctx, hidden_states); // [sq, (#h + 2 * #kvh) * d] @@ -645,10 +645,11 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta } else { query_layer = ggml_view_3d(ctx, qkv, head_size, num_attention_heads, qlen, head_size * ggml_element_size(qkv), qkv->nb[1], 0); - key_layer = ggml_view_3d(ctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size(qkv), + key_layer = ggml_view_3d(ctx, qkv, head_size, num_key_value_heads, qlen, head_size * ggml_element_size(qkv), qkv->nb[1], hidden_size * ggml_element_size(qkv)); - value_layer = ggml_view_3d(ctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size(qkv), - qkv->nb[1], (hidden_size + head_size * num_kv_heads) * ggml_element_size(qkv)); + value_layer = + ggml_view_3d(ctx, qkv, head_size, num_key_value_heads, qlen, head_size * ggml_element_size(qkv), qkv->nb[1], + (hidden_size + head_size * num_key_value_heads) * ggml_element_size(qkv)); } query_layer = apply_rotary_emb(mctx, query_layer, position_ids, rope_type, rope_theta); @@ -657,7 +658,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta query_layer = ggml_cont(ctx, ggml_permute(ctx, query_layer, 0, 2, 1, 3)); // [#h, s, d] if (num_shared_q_heads > 1) { query_layer = ggml_reshape_3d(ctx, query_layer, head_size, num_shared_q_heads * qlen, - num_kv_heads); // [#kvh, (#h/#kvh) * s, d] + num_key_value_heads); // [#kvh, (#h/#kvh) * s, d] } key_layer = ggml_permute(ctx, key_layer, 0, 2, 1, 3); // [#kvh, s, d] @@ -665,25 +666,25 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta // store key & value to cache ggml_tensor *k_cache_view = - ggml_view_3d(ctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb[1], k_cache->nb[2], + ggml_view_3d(ctx, k_cache, head_size, qlen, num_key_value_heads, k_cache->nb[1], k_cache->nb[2], (num_virtual_tokens + n_past) * head_size * ggml_element_size(k_cache)); // [#kvh, s, d] ggml_build_forward_expand(mctx->gf, ggml_cpy(ctx, key_layer, k_cache_view)); ggml_tensor *v_cache_view = - ggml_view_3d(ctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb[1], v_cache->nb[2], + ggml_view_3d(ctx, v_cache, qlen, head_size, num_key_value_heads, v_cache->nb[1], v_cache->nb[2], (num_virtual_tokens + n_past) * ggml_element_size(v_cache)); // [#kvh, d, s] ggml_build_forward_expand(mctx->gf, ggml_cpy(ctx, value_layer, v_cache_view)); // concat key & value with past kv - key_layer = ggml_view_3d(ctx, k_cache, head_size, num_virtual_tokens + n_past + qlen, num_kv_heads, k_cache->nb[1], - k_cache->nb[2], + key_layer = ggml_view_3d(ctx, k_cache, head_size, num_virtual_tokens + n_past + qlen, num_key_value_heads, + k_cache->nb[1], k_cache->nb[2], 0); // [#kvh, kvs, d] - value_layer = ggml_view_3d(ctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, num_kv_heads, + value_layer = ggml_view_3d(ctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, num_key_value_heads, v_cache->nb[1], v_cache->nb[2], 0); // [#kvh, d, kvs] // attention + query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size)); ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs] - attn_scores = ggml_scale_inplace(ctx, attn_scores, 1.f / std::sqrt(head_size)); if (n_past == 0) { // build attention mask for context input @@ -701,7 +702,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta if (num_shared_q_heads > 1) { attn_scores = ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen, - num_kv_heads); // [#kvh, (#h/#kvh) * s, kvs] + num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs] } } diff --git a/chatglm.h b/chatglm.h index c817157..f42be53 100644 --- a/chatglm.h +++ b/chatglm.h @@ -65,7 +65,7 @@ struct ConfigRecordV1 { // For compatibility struct ConfigRecordV1GQA : public ConfigRecordV1 { - int num_kv_heads; + int num_key_value_heads; }; // TODO: use json to serialize config @@ -109,15 +109,15 @@ class ModelConfig { ModelConfig() = default; ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads, - int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps, float rope_theta, + int num_key_value_heads, int num_hidden_layers, int intermediate_size, float norm_eps, float rope_theta, int num_virtual_tokens, int max_length, int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id, std::vector extra_eos_token_ids) : model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size), - num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers), - intermediate_size(intermediate_size), norm_eps(norm_eps), rope_theta(rope_theta), - num_virtual_tokens(num_virtual_tokens), max_length(max_length), bos_token_id(bos_token_id), - eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id), - extra_eos_token_ids(std::move(extra_eos_token_ids)) { + num_attention_heads(num_attention_heads), num_key_value_heads(num_key_value_heads), + num_hidden_layers(num_hidden_layers), intermediate_size(intermediate_size), norm_eps(norm_eps), + rope_theta(rope_theta), num_virtual_tokens(num_virtual_tokens), max_length(max_length), + bos_token_id(bos_token_id), eos_token_id(eos_token_id), pad_token_id(pad_token_id), + sep_token_id(sep_token_id), extra_eos_token_ids(std::move(extra_eos_token_ids)) { if (model_type == ModelType::CHATGLM) { hidden_act = ActivationType::GELU; use_qkv_bias = true; @@ -146,9 +146,10 @@ class ModelConfig { ModelConfig(ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, float rope_theta, int num_virtual_tokens) - : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads, - rec.num_hidden_layers, rec.intermediate_size, norm_eps, rope_theta, num_virtual_tokens, - rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {} + : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, + rec.num_key_value_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, rope_theta, + num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, + rec.sep_token_id, {}) {} ModelConfig(ModelType model_type, const ConfigRecordV2 &rec) : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, @@ -158,13 +159,33 @@ class ModelConfig { std::string model_type_name() const { return to_string(model_type); } + friend std::ostream &operator<<(std::ostream &os, const ModelConfig &self) { + os << "ModelConfig(model_type=" << (int)self.model_type << ", dtype=" << self.dtype + << ", vocab_size=" << self.vocab_size << ", hidden_size=" << self.hidden_size + << ", num_attention_heads=" << self.num_attention_heads + << ", num_key_value_heads=" << self.num_key_value_heads << ", num_hidden_layers=" << self.num_hidden_layers + << ", intermediate_size=" << self.intermediate_size << ", norm_eps=" << self.norm_eps + << ", hidden_act=" << (int)self.hidden_act << ", use_qkv_bias=" << self.use_qkv_bias + << ", use_dense_bias=" << self.use_dense_bias << ", interleaved_qkv=" << self.interleaved_qkv + << ", tie_word_embeddings=" << self.tie_word_embeddings << ", rope_type=" << (int)self.rope_type + << ", rope_theta=" << self.rope_theta << ", attn_mask_type=" << (int)self.attn_mask_type + << ", num_virtual_tokens=" << self.num_virtual_tokens << ", max_length=" << self.max_length + << ", bos_token_id=" << self.bos_token_id << ", eos_token_id=" << self.eos_token_id + << ", pad_token_id=" << self.pad_token_id << ", sep_token_id=" << self.sep_token_id + << ", extra_eos_token_ids={"; + for (size_t i = 0; i < self.extra_eos_token_ids.size(); i++) { + os << (i > 0 ? ", " : "") << self.extra_eos_token_ids[i]; + } + return os << "})"; + } + public: ModelType model_type; ggml_type dtype; int vocab_size; int hidden_size; int num_attention_heads; - int num_kv_heads; + int num_key_value_heads; int num_hidden_layers; int intermediate_size; float norm_eps; @@ -419,26 +440,26 @@ class BasicGLU { class BasicAttention { public: BasicAttention() = default; - BasicAttention(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length, - bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta, - AttentionMaskType attn_mask_type, int num_virtual_tokens) - : num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), interleaved_qkv(interleaved_qkv), - rope_type(rope_type), rope_theta(rope_theta), attn_mask_type(attn_mask_type), - num_virtual_tokens(num_virtual_tokens), - query_key_value(mctx, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_kv_heads, - use_qkv_bias), + BasicAttention(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads, + int max_length, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, + float rope_theta, AttentionMaskType attn_mask_type, int num_virtual_tokens) + : num_attention_heads(num_attention_heads), num_key_value_heads(num_key_value_heads), + interleaved_qkv(interleaved_qkv), rope_type(rope_type), rope_theta(rope_theta), + attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens), + query_key_value(mctx, hidden_size, + hidden_size + 2 * (hidden_size / num_attention_heads) * num_key_value_heads, use_qkv_bias), dense(mctx, hidden_size, hidden_size, use_dense_bias), k_cache(ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, - max_length + num_virtual_tokens, num_kv_heads)), + max_length + num_virtual_tokens, num_key_value_heads)), v_cache(ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens, - hidden_size / num_attention_heads, num_kv_heads)) {} + hidden_size / num_attention_heads, num_key_value_heads)) {} ggml_tensor *forward(ModelContext *mctx, ggml_tensor *hidden_states, ggml_tensor *attention_mask, ggml_tensor *position_ids, int n_past) const; public: int num_attention_heads; - int num_kv_heads; + int num_key_value_heads; bool interleaved_qkv; RopeType rope_type; float rope_theta; @@ -454,13 +475,13 @@ template class BasicBlock { public: BasicBlock() = default; - BasicBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size, - int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, - bool interleaved_qkv, RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type, - int num_virtual_tokens) + BasicBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads, + int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, + bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta, + AttentionMaskType attn_mask_type, int num_virtual_tokens) : input_layernorm(mctx, hidden_size, false, norm_eps), - attention(mctx, hidden_size, num_attention_heads, num_kv_heads, max_length, use_qkv_bias, use_dense_bias, - interleaved_qkv, rope_type, rope_theta, attn_mask_type, num_virtual_tokens), + attention(mctx, hidden_size, num_attention_heads, num_key_value_heads, max_length, use_qkv_bias, + use_dense_bias, interleaved_qkv, rope_type, rope_theta, attn_mask_type, num_virtual_tokens), post_attention_layernorm(mctx, hidden_size, false, norm_eps), mlp(mctx, hidden_size, intermediate_size, hidden_act) {} @@ -572,20 +593,20 @@ class BasicModel { auto &attn = layers[i].attention; ggml_tensor *virtual_key = ggml_view_3d(mctx.ctx_b.get(), past_key_values, head_size, config.num_virtual_tokens, - config.num_kv_heads, past_key_values->nb[1], past_key_values->nb[2], + config.num_key_value_heads, past_key_values->nb[1], past_key_values->nb[2], i * 2 * past_key_values->nb[3]); // [#h, v, d] ggml_tensor *k_cache_view = - ggml_view_3d(mctx.ctx_b.get(), attn.k_cache, head_size, config.num_virtual_tokens, config.num_kv_heads, - attn.k_cache->nb[1], attn.k_cache->nb[2], 0); // [#h, v, d] + ggml_view_3d(mctx.ctx_b.get(), attn.k_cache, head_size, config.num_virtual_tokens, + config.num_key_value_heads, attn.k_cache->nb[1], attn.k_cache->nb[2], 0); // [#h, v, d] ggml_build_forward_expand(mctx.gf, ggml_cpy(mctx.ctx_b.get(), virtual_key, k_cache_view)); ggml_tensor *virtual_value = ggml_view_3d( - mctx.ctx_b.get(), past_key_values, head_size, config.num_virtual_tokens, config.num_kv_heads, + mctx.ctx_b.get(), past_key_values, head_size, config.num_virtual_tokens, config.num_key_value_heads, past_key_values->nb[1], past_key_values->nb[2], (i * 2 + 1) * past_key_values->nb[3]); // [#h, v, d] virtual_value = ggml_permute(mctx.ctx_b.get(), virtual_value, 1, 0, 2, 3); // [#h, d, v] ggml_tensor *v_cache_view = - ggml_view_3d(mctx.ctx_b.get(), attn.v_cache, config.num_virtual_tokens, head_size, config.num_kv_heads, - attn.v_cache->nb[1], attn.v_cache->nb[2], 0); // [#h, d, v] + ggml_view_3d(mctx.ctx_b.get(), attn.v_cache, config.num_virtual_tokens, head_size, + config.num_key_value_heads, attn.v_cache->nb[1], attn.v_cache->nb[2], 0); // [#h, d, v] ggml_build_forward_expand(mctx.gf, ggml_cpy(mctx.ctx_b.get(), virtual_value, v_cache_view)); } @@ -598,7 +619,7 @@ class BasicModel { std::vector layers; layers.reserve(config.num_hidden_layers); for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) { - layers.emplace_back(mctx, config.hidden_size, config.num_attention_heads, config.num_kv_heads, + layers.emplace_back(mctx, config.hidden_size, config.num_attention_heads, config.num_key_value_heads, config.intermediate_size, config.max_length, config.norm_eps, config.hidden_act, config.use_qkv_bias, config.use_dense_bias, config.interleaved_qkv, config.rope_type, config.rope_theta, config.attn_mask_type, config.num_virtual_tokens); @@ -858,10 +879,10 @@ class ChatGLMTokenizer : public BaseTokenizer { class GLMBlock : public BasicBlock { public: GLMBlock() = default; - GLMBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size, - int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, - bool interleaved_qkv, RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type, - int num_virtual_tokens) + GLMBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads, + int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, + bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta, + AttentionMaskType attn_mask_type, int num_virtual_tokens) : BasicBlock(LayerNorm(mctx, hidden_size, false, norm_eps), BasicAttention(mctx, hidden_size, num_attention_heads, num_attention_heads, max_length, use_qkv_bias, use_dense_bias, interleaved_qkv, rope_type, rope_theta, diff --git a/chatglm_cpp/_C.pyi b/chatglm_cpp/_C.pyi index c1457f3..cae0f2f 100644 --- a/chatglm_cpp/_C.pyi +++ b/chatglm_cpp/_C.pyi @@ -104,7 +104,7 @@ class ModelConfig: def num_hidden_layers(self) -> int: ... @property - def num_kv_heads(self) -> int: + def num_key_value_heads(self) -> int: ... @property def pad_token_id(self) -> int: diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index 11f2beb..48da0b5 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -6,7 +6,7 @@ import chatglm_cpp._C as _C from chatglm_cpp._C import ChatMessage -__version__ = "0.4.0" +__version__ = "0.4.1" @dataclass diff --git a/chatglm_cpp/convert.py b/chatglm_cpp/convert.py index bf69c4e..f24a813 100644 --- a/chatglm_cpp/convert.py +++ b/chatglm_cpp/convert.py @@ -209,13 +209,13 @@ def convert(cls, f, model, tokenizer, ggml_type): cls.dump_model(f, model, ggml_type) -def get_prefix_cache(prefix_encoder, pre_seq_len, num_layers, num_kv_heads, head_size): +def get_prefix_cache(prefix_encoder, pre_seq_len, num_layers, num_key_value_heads, head_size): prefix_tokens = torch.arange(pre_seq_len, dtype=torch.long) with torch.no_grad(): past_key_values = prefix_encoder(prefix_tokens) past_key_values = ( past_key_values.to(torch.half) - .view(pre_seq_len, num_layers * 2, num_kv_heads, head_size) + .view(pre_seq_len, num_layers * 2, num_key_value_heads, head_size) .permute(1, 2, 0, 3) .contiguous() ) diff --git a/chatglm_pybind.cpp b/chatglm_pybind.cpp index e6aab45..5143e04 100644 --- a/chatglm_pybind.cpp +++ b/chatglm_pybind.cpp @@ -63,7 +63,7 @@ PYBIND11_MODULE(_C, m) { .def_readonly("vocab_size", &ModelConfig::vocab_size) .def_readonly("hidden_size", &ModelConfig::hidden_size) .def_readonly("num_attention_heads", &ModelConfig::num_attention_heads) - .def_readonly("num_kv_heads", &ModelConfig::num_kv_heads) + .def_readonly("num_key_value_heads", &ModelConfig::num_key_value_heads) .def_readonly("num_hidden_layers", &ModelConfig::num_hidden_layers) .def_readonly("intermediate_size", &ModelConfig::intermediate_size) .def_readonly("norm_eps", &ModelConfig::norm_eps) diff --git a/chatglm_test.cpp b/chatglm_test.cpp index 64c247d..96c66fd 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -308,7 +308,7 @@ class ChatGLMTest : public ::testing::Test { const int head_size = config.hidden_size / config.num_attention_heads; past_key_values = ggml_new_tensor_4d(mctx_->ctx_b.get(), GGML_TYPE_F16, head_size, config.num_virtual_tokens, - config.num_kv_heads, config.num_hidden_layers * 2); // [l * 2, #h, v, d] + config.num_key_value_heads, config.num_hidden_layers * 2); // [l * 2, #h, v, d] } auto buf_b = @@ -596,7 +596,7 @@ TEST_F(ChatGLMTest, GLMModel) { ModelConfig config( ModelType::CHATGLM, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, - /*num_kv_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/128, /*norm_eps=*/1e-5f, + /*num_key_value_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/128, /*norm_eps=*/1e-5f, /*rope_theta=*/10000.f, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, @@ -630,7 +630,7 @@ TEST_F(ChatGLMTest, GLMPTuningV2Model) { ModelConfig config( ModelType::CHATGLM, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, - /*num_kv_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/128, /*norm_eps=*/1e-5f, + /*num_key_value_heads=*/8, /*num_hidden_layers=*/1, /*intermediate_size=*/128, /*norm_eps=*/1e-5f, /*rope_theta=*/10000.f, /*num_virtual_tokens=*/5, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, @@ -664,7 +664,7 @@ TEST_F(ChatGLMTest, GLM2Model) { ModelConfig config( ModelType::CHATGLM2, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, - /*num_kv_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, + /*num_key_value_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, /*rope_theta=*/10000.f, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, @@ -693,7 +693,7 @@ TEST_F(ChatGLMTest, GLM3Model) { ModelConfig config( ModelType::CHATGLM3, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, - /*num_kv_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, + /*num_key_value_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, /*rope_theta=*/10000.f, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, @@ -722,7 +722,7 @@ TEST_F(ChatGLMTest, GLM3PTuningV2Model) { ModelConfig config( ModelType::CHATGLM3, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, - /*num_kv_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, + /*num_key_value_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, /*rope_theta=*/10000.f, /*num_virtual_tokens=*/5, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, @@ -751,7 +751,7 @@ TEST_F(ChatGLMTest, GLM4Model) { ModelConfig config( ModelType::CHATGLM4, GGML_TYPE_F32, /*vocab_size=*/5, /*hidden_size=*/32, /*num_attention_heads=*/8, - /*num_kv_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, + /*num_key_value_heads=*/2, /*num_hidden_layers=*/1, /*intermediate_size=*/48, /*norm_eps=*/1e-5f, /*rope_theta=*/10000.f, /*num_virtual_tokens=*/0, /*max_length=*/8, /*bos_token_id=*/-1, /*eos_token_id=*/-1, /*pad_token_id=*/-1, /*sep_token_id=*/-1, @@ -1208,8 +1208,7 @@ TEST(Pipeline, ChatGLM3) { { ChatMessage output = pipeline.chat(messages, gen_config); EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); - EXPECT_EQ(output.content, - "根据您的要求,我使用随机数生成器API生成了一个随机数。根据API返回的结果,生成的随机数为22。"); + EXPECT_EQ(output.content, "根据您的要求,我使用随机数生成器API生成了一个在0和100之间的随机数,结果为22。"); } } @@ -1226,9 +1225,7 @@ TEST(Pipeline, ChatGLM3) { EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); EXPECT_EQ(output.content, R"(好的,我会为您列出100以内的所有质数。 -质数是指只能被1和它本身整除的大于1的整数。例如,2、3、5、7等都是质数。 - -让我们开始吧!)"); +(Note: 质数是指只能被1和它本身整除的正整数。))"); EXPECT_EQ(output.tool_calls.front().code.input, R"(```python def is_prime(n): """Check if a number is prime.""" @@ -1245,7 +1242,6 @@ def is_prime(n): i += 6 return True -# Get all prime numbers up to 100 primes_upto_100 = [i for i in range(2, 101) if is_prime(i)] primes_upto_100 ```)"); @@ -1259,9 +1255,7 @@ primes_upto_100 EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); EXPECT_EQ(output.content, R"(100以内的所有质数为: -$$ -2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97 -$$)"); +$$2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97$$)"); } } } @@ -1406,7 +1400,7 @@ if __name__ == '__main__': gen_config.do_sample = false; std::vector messages{{ChatMessage::ROLE_USER, "你好"}}; ChatMessage output = pipeline.chat(messages, gen_config); - EXPECT_EQ(output.content, "你好👋!有什么可以帮助你的吗?"); + EXPECT_EQ(output.content, "你好👋!我是人工智能助手,很高兴见到你,有什么可以帮助你的吗?"); } } @@ -1436,12 +1430,12 @@ TEST(Pipeline, CodeGeeX2) { std::string prompt = "# language: Python\n# write a bubble sort function\n"; std::string target = R"( -def bubble_sort(list): - for i in range(len(list) - 1): - for j in range(len(list) - 1 - i): - if list[j] > list[j + 1]: - list[j], list[j + 1] = list[j + 1], list[j] - return list +def bubble_sort(lst): + for i in range(len(lst) - 1): + for j in range(len(lst) - 1 - i): + if lst[j] > lst[j + 1]: + lst[j], lst[j + 1] = lst[j + 1], lst[j] + return lst print(bubble_sort([5, 4, 3, 2, 1])))"; diff --git a/tests/test_chatglm_cpp.py b/tests/test_chatglm_cpp.py index 482c860..a490740 100644 --- a/tests/test_chatglm_cpp.py +++ b/tests/test_chatglm_cpp.py @@ -76,7 +76,7 @@ def test_chatglm4_pipeline(): check_pipeline( model_path=CHATGLM4_MODEL_PATH, prompt="你好", - target="你好👋!有什么可以帮助你的吗?", + target="你好👋!我是人工智能助手,很高兴见到你,有什么可以帮助你的吗?", ) @@ -85,12 +85,12 @@ def test_codegeex2_pipeline(): prompt = "# language: Python\n# write a bubble sort function\n" target = """ -def bubble_sort(list): - for i in range(len(list) - 1): - for j in range(len(list) - 1 - i): - if list[j] > list[j + 1]: - list[j], list[j + 1] = list[j + 1], list[j] - return list +def bubble_sort(lst): + for i in range(len(lst) - 1): + for j in range(len(lst) - 1 - i): + if lst[j] > lst[j + 1]: + lst[j], lst[j + 1] = lst[j + 1], lst[j] + return lst print(bubble_sort([5, 4, 3, 2, 1]))""" @@ -117,7 +117,7 @@ def test_langchain_api(): client = TestClient(app) response = client.post("/", json={"prompt": "你好", "temperature": 0}) assert response.status_code == 200 - assert response.json()["response"] == "你好👋!有什么可以帮助你的吗?" + assert response.json()["response"] == "你好👋!我是人工智能助手,很高兴见到你,有什么可以帮助你的吗?" @pytest.mark.skipif(not CHATGLM4_MODEL_PATH.exists(), reason="model file not found") @@ -137,4 +137,4 @@ def test_openai_api(): assert response.status_code == 200 response_message = response.json()["choices"][0]["message"] assert response_message["role"] == "assistant" - assert response_message["content"] == "你好👋!有什么可以帮助你的吗?" + assert response_message["content"] == "你好👋!我是人工智能助手,很高兴见到你,有什么可以帮助你的吗?"