Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models) #7461

Merged
merged 7 commits into from
May 23, 2024

Conversation

fairydreaming
Copy link
Collaborator

This pull request adds missing pieces to support inference for GPT-NeoX-based models like the GPT-NeoX and the Pythia family. Fixes #742. It also adds model types for all Pythia model sizes.
Added use_par_res hparams field corresponds to the use_parallel_residual parameter from config.json.

@github-actions github-actions bot added the python python script changes label May 22, 2024
@ggerganov
Copy link
Owner

Tested with https://huggingface.co/EleutherAI/pythia-1.4b/tree/main

Seems to work. PPL on wiki.test is 12.8692 +/- 0.09260:

./perplexity -m models/pythia-1b/ggml-model-f16.gguf -f build/wikitext-2-raw/wiki.test.raw

I guess it's normal for 1.4B model that is 1 year old. Thanks for implementing this

@cebtenzzre cebtenzzre linked an issue May 22, 2024 that may be closed by this pull request
@fairydreaming
Copy link
Collaborator Author

fairydreaming commented May 22, 2024

It seems that the perplexity is a little higher compared to the HF transformers implementation because there are differences in tokenization output between llama.cpp and GPTNeoXTokenizerFast.
Edit: It looks that there are differences in dataset files that I used for measuring perplexity for transformers and llama.cpp, will have to recheck.

Copy link
Contributor

