Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] fix sliding window in prefix prefill Triton kernel #4405

Merged
merged 9 commits into from
May 2, 2024

Conversation

mmoskal
Copy link
Contributor

@mmoskal mmoskal commented Apr 26, 2024

This adds support for the sliding window in prefix prefill kernel.

I had to use a large negative value instead of -inf for masking, since otherwise in some situations we get '-inf - -inf' in softmax which leads to NaNs.

Added tests comparing with xformers.

Also added a bunch of comments with tensor shapes etc.

FIX #4057

CC @rkooo567 @cadedaniel @simon-mo

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@zhuohan123
Copy link
Member

CC @caoshiyi Can you help take a look at this?

@rkooo567
Copy link
Collaborator

yep! it is Sat here, but I will take a look at it very soon! @mmoskal thanks for the amazing contribution!

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good! Also thanks for the comment on shapes. Some comments regarding tests

@@ -15,18 +15,21 @@
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
SLIDING_WINDOW = [0, 512]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try various values? (16, 64, 128, 256, 512, 2048). 2048 is bigger than max seq, but just for sanity check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also curious if we can do e2e test against mistral or other model that has sliding window attn enabled..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try different sizes. I'm trying to get the sliding window to work with v2 block manager (somewhat based on #3967) which should exercise this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg. The test itself will be done in this PR right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, adding more sliding window parameters now

# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
-10000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -10000 small enough? Maybe consider even smaller values? like -10000000

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-65000 or so is the smallest value for f16. Anyway, this gets exp()ed so -10k should enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually good you point this out - I found that in one place I do sm_scale before and in the other after masking with -10000; it doesn't matter much (smallest scale is 1/16 or so, and exp(-10000/16) is still zero) but better to always scale and then mask so we get consistent mask values - fixed

# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this means

cur_batch_ctx_len + offs_m[:, None] == end of q
start_n + offs_n[None, :] == end of k

so q-k length is within slinding window, attend it, is this correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes; I added a comment to that effect

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM given it passes with more variations in unit test!

@mmoskal
Copy link
Contributor Author

mmoskal commented May 2, 2024

@rkooo567 all tests passed, should be good

@rkooo567
Copy link
Collaborator

rkooo567 commented May 2, 2024

cc @simon-mo to merge!

@mmoskal
Copy link
Contributor Author

mmoskal commented May 2, 2024

@simon-mo can we get this merged? I need it to rebase #4545

@simon-mo simon-mo merged commit 32881f3 into vllm-project:main May 2, 2024
48 checks passed
@mmoskal mmoskal deleted the triton_sliding_window branch May 2, 2024 22:42
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature][Chunked prefill]: Make sliding window work
4 participants