Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: implement __hmax and __hmax2 for CUDA < 11.7 #7019

Merged
merged 1 commit into from
May 1, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR implements __hmax and __hmax2 for CUDA < 11.7. I don't know how well they perform relative to the built-in functions but without them FlashAttention will not work at all.

@JohannesGaessler JohannesGaessler mentioned this pull request May 1, 2024
8 tasks
@LostRuins
Copy link
Collaborator

Ah, I actually just did an awful hack to do the same thing and got it to work, and was about to PR it.

LostRuins@cea4675

Let me test yours though, as it seems like a cleaner and better impl.

@LostRuins
Copy link
Collaborator

@JohannesGaessler unfortunately, your version fails to compile with this error:

common.cuh(299): error : more than one conversion function from "const half" to a built-in type applies

also

common.cuh(303): error : more than one conversion function from "__half" to a built-in type applies

@JohannesGaessler JohannesGaessler force-pushed the cuda-hmax-fix branch 2 times, most recently from bc8ac98 to 24ea3c6 Compare May 1, 2024 09:25
@JohannesGaessler
Copy link
Collaborator Author

I think the issue was the half comparison, I've changed it to convert the values to float for the comparison.

@LostRuins
Copy link
Collaborator

Alright, let me give it another test.

@LostRuins
Copy link
Collaborator

Hmm, no it's still not working for me. Now I am getting a very weird

ggml-cuda/common.cuh(298): error: function "__hmax(__half, __half)" has already been defined
ggml-cuda/common.cuh(301): error: function "__hmax2(__half2, __half2)" has already been defined