github-actions bot commented May 22, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 537 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8705.59ms p(95)=22299.26ms fails=, finish reason: stop=480 truncated=57
  • Prompt processing (pp): avg=103.68tk/s p(95)=455.27tk/s
  • Token generation (tg): avg=31.86tk/s p(95)=47.22tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=gpt-neox commit=7e171de882ca16fbd75f72d7d1dd4afef75c04d6

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 330.75, 330.75, 330.75, 330.75, 330.75, 691.29, 691.29, 691.29, 691.29, 691.29, 690.47, 690.47, 690.47, 690.47, 690.47, 731.07, 731.07, 731.07, 731.07, 731.07, 765.01, 765.01, 765.01, 765.01, 765.01, 774.28, 774.28, 774.28, 774.28, 774.28, 779.03, 779.03, 779.03, 779.03, 779.03, 802.67, 802.67, 802.67, 802.67, 802.67, 783.81, 783.81, 783.81, 783.81, 783.81, 785.81, 785.81, 785.81, 785.81, 785.81, 807.75, 807.75, 807.75, 807.75, 807.75, 849.72, 849.72, 849.72, 849.72, 849.72, 876.13, 876.13, 876.13, 876.13, 876.13, 883.55, 883.55, 883.55, 883.55, 883.55, 884.21, 884.21, 884.21, 884.21, 884.21, 887.5, 887.5, 887.5, 887.5, 887.5, 890.74, 890.74, 890.74, 890.74, 890.74, 887.0, 887.0, 887.0, 887.0, 887.0, 893.05, 893.05, 893.05, 893.05, 893.05, 892.89, 892.89, 892.89, 892.89, 892.89, 899.33, 899.33, 899.33, 899.33, 899.33, 894.13, 894.13, 894.13, 894.13, 894.13, 895.46, 895.46, 895.46, 895.46, 895.46, 912.73, 912.73, 912.73, 912.73, 912.73, 906.99, 906.99, 906.99, 906.99, 906.99, 906.97, 906.97, 906.97, 906.97, 906.97, 908.35, 908.35, 908.35, 908.35, 908.35, 854.22, 854.22, 854.22, 854.22, 854.22, 850.01, 850.01, 850.01, 850.01, 850.01, 850.49, 850.49, 850.49, 850.49, 850.49, 855.41, 855.41, 855.41, 855.41, 855.41, 853.77, 853.77, 853.77, 853.77, 853.77, 857.9, 857.9, 857.9, 857.9, 857.9, 861.87, 861.87, 861.87, 861.87, 861.87, 873.08, 873.08, 873.08, 873.08, 873.08, 880.87, 880.87, 880.87, 880.87, 880.87, 880.4, 880.4, 880.4, 880.4, 880.4, 878.5, 878.5, 878.5, 878.5, 878.5, 873.8, 873.8, 873.8, 873.8, 873.8, 876.65, 876.65, 876.65, 876.65, 876.65, 878.44, 878.44, 878.44, 878.44, 878.44, 878.46, 878.46, 878.46, 878.46, 878.46, 847.66, 847.66, 847.66, 847.66, 847.66, 850.26, 850.26, 850.26, 850.26, 850.26, 849.57, 849.57, 849.57, 849.57, 849.57, 848.81, 848.81, 848.81, 848.81, 848.81, 851.85, 851.85, 851.85, 851.85, 851.85, 843.14, 843.14, 843.14, 843.14, 843.14, 841.89, 841.89, 841.89, 841.89, 841.89, 844.43, 844.43, 844.43, 844.43, 844.43, 843.23, 843.23, 843.23, 843.23, 843.23, 845.33, 845.33, 845.33, 845.33, 845.33, 850.24, 850.24, 850.24, 850.24, 850.24, 850.0, 850.0, 850.0, 850.0, 850.0, 841.26, 841.26, 841.26, 841.26, 841.26, 840.65, 840.65, 840.65, 840.65, 840.65, 840.81, 840.81, 840.81, 840.81, 840.81, 840.85, 840.85, 840.85, 840.85, 840.85, 841.22, 841.22, 841.22, 841.22, 841.22, 840.66, 840.66, 840.66, 840.66, 840.66, 842.45, 842.45, 842.45, 842.45]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.31, 45.31, 45.31, 45.31, 45.31, 43.69, 43.69, 43.69, 43.69, 43.69, 28.62, 28.62, 28.62, 28.62, 28.62, 33.0, 33.0, 33.0, 33.0, 33.0, 34.71, 34.71, 34.71, 34.71, 34.71, 34.05, 34.05, 34.05, 34.05, 34.05, 33.36, 33.36, 33.36, 33.36, 33.36, 34.25, 34.25, 34.25, 34.25, 34.25, 34.69, 34.69, 34.69, 34.69, 34.69, 34.64, 34.64, 34.64, 34.64, 34.64, 34.76, 34.76, 34.76, 34.76, 34.76, 34.17, 34.17, 34.17, 34.17, 34.17, 34.12, 34.12, 34.12, 34.12, 34.12, 33.14, 33.14, 33.14, 33.14, 33.14, 32.5, 32.5, 32.5, 32.5, 32.5, 30.43, 30.43, 30.43, 30.43, 30.43, 29.68, 29.68, 29.68, 29.68, 29.68, 29.75, 29.75, 29.75, 29.75, 29.75, 29.95, 29.95, 29.95, 29.95, 29.95, 29.83, 29.83, 29.83, 29.83, 29.83, 29.81, 29.81, 29.81, 29.81, 29.81, 29.86, 29.86, 29.86, 29.86, 29.86, 30.09, 30.09, 30.09, 30.09, 30.09, 30.14, 30.14, 30.14, 30.14, 30.14, 30.02, 30.02, 30.02, 30.02, 30.02, 30.01, 30.01, 30.01, 30.01, 30.01, 30.2, 30.2, 30.2, 30.2, 30.2, 30.25, 30.25, 30.25, 30.25, 30.25, 30.19, 30.19, 30.19, 30.19, 30.19, 30.43, 30.43, 30.43, 30.43, 30.43, 30.74, 30.74, 30.74, 30.74, 30.74, 30.84, 30.84, 30.84, 30.84, 30.84, 31.07, 31.07, 31.07, 31.07, 31.07, 31.14, 31.14, 31.14, 31.14, 31.14, 31.01, 31.01, 31.01, 31.01, 31.01, 30.95, 30.95, 30.95, 30.95, 30.95, 30.54, 30.54, 30.54, 30.54, 30.54, 30.44, 30.44, 30.44, 30.44, 30.44, 30.53, 30.53, 30.53, 30.53, 30.53, 30.65, 30.65, 30.65, 30.65, 30.65, 30.76, 30.76, 30.76, 30.76, 30.76, 30.89, 30.89, 30.89, 30.89, 30.89, 30.9, 30.9, 30.9, 30.9, 30.9, 30.61, 30.61, 30.61, 30.61, 30.61, 30.15, 30.15, 30.15, 30.15, 30.15, 29.49, 29.49, 29.49, 29.49, 29.49, 29.36, 29.36, 29.36, 29.36, 29.36, 29.31, 29.31, 29.31, 29.31, 29.31, 29.28, 29.28, 29.28, 29.28, 29.28, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.25, 29.25, 29.25, 29.25, 29.25, 29.26, 29.26, 29.26, 29.26, 29.26, 29.22, 29.22, 29.22, 29.22, 29.22, 29.29, 29.29, 29.29, 29.29, 29.29, 29.24, 29.24, 29.24, 29.24, 29.24, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.24, 29.24, 29.24, 29.24, 29.24, 29.35, 29.35, 29.35, 29.35, 29.35, 29.47, 29.47, 29.47, 29.47]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08, 0.08, 0.08, 0.08, 0.08, 0.35, 0.35, 0.35, 0.35, 0.35, 0.09, 0.09, 0.09, 0.09, 0.09, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.26, 0.26, 0.26, 0.26, 0.26, 0.09, 0.09, 0.09, 0.09, 0.09, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.28, 0.28, 0.28, 0.28, 0.28, 0.39, 0.39, 0.39, 0.39, 0.39, 0.4, 0.4, 0.4, 0.4, 0.4, 0.35, 0.35, 0.35, 0.35, 0.35, 0.26, 0.26, 0.26, 0.26, 0.26, 0.18, 0.18, 0.18, 0.18, 0.18, 0.14, 0.14, 0.14, 0.14, 0.14, 0.26, 0.26, 0.26, 0.26, 0.26, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.35, 0.35, 0.35, 0.35, 0.35, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.25, 0.25, 0.25, 0.25, 0.25, 0.33, 0.33, 0.33, 0.33, 0.33, 0.22, 0.22, 0.22, 0.22, 0.22, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.45, 0.45, 0.45, 0.45, 0.45, 0.49, 0.49, 0.49, 0.49, 0.49, 0.47, 0.47, 0.47, 0.47, 0.47, 0.31, 0.31, 0.31, 0.31, 0.31, 0.26, 0.26, 0.26, 0.26, 0.26, 0.31, 0.31, 0.31, 0.31, 0.31, 0.28, 0.28, 0.28, 0.28, 0.28, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.22, 0.22, 0.22, 0.22, 0.22, 0.21, 0.21, 0.21, 0.21, 0.21, 0.23, 0.23, 0.23, 0.23, 0.23, 0.25, 0.25, 0.25, 0.25, 0.25, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0]
                    
