Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: fix FA out-of-bounds writes #7465

Merged
merged 1 commit into from
May 22, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

Fixes #7421 .

The issue seems to be that all FlashAttention kernels except for the one using tensor cores are missing a check to avoid out-of-bounds writes for batch sizes that are not powers of 2.

@m18coppola
Copy link
Contributor

This resolved my issues regarding FA. Tested on 1x p40 and 2x p40 using llama3-8B @ Q8_0.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 22, 2024
@JohannesGaessler JohannesGaessler merged commit 38c0347 into ggml-org:master May 22, 2024
61 of 72 checks passed
@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label May 22, 2024
Copy link
Contributor

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 536 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8726.62ms p(95)=21076.63ms fails=, finish reason: stop=477 truncated=59
  • Prompt processing (pp): avg=93.93tk/s p(95)=354.8tk/s
  • Token generation (tg): avg=46.92tk/s p(95)=45.15tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-fix-fa-oob commit=f9357395bb5f7aa0c1e5a2f85aab6981dc51b1a2

prompt_tokens_seconds

More
Loading
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716396323 --> 1716396953
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 678.43, 678.43, 678.43, 678.43, 678.43, 629.29, 629.29, 629.29, 629.29, 629.29, 655.74, 655.74, 655.74, 655.74, 655.74, 719.34, 719.34, 719.34, 719.34, 719.34, 738.7, 738.7, 738.7, 738.7, 738.7, 738.0, 738.0, 738.0, 738.0, 738.0, 741.15, 741.15, 741.15, 741.15, 741.15, 759.13, 759.13, 759.13, 759.13, 759.13, 759.01, 759.01, 759.01, 759.01, 759.01, 779.27, 779.27, 779.27, 779.27, 779.27, 804.36, 804.36, 804.36, 804.36, 804.36, 845.51, 845.51, 845.51, 845.51, 845.51, 883.04, 883.04, 883.04, 883.04, 883.04, 897.93, 897.93, 897.93, 897.93, 897.93, 881.37, 881.37, 881.37, 881.37, 881.37, 882.26, 882.26, 882.26, 882.26, 882.26, 887.12, 887.12, 887.12, 887.12, 887.12, 896.2, 896.2, 896.2, 896.2, 896.2, 894.44, 894.44, 894.44, 894.44, 894.44, 889.52, 889.52, 889.52, 889.52, 889.52, 895.12, 895.12, 895.12, 895.12, 895.12, 896.84, 896.84, 896.84, 896.84, 896.84, 917.69, 917.69, 917.69, 917.69, 917.69, 909.8, 909.8, 909.8, 909.8, 909.8, 904.49, 904.49, 904.49, 904.49, 904.49, 904.42, 904.42, 904.42, 904.42, 904.42, 886.4, 886.4, 886.4, 886.4, 886.4, 885.69, 885.69, 885.69, 885.69, 885.69, 884.89, 884.89, 884.89, 884.89, 884.89, 886.4, 886.4, 886.4, 886.4, 886.4, 885.49, 885.49, 885.49, 885.49, 885.49, 882.41, 882.41, 882.41, 882.41, 882.41, 886.7, 886.7, 886.7, 886.7, 886.7, 882.92, 882.92, 882.92, 882.92, 882.92, 886.63, 886.63, 886.63, 886.63, 886.63, 874.82, 874.82, 874.82, 874.82, 874.82, 873.14, 873.14, 873.14, 873.14, 873.14, 870.18, 870.18, 870.18, 870.18, 870.18, 867.18, 867.18, 867.18, 867.18, 867.18, 871.03, 871.03, 871.03, 871.03, 871.03, 871.48, 871.48, 871.48, 871.48, 871.48, 877.82, 877.82, 877.82, 877.82, 877.82, 852.94, 852.94, 852.94, 852.94, 852.94, 854.26, 854.26, 854.26, 854.26, 854.26, 853.05, 853.05, 853.05, 853.05, 853.05, 851.91, 851.91, 851.91, 851.91, 851.91, 846.22, 846.22, 846.22, 846.22, 846.22, 851.34, 851.34, 851.34, 851.34, 851.34, 851.91, 851.91, 851.91, 851.91, 851.91, 850.91, 850.91, 850.91, 850.91, 850.91, 853.32, 853.32, 853.32, 853.32, 853.32, 855.0, 855.0, 855.0, 855.0, 855.0, 856.94, 856.94, 856.94, 856.94, 856.94, 858.87, 858.87, 858.87, 858.87, 858.87, 859.77, 859.77, 859.77, 859.77, 859.77, 852.12, 852.12, 852.12, 852.12, 852.12, 851.99, 851.99, 851.99, 851.99, 851.99, 851.65, 851.65, 851.65, 851.65, 851.65, 853.07, 853.07, 853.07, 853.07, 853.07, 854.75, 854.75, 854.75, 854.75, 854.75, 854.78]
                    
