Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vulkan Mixture of Experts (MoE) support #7628

Merged
merged 8 commits into from
Jun 3, 2024
Merged

Vulkan Mixture of Experts (MoE) support #7628

merged 8 commits into from
Jun 3, 2024

Conversation

0cc4m
Copy link
Collaborator

@0cc4m 0cc4m commented May 29, 2024

Here's a basic version that can run MoE models. There's some bottleneck in the matrix vector MUL_MAT_ID shaders (or another MoE-specific one) so that generation is rather slow, but at least it runs now.

I had to implement MUL_MAT_ID, SUM_ROWS and DIV for MoE to work. MUL_MAT_ID was the complicated one, obviously.

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend python python script changes labels May 29, 2024
@lin72h
Copy link

lin72h commented May 29, 2024

Thanks for improving the Vulkan backend. Cross-platform support is very important to make LLM available for everyone.

@MaggotHATE
Copy link
Contributor

Compilation fails due to uint, replacing with unsigned int works. Using w64devkit, Windows 10, but maybe it's because of C++20.

As for performance, it is noticeably slow: the more layers I offload, the slower tg and pp are. However, it finally works, thank you!

Due to how MoE work, I had to download 2x7B in Q3_K_S to test this PR, and suddenly there's no difference in RAM consumption between CPU-only and Vulkan. Another 2x7B model in Q4_K_S still uses more RAM on Vulkan (too much for my system).

I don't normally use Q3_K_S, so I tried three other, non-MoE, models in the same Q3_K_S, and it didn't seem to help - there's still a difference in RAM consumption. Seems like the combination of 2x7b with Q3_K_S works without memory overhead, both with 4096 and 8192 context sizes.

@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label May 30, 2024
@0cc4m
Copy link
Collaborator Author

0cc4m commented Jun 1, 2024

Thank you for testing it. The occasions of uint were unintentional, I'll correct those. The bad generation performance is because the matrix matrix multiplication shader is being used for matrix vector multiplications. I'll look into that.

@slaren
Copy link
Collaborator

slaren commented Jun 1, 2024

I get this error with -sm none -mg 1 to use only the second GPU:

ggml_vulkan: Found 2 Vulkan devices:
Vulkan0: NVIDIA GeForce RTX 3080 | uma: 0 | fp16: 1 | warp size: 32
Vulkan1: NVIDIA GeForce RTX 3090 Ti | uma: 0 | fp16: 1 | warp size: 32
ggml_backend_sched_backend_from_buffer: error: no backend supports buffer type Vulkan1 used in tensor blk.0.attn_q.weight
GGML_ASSERT: C:\CODE\llama.cpp\ggml-backend.c:1115: false

@slaren
Copy link
Collaborator

slaren commented Jun 1, 2024

The offload_op needs to handle MUL_MAT_ID separately since the batch size is in a different dimension. This is the cuda impl:

llama.cpp/ggml-cuda.cu

Lines 2924 to 2931 in 750f60c

GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
GGML_UNUSED(backend);
}

@0cc4m
Copy link
Collaborator Author

0cc4m commented Jun 2, 2024

@slaren Thank you, I fixed the offload_op code and fixed the split mode none + main gpu case for Vulkan.

@MaggotHATE
Copy link
Contributor

MaggotHATE commented Jun 2, 2024

Thanks for the update - inference speed is quite good now! However, prompt processing is still slow, and I noticed that adding more threads (for example, from 3 to 6) boosts performance greatly. At the same time, I'm using OpenMP PR too, so maybe it affects the result.

Memory is fixed for MoEs now - 2x7B in Q4_K_S uses the same amount of RAM in CPU-only and Vulkan. For non-MoE models the RAM difference is still there. UPD: Looks like it's all about context size - going from 4096 to 8192 costs much more memory on Vulkan, to the point of not fitting into 16GB RAM and 3GB VRAM.

@0cc4m 0cc4m merged commit 3d7ebf6 into master Jun 3, 2024
68 of 72 checks passed
@0cc4m 0cc4m deleted the 0cc4m/vulkan-moe branch June 3, 2024 08:59
Copy link
Contributor

github-actions bot commented Jun 3, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 536 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8753.83ms p(95)=22160.16ms fails=, finish reason: stop=479 truncated=57
  • Prompt processing (pp): avg=98.78tk/s p(95)=458.41tk/s
  • Token generation (tg): avg=47.37tk/s p(95)=46.2tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=0cc4m/vulkan-moe commit=fe3f6958bd64ecce4ac6548f69af7594fb8913db

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717418092 --> 1717418726
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 738.27, 738.27, 738.27, 738.27, 738.27, 846.4, 846.4, 846.4, 846.4, 846.4, 825.75, 825.75, 825.75, 825.75, 825.75, 848.83, 848.83, 848.83, 848.83, 848.83, 903.53, 903.53, 903.53, 903.53, 903.53, 899.04, 899.04, 899.04, 899.04, 899.04, 906.63, 906.63, 906.63, 906.63, 906.63, 914.29, 914.29, 914.29, 914.29, 914.29, 908.53, 908.53, 908.53, 908.53, 908.53, 921.62, 921.62, 921.62, 921.62, 921.62, 932.0, 932.0, 932.0, 932.0, 932.0, 917.3, 917.3, 917.3, 917.3, 917.3, 930.09, 930.09, 930.09, 930.09, 930.09, 920.46, 920.46, 920.46, 920.46, 920.46, 919.06, 919.06, 919.06, 919.06, 919.06, 907.81, 907.81, 907.81, 907.81, 907.81, 903.44, 903.44, 903.44, 903.44, 903.44, 917.8, 917.8, 917.8, 917.8, 917.8, 916.61, 916.61, 916.61, 916.61, 916.61, 910.13, 910.13, 910.13, 910.13, 910.13, 915.92, 915.92, 915.92, 915.92, 915.92, 915.99, 915.99, 915.99, 915.99, 915.99, 914.87, 914.87, 914.87, 914.87, 914.87, 903.2, 903.2, 903.2, 903.2, 903.2, 904.64, 904.64, 904.64, 904.64, 904.64, 905.67, 905.67, 905.67, 905.67, 905.67, 917.3, 917.3, 917.3, 917.3, 917.3, 912.91, 912.91, 912.91, 912.91, 912.91, 911.16, 911.16, 911.16, 911.16, 911.16, 912.06, 912.06, 912.06, 912.06, 912.06, 914.35, 914.35, 914.35, 914.35, 914.35, 911.08, 911.08, 911.08, 911.08, 911.08, 913.8, 913.8, 913.8, 913.8, 913.8, 913.3, 913.3, 913.3, 913.3, 913.3, 917.06, 917.06, 917.06, 917.06, 917.06, 895.8, 895.8, 895.8, 895.8, 895.8, 892.79, 892.79, 892.79, 892.79, 892.79, 890.64, 890.64, 890.64, 890.64, 890.64, 887.66, 887.66, 887.66, 887.66, 887.66, 889.7, 889.7, 889.7, 889.7, 889.7, 889.26, 889.26, 889.26, 889.26, 889.26, 857.68, 857.68, 857.68, 857.68, 857.68, 815.82, 815.82, 815.82, 815.82, 815.82, 815.82, 815.82, 815.82, 815.82, 815.82, 815.08, 815.08, 815.08, 815.08, 815.08, 808.09, 808.09, 808.09, 808.09, 808.09, 811.11, 811.11, 811.11, 811.11, 811.11, 813.97, 813.97, 813.97, 813.97, 813.97, 813.71, 813.71, 813.71, 813.71, 813.71, 814.37, 814.37, 814.37, 814.37, 814.37, 814.56, 814.56, 814.56, 814.56, 814.56, 817.3, 817.3, 817.3, 817.3, 817.3, 820.31, 820.31, 820.31, 820.31, 820.31, 820.86, 820.86, 820.86, 820.86, 820.86, 827.89, 827.89, 827.89, 827.89, 827.89, 825.34, 825.34, 825.34, 825.34, 825.34, 825.43, 825.43, 825.43, 825.43, 825.43, 825.81, 825.81, 825.81, 825.81, 825.81, 827.23, 827.23, 827.23, 827.23, 827.23, 828.52, 828.52, 828.52, 828.52, 828.52, 828.13, 828.13, 828.13]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717418092 --> 1717418726
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.88, 40.88, 40.88, 40.88, 40.88, 29.35, 29.35, 29.35, 29.35, 29.35, 27.41, 27.41, 27.41, 27.41, 27.41, 29.91, 29.91, 29.91, 29.91, 29.91, 29.88, 29.88, 29.88, 29.88, 29.88, 30.76, 30.76, 30.76, 30.76, 30.76, 32.53, 32.53, 32.53, 32.53, 32.53, 33.1, 33.1, 33.1, 33.1, 33.1, 33.42, 33.42, 33.42, 33.42, 33.42, 33.95, 33.95, 33.95, 33.95, 33.95, 33.88, 33.88, 33.88, 33.88, 33.88, 33.9, 33.9, 33.9, 33.9, 33.9, 32.83, 32.83, 32.83, 32.83, 32.83, 31.9, 31.9, 31.9, 31.9, 31.9, 31.76, 31.76, 31.76, 31.76, 31.76, 30.06, 30.06, 30.06, 30.06, 30.06, 30.37, 30.37, 30.37, 30.37, 30.37, 30.53, 30.53, 30.53, 30.53, 30.53, 30.26, 30.26, 30.26, 30.26, 30.26, 29.79, 29.79, 29.79, 29.79, 29.79, 29.76, 29.76, 29.76, 29.76, 29.76, 29.75, 29.75, 29.75, 29.75, 29.75, 29.86, 29.86, 29.86, 29.86, 29.86, 30.04, 30.04, 30.04, 30.04, 30.04, 30.22, 30.22, 30.22, 30.22, 30.22, 30.51, 30.51, 30.51, 30.51, 30.51, 30.4, 30.4, 30.4, 30.4, 30.4, 30.11, 30.11, 30.11, 30.11, 30.11, 30.09, 30.09, 30.09, 30.09, 30.09, 30.25, 30.25, 30.25, 30.25, 30.25, 30.48, 30.48, 30.48, 30.48, 30.48, 30.62, 30.62, 30.62, 30.62, 30.62, 30.8, 30.8, 30.8, 30.8, 30.8, 30.92, 30.92, 30.92, 30.92, 30.92, 30.55, 30.55, 30.55, 30.55, 30.55, 30.5, 30.5, 30.5, 30.5, 30.5, 30.2, 30.2, 30.2, 30.2, 30.2, 30.18, 30.18, 30.18, 30.18, 30.18, 30.3, 30.3, 30.3, 30.3, 30.3, 30.5, 30.5, 30.5, 30.5, 30.5, 30.67, 30.67, 30.67, 30.67, 30.67, 30.69, 30.69, 30.69, 30.69, 30.69, 30.49, 30.49, 30.49, 30.49, 30.49, 30.49, 30.49, 30.49, 30.49, 30.49, 30.31, 30.31, 30.31, 30.31, 30.31, 28.82, 28.82, 28.82, 28.82, 28.82, 28.65, 28.65, 28.65, 28.65, 28.65, 28.64, 28.64, 28.64, 28.64, 28.64, 28.58, 28.58, 28.58, 28.58, 28.58, 28.58, 28.58, 28.58, 28.58, 28.58, 28.67, 28.67, 28.67, 28.67, 28.67, 28.77, 28.77, 28.77, 28.77, 28.77, 28.77, 28.77, 28.77, 28.77, 28.77, 28.69, 28.69, 28.69, 28.69, 28.69, 28.67, 28.67, 28.67, 28.67, 28.67, 28.6, 28.6, 28.6, 28.6, 28.6, 28.68, 28.68, 28.68, 28.68, 28.68, 28.77, 28.77, 28.77, 28.77, 28.77, 28.92, 28.92, 28.92, 28.92, 28.92, 28.99, 28.99, 28.99, 28.99, 28.99, 29.08, 29.08, 29.08]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717418092 --> 1717418726
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.31, 0.31, 0.31, 0.31, 0.31, 0.36, 0.36, 0.36, 0.36, 0.36, 0.26, 0.26, 0.26, 0.26, 0.26, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.1, 0.1, 0.1, 0.1, 0.1, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.18, 0.18, 0.18, 0.18, 0.18, 0.3, 0.3, 0.3, 0.3, 0.3, 0.27, 0.27, 0.27, 0.27, 0.27, 0.45, 0.45, 0.45, 0.45, 0.45, 0.39, 0.39, 0.39, 0.39, 0.39, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.35, 0.35, 0.35, 0.35, 0.35, 0.3, 0.3, 0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.15, 0.15, 0.23, 0.23, 0.23, 0.23, 0.23, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.27, 0.27, 0.27, 0.27, 0.27, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.34, 0.34, 0.34, 0.34, 0.34, 0.27, 0.27, 0.27, 0.27, 0.27, 0.09, 0.09, 0.09, 0.09, 0.09, 0.15, 0.15, 0.15, 0.15, 0.15, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.25, 0.25, 0.25, 0.25, 0.25, 0.23, 0.23, 0.23, 0.23, 0.23, 0.39, 0.39, 0.39, 0.39, 0.39, 0.3, 0.3, 0.3, 0.3, 0.3, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.49, 0.49, 0.49, 0.49, 0.49, 0.65, 0.65, 0.65, 0.65, 0.65, 0.66, 0.66, 0.66, 0.66, 0.66, 0.7, 0.7, 0.7, 0.7, 0.7, 0.25, 0.25, 0.25, 0.25, 0.25, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.23, 0.23, 0.23, 0.23, 0.23, 0.1, 0.1, 0.1, 0.1, 0.1, 0.21, 0.21, 0.21, 0.21, 0.21, 0.16, 0.16, 0.16, 0.16, 0.16, 0.28, 0.28, 0.28, 0.28, 0.28, 0.22, 0.22, 0.22, 0.22, 0.22, 0.26, 0.26, 0.26, 0.26, 0.26, 0.21, 0.21, 0.21, 0.21, 0.21, 0.14, 0.14, 0.14, 0.14, 0.14, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22, 0.22]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717418092 --> 1717418726
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0]
                    
Loading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants