Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DeepSeek V3 #11049

Merged
merged 10 commits into from
Jan 4, 2025
Merged

Conversation

fairydreaming
Copy link
Collaborator

@fairydreaming fairydreaming commented Jan 2, 2025

This PR adds support for recently released DeepSeek V3 model. (MoE, 671B)

The model is architecturally very similar to DeepSeek V2, there are only minor changes in expert weights calculation.

Summary of changes:

  • added boolean expert_weights_norm model parameter indicating whether expert weights shall be normalized or not - they were not normalized in DeepSeek V2 but they are in DeepSeek V3,
  • added numerical expert_gating_func model parameter corresponding to enum value indicating a function used to calculate expert probs - usually it's softmax, but DeepSeek V3 uses sigmoid for this purpose,
  • added expert_weights_b exp_probs_b tensor type containing expert weights bias tensors - DeepSeek V3 introduced bias term added to calculated expert probs, biased probs are the input to the top k experts selection process,
  • updated llm_build_moe_ffn() API and implementation to handle the mentioned differences,
  • added new pre-tokenization regex for DeepSeek V3 - some wise man could take a look if it needs any modifications to work correctly.

Note: DeepSeek V3 also introduced multi-token prediction (MTP), but I decided to skip this feature for now. MTP layer is ignored during model conversion and is not present in resulting GGUF file.

@github-actions github-actions bot added the python python script changes label Jan 2, 2025
@fairydreaming fairydreaming linked an issue Jan 2, 2025 that may be closed by this pull request
4 tasks
src/llama.cpp Outdated
Comment on lines 10299 to 10304
// add experts selection bias - introduced in DeepSeek V3
ggml_tensor * selection_probs = probs;
if (expert_weights_b != nullptr) {
selection_probs = ggml_add(ctx, probs, expert_weights_b);
cb(selection_probs, "ffn_moe_sigm_biased", il);
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be simplified to:

Suggested change
// add experts selection bias - introduced in DeepSeek V3
ggml_tensor * selection_probs = probs;
if (expert_weights_b != nullptr) {
selection_probs = ggml_add(ctx, probs, expert_weights_b);
cb(selection_probs, "ffn_moe_sigm_biased", il);
}
// add experts selection bias - introduced in DeepSeek V3
if (expert_weights_b != nullptr) {
probs = ggml_add(ctx, probs, expert_weights_b);
cb(probs, "ffn_moe_sigm_b", il);
}

Copy link
Collaborator Author

@fairydreaming fairydreaming Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this won't work correctly, as the original unmodified weights are still needed for multiplication with the experts output at the end of the function. Biased weights are used only for expert selection. See the DeepSeek V3 technical report:

Note that the bias term is only used for routing. The gating value, which will be multiplied with
the FFN output, is still derived from the original affinity score

Edit: I'm going to add a comment in the code to make it clear

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - I missed this.

@@ -312,6 +314,7 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
FFN_EXPERT_WEIGHTS_B = auto()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more consistency in the names, lets change the EXPERT to EXP. Also, it seems that PROBS is better name since this is a bias for the computed expert probabilities:

Suggested change
FFN_EXPERT_WEIGHTS_B = auto()
FFN_EXP_PROBS_B = auto()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -496,6 +499,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

src/llama.cpp Outdated
@@ -2912,6 +2934,7 @@ struct llama_layer {
struct ggml_tensor * ffn_down_b = nullptr; // b2
struct ggml_tensor * ffn_up_b = nullptr; // b3
struct ggml_tensor * ffn_act = nullptr;
struct ggml_tensor * ffn_expert_weights_bias = nullptr;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct ggml_tensor * ffn_expert_weights_bias = nullptr;
struct ggml_tensor * ffn_exp_probs_b = nullptr;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

src/llama.cpp Outdated
Comment on lines 10283 to 10297
ggml_tensor * probs = nullptr;
switch (gating_op) {
case LLM_EXPERT_GATING_FUNC_SOFTMAX:
{
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_probs", il);
} break;
case LLM_EXPERT_GATING_FUNC_SIGMOID:
{
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_sigm", il);
} break;
default:
GGML_ABORT("fatal error");
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't set names here:

Suggested change
ggml_tensor * probs = nullptr;
switch (gating_op) {
case LLM_EXPERT_GATING_FUNC_SOFTMAX:
{
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_probs", il);
} break;
case LLM_EXPERT_GATING_FUNC_SIGMOID:
{
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_sigm", il);
} break;
default:
GGML_ABORT("fatal error");
}
ggml_tensor * probs = nullptr;
switch (gating_op) {
case LLM_EXPERT_GATING_FUNC_SOFTMAX:
{
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
} break;
case LLM_EXPERT_GATING_FUNC_SIGMOID:
{
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
} break;
default:
GGML_ABORT("fatal error");
}

Instead, after applying the probs bias, call cb(probs, "ffn_moe_probs", il); for the final probs result.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved name setting after switch, but I kept it separate from biased probs for reasons mentioned earlier.

@fairydreaming
Copy link
Collaborator Author

@ggerganov I extended your "collapsed" regex workaround with \p{M} and \p{S} - DeepSeek V3 has these in pre-tokenizer regex. Take a look if it looks sane when you have a moment. I checked with test-tokenizer-0 and tokenization of wiki.test.raw now matches the original.

@fairydreaming
Copy link
Collaborator Author

fairydreaming commented Jan 3, 2025

@ggerganov Also since you merged #10902 I had to put expert_gating_func enum in a file included in both llama-hparams.h, llama.cpp and llama-model.cpp. I put it in llama.h, let me know if you have other plans for this enum.

@ggerganov
Copy link
Owner

@ggerganov Also since you merged #10902 I had to put expert_gating_func enum in a file included in both llama-hparams.h, llama.cpp and llama-model.cpp. I put it in llama.h, let me know if you have other plans for this enum.

Let's place it in llama-hparams.h for now. We can potentially make it public if we find some utility in the future, but for it's better to try to hide more things from the public API - there are some other enums in llama.h that can also be moved to the implementation.

We can merge after you move the llama_expert_gating_func_type to llama-hparams.h

@fairydreaming fairydreaming merged commit 9394bbd into ggerganov:master Jan 4, 2025
51 checks passed
netrunnereve pushed a commit to netrunnereve/llama.cpp that referenced this pull request Jan 5, 2025
* convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type

* vocab : add DeepSeek V3 pre-tokenizer regexes

* unicode : handle ACCENT_MARK and SYMBOL categories in regex

* llama : add DeepSeek V3 chat template, handle new model parameters and tensor types

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
@x66ccff x66ccff mentioned this pull request Jan 5, 2025
@emuchogu
Copy link

emuchogu commented Jan 8, 2025

Related to #11141

While DeepSeek V3 support has been added, there appears to be an ongoing issue specifically with the ROCm backend. When attempting to run DeepSeek models (both V2 and V3) with ROCm:

  • Models load successfully into VRAM
  • No output is generated
  • One GPU becomes pinned at 100% utilization
  • Other GPUs remain idle

This behavior is consistent across both DeepSeek V2 and V3 models. Would appreciate if this ROCm-specific issue could be investigated.

@fairydreaming
Copy link
Collaborator Author

@emuchogu Does it happen even with DeepSeek-V2-Lite?

@emuchogu
Copy link

emuchogu commented Jan 8, 2025

Yes. Same behavior with deepseek-v2-16b-lite-chat-q4_K_M.
I tested this using Ollama and got the same 100% GPU pegging with no output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: add DeepSeek-v3 support
4 participants