Loading

@ggerganov
Copy link
Owner

The tokenization differences on wiki.test are minimal and related to slightly different way that we handle added tokens:

diff ./build/wikitext-2-raw/wiki.test.raw.tok ./build/wikitext-2-raw/wiki.test.raw.tokcpp
245413,245414c245413,245414
< 50276
< 6285
---
> 209
> 20589
245440,245441c245440,245441
< 50276
< 6285
---
> 209
> 20589
246660,246661c246660,246661
< 50276
< 6285
---
> 209
> 20589
246687,246688c246687,246688
< 50276
< 6285
---
> 209
> 20589

Likely the perplexity computation used in the HF transformers differs from llama.cpp (i.e. different context size, strided evaluation, etc.)

For Pythia 2.8b I get PPL 10.9294 +/- 0.07654

@ggerganov
Copy link
Owner

Edit: It looks that there are differences in dataset files that I used for measuring perplexity for transformers and llama.cpp, will have to recheck.

Feel free to merge this when ready - I think it works

@mofosyne mofosyne added model Model specific Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level labels May 23, 2024
@fairydreaming fairydreaming merged commit 9b82476 into ggerganov:master May 23, 2024
60 of 71 checks passed
@felladrin
Copy link
Contributor

Thank you for this, @fairydreaming! I have wanted it for so long!
And thanks @ggerganov for reviewing it so quickly!

teleprint-me pushed a commit to teleprint-me/llama.cpp that referenced this pull request May 23, 2024
…NeoX base models) (ggerganov#7461)

* convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel)

* llama : add inference support for LLM_ARCH_GPTNEOX

* llama : add model types for every Pythia variant and GPT-NeoX

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Model specific python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GPT-NeoX has only minimal inference support Pythia Support?
5 participants