Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: OLMo 2 does not split qkv correctly for grouped query attention #13686

Closed
1 task done
2015aroras opened this issue Feb 21, 2025 · 0 comments · Fixed by #13687
Closed
1 task done

[Bug]: OLMo 2 does not split qkv correctly for grouped query attention #13686

2015aroras opened this issue Feb 21, 2025 · 0 comments · Fixed by #13687
Labels
bug Something isn't working

Comments

@2015aroras
Copy link
Contributor

Your current environment

I do not have access to the environment anymore, but the bug and its fix are straightforward.

🐛 Describe the bug

OLMo 2 does not correctly do attention when the number of heads is not the same as the number of kv heads (i.e. GQA or MQA is used instead of MHA). Specifically, it splits qkv into equal chunks rather than chunks based on q, k, v size. The fix is a 1-liner.

I don't have a minimal repro, but below is the stack trace caused by using OLMo 2 for GQA.

Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 236, in _run_worker_process
    output = run_method(worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/utils.py", line 2220, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1235, in profile_run
    self._dummy_run(max_num_batched_tokens, max_num_seqs)
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1346, in _dummy_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1719, in execute_model
    hidden_or_intermediate_states = model_executable(
                                    ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 364, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 312, in forward
    hidden_states = self.layers[i](
                    ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 247, in forward
    hidden_states = self.self_attn(positions, hidden_states, kv_cache,
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 161, in forward
    q, k = self._apply_qk_norm(q, k)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 143, in _apply_qk_norm
    q = self.q_norm.forward_native(q)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/layers/layernorm.py", line 52, in forward_native
    raise ValueError("Expected hidden_size to be "
ValueError: Expected hidden_size to be 5120, but found: 2392

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant