Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : refactor model loader with backend registry #10026

Merged
merged 2 commits into from
Oct 30, 2024

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Oct 24, 2024

  • The model loader can choose the buffer type to use with each tensor automatically based on its type and the available backends. It will avoid offloading tensors to backends that cannot use the tensors, for example, when trying to a mamba model with CUDA, the weights that cannot be used by the CUDA backend will be kept on the CPU.
  • Simplified model loading code: no longer requires specifying a context for each tensor, which was a common source of errors
  • Removed support for split tensor MoE models (only affects very old mixtral models)
  • Avoid duplicating tensors if not necessary, eg. if a model has a shared tok embd and output tensor, and the input and output layers are on the same backend, then it will reuse the same tensor for both layers
  • Row split mode (-sm row): KV and other non-matrix weights are split among the available GPUs in the same way as split by layer mode.
  • The AMX backend no longer requires using -ngl, it is considered a CPU accelerator and used automatically in CPU layers. Unsupported types automatically use the CPU backend.
  • Bug fix: in models with a rope frequencies tensor, the rope operation would always be offloaded to the GPU even in non-offloaded layers, causing a significant degradation in performance when using partial offloading with these models
    • During context creation, the number of graph splits for both pp and tg are printed to make it easier to find these bugs in the future
  • Replaced some asserts in the model loader with exceptions to avoid crashing the application with bad models
  • CPU and Metal from_ptr buffers have a _Mapped suffix so that it is easier to tell when a model is being loaded with mmap
  • llama-bench now uses the backend registry to determine the backend used and obtain device descriptions
  • ggml-backend:
    • Removed deprecated backend interface functions
    • Removed buffer names, the buffer type name is used instead
    • Simplified device types: now there are only CPU, GPU, and ACCEL type. llama.cpp will treat ACCEL devices as accelerators intended to be used together with the CPU backend, while GPU devices will be used only with the layers offloaded with -ngl.
    • When allocating a buffer of size 0, the alloc_buffer function is no longer called (backends no longer need to handle this case separately)
    • The environment variable GGML_SCHED_DEBUG now takes an integer value: 0 will no print any debugging trace, 1 will print the split headers only, 2 will print the entire graph
    • Added ggml_backend_dev_get_extra_bufts optional function (returned with get_proc_address) for backends that have multiple buffer types. llama.cpp will automatically use these buffer types if available. This is intended to be used with buffer types that change tensor layouts, eg. for automatic conversion of Q4_0 to the aarch64 types.
    • The optional function ggml_backend_split_buffer_type now takes a device parameter, which represents the main device intended to be used with this split buffer. Only this device should be reported as supported in the supports_buft function. Backends that implement split buffer types (CUDA and SYCL) should update this function to support the changes to -sm row.

Note: when this is merged, support for backends that do not implement the reg/device interfaces will be dropped. This will affect the Kompute backend until #10045 is merged.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Kompute https://github.com/KomputeProject/kompute/ labels Oct 24, 2024
@slaren slaren force-pushed the sl/load-time-supports-op branch from 0c91485 to ac4b252 Compare October 24, 2024 01:45
@slaren slaren force-pushed the sl/load-time-supports-op branch 4 times, most recently from bea440d to 24b79ca Compare October 25, 2024 00:59
@github-actions github-actions bot added script Script related python python script changes labels Oct 25, 2024
@slaren
Copy link
Collaborator Author

slaren commented Oct 25, 2024

@JohannesGaessler I would appreciate your opinion regarding the changes to -sm row. I can think of three options:

  • Revert the change completely
  • Make it the new behavior of -sm row
  • Add a new split mode

@slp
Copy link
Collaborator

slp commented Oct 25, 2024

Regarding the kompute backend, I've started adding support for the backend/device register interface in #10045

I expect to be ready for review in a couple days.

@JohannesGaessler
Copy link
Collaborator

I didn't yet look at the code due to the statement

Please do not review the code at this point, there are still a lot of changes left.

Is this still accurate?

@slaren
Copy link
Collaborator Author

slaren commented Oct 25, 2024

I still expect to make significant changes, so I don't see the point of doing a code review at this point, but the change to -sm row is already implemented and should work as intended. I am only asking for feedback about that change, not the entire PR.

@JohannesGaessler
Copy link
Collaborator

I think it would be fine to replace the current -sm row behavior with the one described in this PR (I think --main-gpu would then also be obsolete). The only situation where I think the behavior on master would be better is if you have plenty of VRAM and just want to use multiple GPUs for better speed, -sm row is actually beneficial, and also you don't have the same interconnect speed between GPUs. For example, if you have x16/x8/x8 PCIe lanes, you want the main GPU to preferentially be the one with 16 PCIe lanes. But nowadays the memory needed for the context is so large that I think it would be preferable to distribute the main GPU role and the KV cache.

Long-term I think the correct way to parallelize a transformer is to split the attention by heads since each head can be computed independently of the others. That would then also naturally distribute the KV cache.

@slaren slaren force-pushed the sl/load-time-supports-op branch from 24b79ca to 31b1a7c Compare October 27, 2024 16:19
@slaren slaren force-pushed the sl/load-time-supports-op branch 4 times, most recently from d06a75e to 916d263 Compare October 27, 2024 23:36
@github-actions github-actions bot added the devops improvements to build systems and github actions label Oct 27, 2024
@slaren slaren force-pushed the sl/load-time-supports-op branch 5 times, most recently from dac2953 to 9afae6a Compare October 28, 2024 01:05
@slaren slaren force-pushed the sl/load-time-supports-op branch 2 times, most recently from 6f0c502 to 1a408d7 Compare October 28, 2024 20:52
@slaren slaren marked this pull request as ready for review October 28, 2024 20:52
@slaren slaren force-pushed the sl/load-time-supports-op branch from 1a408d7 to 63c47ab Compare October 28, 2024 21:06
@slaren
Copy link
Collaborator Author

slaren commented Oct 28, 2024

This should be ready now. I will leave the flash attention changes for a different PR, since this is already becoming too big to review.

@ggerganov
Copy link
Owner

Added ggml_backend_dev_get_extra_bufts optional function (returned with get_proc_address) for backends that have multiple buffer types. llama.cpp will automatically use these buffer types if available. This is intended to be used with buffer types that change tensor layouts, eg. for automatic conversion of Q4_0 to the aarch64 types.

Do I understand the idea correctly that for example the CPU backend can "export" an extra "aarch64" buffer type that would be added at the beginning of the CPU buffer types list (after the ACCEL buffer types). When we test if a weight can be allocated with this extra buffer type, we will check if for example the hardware supports SVE or SME in the ggml_backend_cpu_device_supports_op() function? And this way we will be able to convert Q4_0 at runtime to the interleaved format currently required by the Q4_0_X_Y formats.

@slaren
Copy link
Collaborator Author

slaren commented Oct 29, 2024

Yes, that's exactly it. The check for hardware support may also be done in ggml_backend_dev_get_extra_bufts directly to avoid returning a buffer type that cannot be used, rather than in the supports_op function.

@slp
Copy link
Collaborator

slp commented Oct 29, 2024

Please consider merging #10045 first to avoid breaking the kompute backend upstream.

@8XXD8
Copy link

8XXD8 commented Oct 29, 2024

Vram usage seems to be a lot higher.
I can run Mistral Large with IQ3 quant and 16k context with master, but this PR fills every GPU to 99% with 10k context.
It is a couple percent slower than master, likely because some tensors now run on CPU:
llm_load_tensors: tensor 'token_embd.weight' (iq3_s) (and 177 others) cannot be used with preferred buffer type ROCm_Host, using CPU instead
Q8 and Q6 has the same warning with other models.
Its slower up to 7% percent with FA enabled, and up to 30% with FA disabled, larger models more affected.
I'm using 4X AMD GPUs.

@slaren
Copy link
Collaborator Author

slaren commented Oct 29, 2024

@8XXD8 I cannot reproduce that, and I can't even imagine how any of that could happen. Without more details, I am just going to assume that it is an AMD driver issue.

@slaren
Copy link
Collaborator Author

slaren commented Oct 29, 2024

Please consider merging #10045 first to avoid breaking the kompute backend upstream.

There will be additional changes needed to the Kompute backend after this is merged to adapt to the interface changes here, which I will not be able to test. Ultimately, the Kompute backend needs a maintainer that is willing to keep it updated with these changes, I cannot be responsible for updating every backend.

@ggerganov
Copy link
Owner

On MacOS, the compare-commits.sh script gives this error:

./scripts/compare-commits.sh master sl/load-time-supports-op -m ./models/llama-3.2-1b-instruct/ggml-model-q4_0.gguf -r 1 -n 0

...

build: 63c47ab8 (3984)
+ ./scripts/compare-llama-bench.py -b master -c sl/load-time-supports-op
Traceback (most recent call last):
  File "./scripts/compare-llama-bench.py", line 307, in <module>
    gpu_blas = bool(rows_full[0][KEY_PROPERTIES.index("gpu_blas")])
                    ~~~~~~~~~^^^
IndexError: list index out of range

@slaren
Copy link
Collaborator Author

slaren commented Oct 29, 2024

Due to the changes to llama-bench, there are some additional details exposed when running on macOS, which causes compare-llama-bench.py to fail to match the runs between the different branches, since some fields have different values. Eg.:

$ sqlite3 llama-bench.sqlite
sqlite> select * from test;
8f275a7c|3989|0|0|0|1|0|0|0|1|||./models/tinyllama-1.1b-intermediate-step-480k-1t.Q8_0.gguf|llama 1B Q8_0|1169072128|1100048384|2048|512|12|0x0|0|50|f16|f16|99|layer|0|0|0|0.00|1|0|512|0|2024-10-29T11:53:36Z|111724875|0|4582.685816|0.0
63c47ab8|3984|0|0|0|1|0|0|0|1|Accelerate, Apple M3 Max|Apple M3 Max|./models/tinyllama-1.1b-intermediate-step-480k-1t.Q8_0.gguf|llama 1B Q8_0|1169072128|1100048384|2048|512|12|0x0|0|50|f16|f16|99|layer|0|0|0|0.00|1|0|512|0|2024-10-29T11:53:47Z|110440167|0|4635.994439|0.0

In short: compare-llama-bench.py groups the results by the fields in KEY_PROPERTIES, usually these are the same between branches, in this case they aren't, so it fails to group the runs correctly.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool work!

@slp
Copy link
Collaborator

slp commented Oct 29, 2024

Please consider merging #10045 first to avoid breaking the kompute backend upstream.

There will be additional changes needed to the Kompute backend after this is merged to adapt to the interface changes here, which I will not be able to test. Ultimately, the Kompute backend needs a maintainer that is willing to keep it updated with these changes, I cannot be responsible for updating every backend.

FWIW, I can help with the maintenance work for the Kompute backend.

@chaxu01
Copy link
Collaborator

chaxu01 commented Oct 29, 2024

Yes, that's exactly it. The check for hardware support may also be done in ggml_backend_dev_get_extra_bufts directly to avoid returning a buffer type that cannot be used, rather than in the supports_op function.

Do I understand correctly that with the extra buft added in ggml_backend_dev_get_extra_bufts the repack of Q4_0 could be done in ggml_backend_cpu_buffer_set_tensor?

BTW there is small naming discrepancy in ggml_backend_cpu_get_proc_address where "ggml_backend_dev_get_extra_bufts" should be "ggml_backend_cpu_get_extra_bufts"

@slaren
Copy link
Collaborator Author

slaren commented Oct 29, 2024

Do I understand correctly that with the extra buft added in ggml_backend_dev_get_extra_bufts the repack of Q4_0 could be done in ggml_backend_cpu_buffer_set_tensor?

More or less. The process for implementing this would be something like this:

  • Create a new buffer type, implementing the buffer type and buffer interfaces
  • In the set_tensor function of this buffer, perform the repacking of the tensor
  • In ggml_backend_cpu_device_supports_buft, return true for this buffer type
  • In ggml_backend_cpu_device_supports_op, return true for this buffer type only if the op is GGML_OP_MUL_MAT, src0 is stored in this buffer type and has type Q4_0. This will prevent llama.cpp from using this buffer type for other tensors.
  • In ggml_compute_forward_mul_mat, check the buffer type of the src0 tensor, and if stored in this buffer type, handle the matrix multiplication using the repacked tensor layout

BTW there is small naming discrepancy in ggml_backend_cpu_get_proc_address where "ggml_backend_dev_get_extra_bufts" should be "ggml_backend_cpu_get_extra_bufts"

This is intentional. ggml_backend_dev_get_extra_bufts is intended to be a generic function name that can be implemented by any backend, ggml_backend_cpu_get_extra_bufts is the specific implementation of the CPU backend.

ggml-ci
@slaren
Copy link
Collaborator Author

slaren commented Oct 30, 2024

Since there are a few changes waiting on this, I will merge this now. @8XXD8 please open an issue with more details if you still see that problem after this is merged.

@slaren slaren merged commit c5b0f4b into master Oct 30, 2024
60 of 61 checks passed
@slaren slaren deleted the sl/load-time-supports-op branch October 30, 2024 01:01
@8XXD8
Copy link

8XXD8 commented Oct 30, 2024

Since there are a few changes waiting on this, I will merge this now. @8XXD8 please open an issue with more details if you still see that problem after this is merged.

The problem still persists, I'll try a few things before opening an issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
devops improvements to build systems and github actions examples ggml changes relating to the ggml tensor library for machine learning Kompute https://github.com/KomputeProject/kompute/ Nvidia GPU Issues specific to Nvidia GPUs python python script changes script Script related SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants