Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][VLM] Enable proper chunked prefill for multimodal models #9950

Closed
wants to merge 17 commits into from
Prev Previous commit
Next Next commit
update
Signed-off-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
ywang96 committed Nov 7, 2024
commit d918b0f3fa86398a23b86a2fef64210590096d84
3 changes: 3 additions & 0 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,9 @@ def forward(
if intermediate_tensors is not None:
inputs_embeds = None

# TODO (ywang96): This is currently needed since embedding generation
# takes place in the model forward pass. Clean this up after V0 is
# fully deprecated.
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
inputs_embeds = self.get_inputs_embeds(
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def forward(
):
if intermediate_tensors is not None:
inputs_embeds = None

# TODO (ywang96): This is currently needed since embedding generation
# takes place in the model forward pass. Clean this up after V0 is
# fully deprecated.
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
inputs_embeds = self.get_inputs_embeds(
Expand Down
72 changes: 55 additions & 17 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.base import MultiModalInputs, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
Expand Down Expand Up @@ -323,9 +323,25 @@ def input_processor(
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)

return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
# Get precise tracking of placeholder positions
token_idx = image_idx = 0
placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == self.img_context_token:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1

return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})

def input_mapper(
self,
Expand Down Expand Up @@ -608,33 +624,55 @@ def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
visual_token_mask = None
return visual_token_mask

def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_inputs_embeds(
self, input_ids: torch.Tensor,
vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=vision_embeddings,
placeholder_token_id=self.img_context_token_id)

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
visual_token_mask = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)

# TODO (ywang96): This is currently needed since embedding generation
# takes place in the model forward pass. Clean this up after V0 is
# fully deprecated.
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
if vision_embeddings is not None:
visual_token_mask = self._get_visual_token_mask(input_ids)
input_ids = None
else:
inputs_embeds = None
visual_token_mask = None
inputs_embeds = self.get_inputs_embeds(
input_ids=input_ids, vision_embeddings=vision_embeddings)
input_ids = None

else:
visual_token_mask = self._get_visual_token_mask(input_ids)

forward_kwargs = {
"input_ids": input_ids,
Expand Down
Loading