Skip to content

Commit

Permalink
[Core] Interface for accessing model from VllmRunner (vllm-project#…
Browse files Browse the repository at this point in the history
…10353)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
  • Loading branch information
DarkLight1337 authored and abmfy committed Jan 24, 2025
1 parent 1d88e80 commit 3427a91
Show file tree
Hide file tree
Showing 35 changed files with 474 additions and 307 deletions.
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def video_assets() -> _VideoAssets:


_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_R = TypeVar("_R")


class HfRunner:
Expand Down Expand Up @@ -930,6 +931,10 @@ def score(
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.score for req_output in req_outputs]

def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.model.llm_engine.model_executor
return executor.apply_model(func)

def __enter__(self):
return self

Expand Down
4 changes: 3 additions & 1 deletion tests/engine/test_custom_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def test_custom_executor(model, tmp_path):
assert not os.path.exists(".marker")

engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomUniExecutor)
model=model,
distributed_executor_backend=CustomUniExecutor,
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

Expand Down
64 changes: 33 additions & 31 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ def test_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
Expand All @@ -46,11 +45,13 @@ def test_model_loading_with_params(vllm_runner):
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
def check_model(model):
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize

vllm_model.apply_model(check_model)

# assert output
assert output

Expand All @@ -64,13 +65,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
Expand All @@ -84,11 +84,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize

vllm_model.apply_model(check_model)

# assert output
assert output
Expand All @@ -103,17 +104,18 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
model_name = "FacebookAI/roberta-base"
with vllm_runner(model_name=model_name,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_tokenizer = model.model.llm_engine.tokenizer
model_tokenizer = vllm_model.model.llm_engine.tokenizer
assert model_tokenizer.tokenizer_id == model_name

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert not hasattr(model, "lm_head")
assert isinstance(model, RobertaEmbeddingModel)
assert isinstance(model._pooler, CLSPool)
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert not hasattr(model, "lm_head")
assert isinstance(model._pooler, CLSPool)

vllm_model.apply_model(check_model)

assert output
7 changes: 5 additions & 2 deletions tests/models/decoder_only/language/test_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ def test_models(

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)

vllm_model.apply_model(print_model)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
Expand Down
7 changes: 5 additions & 2 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ def test_models(

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)

vllm_model.apply_model(print_model)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
Expand Down
7 changes: 5 additions & 2 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,13 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)

vllm_model.apply_model(print_model)

check_logprobs_close(
outputs_0_lst=hf_outputs,
Expand Down
49 changes: 26 additions & 23 deletions tests/models/decoder_only/vision_language/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from PIL import Image

from vllm.entrypoints.llm import LLM
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video

Expand Down Expand Up @@ -69,7 +68,7 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict):

def batch_make_image_embeddings(
image_batches: List[Union[Image.Image, List[Image.Image]]], processor,
llm: LLM) -> List[Qwen2VLPromptImageEmbeddingInput]:
llm: VllmRunner) -> List[Qwen2VLPromptImageEmbeddingInput]:
"""batched image embeddings for Qwen2-VL
This will infer all images' embeddings in a single batch,
Expand Down Expand Up @@ -106,16 +105,18 @@ def batch_make_image_embeddings(
image_grid_thw = preprocess_result["image_grid_thw"]

# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
def get_image_embeds(model):
with torch.no_grad():
visual = model.visual

pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64)
image_embeds = visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device)
pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device)

image_embeds = torch.concat(llm.apply_model(get_image_embeds))

# split into original batches
result: List[Qwen2VLPromptImageEmbeddingInput] = []
Expand Down Expand Up @@ -150,7 +151,7 @@ def batch_make_image_embeddings(

def batch_make_video_embeddings(
video_batches: PromptVideoInput, processor,
llm: LLM) -> List[Qwen2VLPromptVideoEmbeddingInput]:
llm: VllmRunner) -> List[Qwen2VLPromptVideoEmbeddingInput]:
"""batched video embeddings for Qwen2-VL
A NDArray represents a single video's all frames.
Expand Down Expand Up @@ -187,16 +188,18 @@ def batch_make_video_embeddings(
video_grid_thw = preprocess_result["video_grid_thw"]

# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
def get_image_embeds(model):
with torch.no_grad():
visual = model.visual

pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device)

pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64)
video_embeds = visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device)
video_embeds = torch.concat(llm.apply_model(get_image_embeds))

# split into original batches
result: List[Qwen2VLPromptVideoEmbeddingInput] = []
Expand Down Expand Up @@ -278,9 +281,9 @@ def run_embedding_input_test(
max_tokens,
num_logprobs=num_logprobs,
images=batch_make_image_embeddings(
images, processor, vllm_model.model) if images else None,
images, processor, vllm_model) if images else None,
videos=batch_make_video_embeddings(
videos, processor, vllm_model.model) if videos else None)
videos, processor, vllm_model) if videos else None)
for prompts, images, videos in inputs
]

Expand Down
7 changes: 5 additions & 2 deletions tests/models/embedding/language/test_cls_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ def test_classification_models(
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)

# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)

vllm_model.apply_model(print_model)

with hf_runner(model,
dtype=dtype,
Expand Down
7 changes: 5 additions & 2 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ def test_models(
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)

# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)

vllm_model.apply_model(print_model)

check_embeddings_close(
embeddings_0_lst=hf_outputs,
Expand Down
Loading

0 comments on commit 3427a91

Please sign in to comment.