You can view my build environment and CI error logs here (Sorry, I know this is kobold, but it's compiling the same cuda files):
https://github.com/LostRuins/koboldcpp/actions/runs/8907906401/job/24462587092#step:4:226

@JohannesGaessler
Copy link
Collaborator Author

I've changed it to be similar to how you originally did it. The performance for old CUDA versions will be potentially worse but if people want to use such an old CUDA version my stance is that they'll just have to live with it.

Copy link
Contributor

github-actions bot commented May 1, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 560 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8320.38ms p(95)=20339.41ms fails=, finish reason: stop=509 truncated=51
  • Prompt processing (pp): avg=96.24tk/s p(95)=454.53tk/s
  • Token generation (tg): avg=34.21tk/s p(95)=49.44tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-hmax-fix commit=859734eecc2605a24b5478680082b106483142d6

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 560 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1714566747 --> 1714567371
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 374.82, 374.82, 374.82, 374.82, 374.82, 714.08, 714.08, 714.08, 714.08, 714.08, 673.8, 673.8, 673.8, 673.8, 673.8, 752.27, 752.27, 752.27, 752.27, 752.27, 771.4, 771.4, 771.4, 771.4, 771.4, 773.66, 773.66, 773.66, 773.66, 773.66, 808.11, 808.11, 808.11, 808.11, 808.11, 808.12, 808.12, 808.12, 808.12, 808.12, 818.55, 818.55, 818.55, 818.55, 818.55, 847.64, 847.64, 847.64, 847.64, 847.64, 838.24, 838.24, 838.24, 838.24, 838.24, 854.17, 854.17, 854.17, 854.17, 854.17, 856.49, 856.49, 856.49, 856.49, 856.49, 859.52, 859.52, 859.52, 859.52, 859.52, 865.53, 865.53, 865.53, 865.53, 865.53, 862.34, 862.34, 862.34, 862.34, 862.34, 867.46, 867.46, 867.46, 867.46, 867.46, 868.98, 868.98, 868.98, 868.98, 868.98, 873.87, 873.87, 873.87, 873.87, 873.87, 873.21, 873.21, 873.21, 873.21, 873.21, 876.21, 876.21, 876.21, 876.21, 876.21, 840.14, 840.14, 840.14, 840.14, 840.14, 843.13, 843.13, 843.13, 843.13, 843.13, 845.18, 845.18, 845.18, 845.18, 845.18, 854.8, 854.8, 854.8, 854.8, 854.8, 852.05, 852.05, 852.05, 852.05, 852.05, 850.89, 850.89, 850.89, 850.89, 850.89, 853.15, 853.15, 853.15, 853.15, 853.15, 854.33, 854.33, 854.33, 854.33, 854.33, 852.51, 852.51, 852.51, 852.51, 852.51, 856.8, 856.8, 856.8, 856.8, 856.8, 866.8, 866.8, 866.8, 866.8, 866.8, 875.43, 875.43, 875.43, 875.43, 875.43, 871.5, 871.5, 871.5, 871.5, 871.5, 871.11, 871.11, 871.11, 871.11, 871.11, 869.87, 869.87, 869.87, 869.87, 869.87, 871.29, 871.29, 871.29, 871.29, 871.29, 873.22, 873.22, 873.22, 873.22, 873.22, 872.94, 872.94, 872.94, 872.94, 872.94, 882.32, 882.32, 882.32, 882.32, 882.32, 884.8, 884.8, 884.8, 884.8, 884.8, 883.47, 883.47, 883.47, 883.47, 883.47, 881.87, 881.87, 881.87, 881.87, 881.87, 880.55, 880.55, 880.55, 880.55, 880.55, 881.17, 881.17, 881.17, 881.17, 881.17, 883.79, 883.79, 883.79, 883.79, 883.79, 882.84, 882.84, 882.84, 882.84, 882.84, 885.07, 885.07, 885.07, 885.07, 885.07, 886.9, 886.9, 886.9, 886.9, 886.9, 890.27, 890.27, 890.27, 890.27, 890.27, 889.35, 889.35, 889.35, 889.35, 889.35, 892.47, 892.47, 892.47, 892.47, 892.47, 894.11, 894.11, 894.11, 894.11, 894.11, 894.76, 894.76, 894.76, 894.76, 894.76, 893.97, 893.97, 893.97, 893.97, 893.97, 894.7, 894.7, 894.7, 894.7, 894.7, 895.14, 895.14, 895.14, 895.14, 895.14, 898.23, 898.23, 898.23, 898.23, 898.23, 896.58, 896.58, 896.58, 896.58, 896.58, 895.71, 895.71, 895.71, 895.71, 895.71, 895.0, 895.0, 895.0]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 560 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1714566747 --> 1714567371
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 44.03, 44.03, 44.03, 44.03, 44.03, 45.51, 45.51, 45.51, 45.51, 45.51, 32.7, 32.7, 32.7, 32.7, 32.7, 36.21, 36.21, 36.21, 36.21, 36.21, 36.17, 36.17, 36.17, 36.17, 36.17, 37.54, 37.54, 37.54, 37.54, 37.54, 38.54, 38.54, 38.54, 38.54, 38.54, 38.6, 38.6, 38.6, 38.6, 38.6, 37.67, 37.67, 37.67, 37.67, 37.67, 37.18, 37.18, 37.18, 37.18, 37.18, 37.05, 37.05, 37.05, 37.05, 37.05, 35.64, 35.64, 35.64, 35.64, 35.64, 34.98, 34.98, 34.98, 34.98, 34.98, 33.87, 33.87, 33.87, 33.87, 33.87, 33.68, 33.68, 33.68, 33.68, 33.68, 33.8, 33.8, 33.8, 33.8, 33.8, 33.51, 33.51, 33.51, 33.51, 33.51, 33.4, 33.4, 33.4, 33.4, 33.4, 33.01, 33.01, 33.01, 33.01, 33.01, 32.66, 32.66, 32.66, 32.66, 32.66, 32.88, 32.88, 32.88, 32.88, 32.88, 33.01, 33.01, 33.01, 33.01, 33.01, 33.1, 33.1, 33.1, 33.1, 33.1, 33.36, 33.36, 33.36, 33.36, 33.36, 33.31, 33.31, 33.31, 33.31, 33.31, 32.68, 32.68, 32.68, 32.68, 32.68, 32.53, 32.53, 32.53, 32.53, 32.53, 32.75, 32.75, 32.75, 32.75, 32.75, 32.84, 32.84, 32.84, 32.84, 32.84, 32.97, 32.97, 32.97, 32.97, 32.97, 33.07, 33.07, 33.07, 33.07, 33.07, 32.91, 32.91, 32.91, 32.91, 32.91, 32.91, 32.91, 32.91, 32.91, 32.91, 32.74, 32.74, 32.74, 32.74, 32.74, 32.51, 32.51, 32.51, 32.51, 32.51, 32.41, 32.41, 32.41, 32.41, 32.41, 32.6, 32.6, 32.6, 32.6, 32.6, 32.73, 32.73, 32.73, 32.73, 32.73, 32.78, 32.78, 32.78, 32.78, 32.78, 32.55, 32.55, 32.55, 32.55, 32.55, 32.55, 32.55, 32.55, 32.55, 32.55, 31.76, 31.76, 31.76, 31.76, 31.76, 31.62, 31.62, 31.62, 31.62, 31.62, 30.57, 30.57, 30.57, 30.57, 30.57, 30.15, 30.15, 30.15, 30.15, 30.15, 30.15, 30.15, 30.15, 30.15, 30.15, 30.16, 30.16, 30.16, 30.16, 30.16, 30.29, 30.29, 30.29, 30.29, 30.29, 30.36, 30.36, 30.36, 30.36, 30.36, 30.51, 30.51, 30.51, 30.51, 30.51, 30.49, 30.49, 30.49, 30.49, 30.49, 30.48, 30.48, 30.48, 30.48, 30.48, 30.39, 30.39, 30.39, 30.39, 30.39, 30.39, 30.39, 30.39, 30.39, 30.39, 30.55, 30.55, 30.55, 30.55, 30.55, 30.71, 30.71, 30.71, 30.71, 30.71, 30.83, 30.83, 30.83, 30.83, 30.83, 30.91, 30.91, 30.91, 30.91, 30.91, 30.91, 30.91, 30.91, 30.91, 30.91, 30.93, 30.93, 30.93, 30.93, 30.93, 30.96, 30.96, 30.96]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 560 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1714566747 --> 1714567371
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11, 0.11, 0.11, 0.11, 0.11, 0.44, 0.44, 0.44, 0.44, 0.44, 0.14, 0.14, 0.14, 0.14, 0.14, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.26, 0.26, 0.26, 0.26, 0.26, 0.21, 0.21, 0.21, 0.21, 0.21, 0.35, 0.35, 0.35, 0.35, 0.35, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.24, 0.24, 0.24, 0.24, 0.24, 0.29, 0.29, 0.29, 0.29, 0.29, 0.24, 0.24, 0.24, 0.24, 0.24, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.3, 0.3, 0.3, 0.3, 0.3, 0.22, 0.22, 0.22, 0.22, 0.22, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.31, 0.31, 0.31, 0.31, 0.31, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.36, 0.36, 0.36, 0.36, 0.36, 0.55, 0.55, 0.55, 0.55, 0.55, 0.56, 0.56, 0.56, 0.56, 0.56, 0.6, 0.6, 0.6, 0.6, 0.6, 0.53, 0.53, 0.53, 0.53, 0.53, 0.21, 0.21, 0.21, 0.21, 0.21, 0.24, 0.24, 0.24, 0.24, 0.24, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.1, 0.1, 0.1, 0.1, 0.1, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.19, 0.19, 0.19, 0.19, 0.19, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 560 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1714566747 --> 1714567371
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0]
                    
Loading

@JohannesGaessler JohannesGaessler merged commit 1613ef8 into ggml-org:master May 1, 2024
62 checks passed
@JohannesGaessler
Copy link
Collaborator Author

Sorry, I only noticed after I had already pushed the button, but does the code actually work for you @LostRuins ?

@LostRuins
Copy link
Collaborator

LostRuins commented May 1, 2024

Yes, it does. I have tried with both flash attn on and off, and the model outputs are coherent - so I presume it must be working.

Let me merge your latest changes and try again.

@LostRuins
Copy link
Collaborator

@JohannesGaessler I can confirm that your changes now build successfully and appear to work. Thanks!

nopperl pushed a commit to nopperl/llama.cpp that referenced this pull request May 5, 2024
teleprint-me pushed a commit to teleprint-me/llama.cpp that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants