Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Cache uses_mrope in GPUModelRunner #12969

Merged
merged 1 commit into from
Feb 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope

# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(
device=self.device)

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
Expand Down Expand Up @@ -284,7 +285,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
)

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
if self.uses_mrope:
image_grid_thw = []
video_grid_thw = []
second_per_grid_ts = []
Expand Down Expand Up @@ -411,7 +412,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)

# Get token indices.
Expand Down Expand Up @@ -458,7 +459,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
if self.model_config.uses_mrope:
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
Expand Down Expand Up @@ -817,13 +818,14 @@ def execute_model(
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
positions = self.positions[:num_input_tokens]

# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
positions = self.mrope_positions[:, :num_input_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_input_tokens]
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
Expand Down Expand Up @@ -1001,10 +1003,11 @@ def _dummy_run(
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
with set_forward_context(None, self.vllm_config):
positions = self.mrope_positions[:, :num_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_tokens]
hidden_states = model(
input_ids=input_ids,
positions=positions,
Expand Down