From df7ab195e51cec0f27872c5f3b17b5aaa88030c0 Mon Sep 17 00:00:00 2001 From: Ryan Nguyen <96593302+xpbowler@users.noreply.github.com> Date: Fri, 31 Jan 2025 18:37:30 -0500 Subject: [PATCH] [Feature] Fix guided decoding blocking bitmask memcpy (#12563) **[Guided decoding performance optimization]** Sending the guided decoding bitmask in xgrammar to the GPU (`self.token_bitmask.to(scores.device)`) is a blocking operation that prevents the CPU from pre-launching the sampler kernels. The CPU waits until decode is complete, then copies the bitmask over. This PR changes the operation to async via setting `non-blocking=True`. (Current) The CPU is blocked on a `cudaStreamSynchronize` and only pre-empts the sampling kernels after bitmask application. Below is the Nsys profile for one decode phase from Llama 3.1 8B. ![image](https://github.com/user-attachments/assets/8997eae1-b822-4f52-beb8-ef19a7c6b824) With the optimization, this is no longer the case: ![image](https://github.com/user-attachments/assets/6d5ea83f-f169-4f98-a8c1-41c719b3e1e7) --------- Signed-off-by: Ryan N Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/guided_decoding/xgrammar_decoding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 2d8594cb8aafa..ee30ce96f0a1e 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -307,8 +307,8 @@ def __call__(self, input_ids: list[int], # Note: In this method, if the tensors have different dimensions # on CPU device fails, but on GPU it runs without error. Hence the # unsqueeze above for scores, to match the token bitmask shape - xgr.apply_token_bitmask_inplace(scores, - self.token_bitmask.to(scores.device)) + xgr.apply_token_bitmask_inplace( + scores, self.token_bitmask.to(scores.device, non_blocking=True)) if device_type != "cuda": scores = scores.to(dtype).to(device_type).squeeze()