predicted_tokens_seconds
More
Loading
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716396323 --> 1716396953
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 46.17, 46.17, 46.17, 46.17, 46.17, 33.8, 33.8, 33.8, 33.8, 33.8, 31.9, 31.9, 31.9, 31.9, 31.9, 33.33, 33.33, 33.33, 33.33, 33.33, 33.16, 33.16, 33.16, 33.16, 33.16, 33.41, 33.41, 33.41, 33.41, 33.41, 33.97, 33.97, 33.97, 33.97, 33.97, 34.37, 34.37, 34.37, 34.37, 34.37, 34.54, 34.54, 34.54, 34.54, 34.54, 34.37, 34.37, 34.37, 34.37, 34.37, 34.62, 34.62, 34.62, 34.62, 34.62, 34.59, 34.59, 34.59, 34.59, 34.59, 33.54, 33.54, 33.54, 33.54, 33.54, 33.0, 33.0, 33.0, 33.0, 33.0, 32.06, 32.06, 32.06, 32.06, 32.06, 31.37, 31.37, 31.37, 31.37, 31.37, 31.72, 31.72, 31.72, 31.72, 31.72, 31.27, 31.27, 31.27, 31.27, 31.27, 31.08, 31.08, 31.08, 31.08, 31.08, 30.45, 30.45, 30.45, 30.45, 30.45, 30.33, 30.33, 30.33, 30.33, 30.33, 30.58, 30.58, 30.58, 30.58, 30.58, 30.73, 30.73, 30.73, 30.73, 30.73, 30.58, 30.58, 30.58, 30.58, 30.58, 30.62, 30.62, 30.62, 30.62, 30.62, 30.76, 30.76, 30.76, 30.76, 30.76, 30.6, 30.6, 30.6, 30.6, 30.6, 30.65, 30.65, 30.65, 30.65, 30.65, 30.9, 30.9, 30.9, 30.9, 30.9, 31.03, 31.03, 31.03, 31.03, 31.03, 31.12, 31.12, 31.12, 31.12, 31.12, 31.26, 31.26, 31.26, 31.26, 31.26, 31.42, 31.42, 31.42, 31.42, 31.42, 31.49, 31.49, 31.49, 31.49, 31.49, 31.35, 31.35, 31.35, 31.35, 31.35, 31.12, 31.12, 31.12, 31.12, 31.12, 30.73, 30.73, 30.73, 30.73, 30.73, 30.23, 30.23, 30.23, 30.23, 30.23, 30.28, 30.28, 30.28, 30.28, 30.28, 30.43, 30.43, 30.43, 30.43, 30.43, 30.48, 30.48, 30.48, 30.48, 30.48, 30.68, 30.68, 30.68, 30.68, 30.68, 30.61, 30.61, 30.61, 30.61, 30.61, 30.33, 30.33, 30.33, 30.33, 30.33, 30.17, 30.17, 30.17, 30.17, 30.17, 29.78, 29.78, 29.78, 29.78, 29.78, 28.65, 28.65, 28.65, 28.65, 28.65, 28.68, 28.68, 28.68, 28.68, 28.68, 28.66, 28.66, 28.66, 28.66, 28.66, 28.63, 28.63, 28.63, 28.63, 28.63, 28.63, 28.63, 28.63, 28.63, 28.63, 28.66, 28.66, 28.66, 28.66, 28.66, 28.72, 28.72, 28.72, 28.72, 28.72, 28.73, 28.73, 28.73, 28.73, 28.73, 28.69, 28.69, 28.69, 28.69, 28.69, 28.72, 28.72, 28.72, 28.72, 28.72, 28.73, 28.73, 28.73, 28.73, 28.73, 28.76, 28.76, 28.76, 28.76, 28.76, 28.95, 28.95, 28.95, 28.95, 28.95, 29.01, 29.01, 29.01, 29.01, 29.01, 29.13]
                    

Details

kv_cache_usage_ratio

More
Loading
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716396323 --> 1716396953
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.34, 0.34, 0.34, 0.34, 0.34, 0.24, 0.24, 0.24, 0.24, 0.24, 0.27, 0.27, 0.27, 0.27, 0.27, 0.22, 0.22, 0.22, 0.22, 0.22, 0.21, 0.21, 0.21, 0.21, 0.21, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.27, 0.27, 0.27, 0.27, 0.27, 0.32, 0.32, 0.32, 0.32, 0.32, 0.37, 0.37, 0.37, 0.37, 0.37, 0.33, 0.33, 0.33, 0.33, 0.33, 0.11, 0.11, 0.11, 0.11, 0.11, 0.22, 0.22, 0.22, 0.22, 0.22, 0.33, 0.33, 0.33, 0.33, 0.33, 0.4, 0.4, 0.4, 0.4, 0.4, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.3, 0.3, 0.3, 0.3, 0.3, 0.18, 0.18, 0.18, 0.18, 0.18, 0.14, 0.14, 0.14, 0.14, 0.14, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.26, 0.26, 0.26, 0.26, 0.26, 0.35, 0.35, 0.35, 0.35, 0.35, 0.49, 0.49, 0.49, 0.49, 0.49, 0.34, 0.34, 0.34, 0.34, 0.34, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.28, 0.28, 0.28, 0.28, 0.28, 0.51, 0.51, 0.51, 0.51, 0.51, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.53, 0.53, 0.53, 0.53, 0.53, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.28, 0.28, 0.28, 0.28, 0.28, 0.2, 0.2, 0.2, 0.2, 0.2, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.28, 0.28, 0.28, 0.28, 0.28, 0.27, 0.27, 0.27, 0.27, 0.27, 0.12, 0.12, 0.12, 0.12, 0.12, 0.25, 0.25, 0.25, 0.25, 0.25, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13]
                    
requests_processing
More
Loading
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 536 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716396323 --> 1716396953
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0]
                    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Generating the same token (token '1') over and over, after a few successful messages?
4 participants