Skip to content

Commit

Permalink
Ensure Lora Adapter Requests Return Lora Model Name Instead of Base M…
Browse files Browse the repository at this point in the history
…odel
  • Loading branch information
Jeffwan committed Dec 11, 2024
1 parent 9974fca commit f7b221c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
11 changes: 7 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest

Check failure on line 30 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/entrypoints/openai/serving_chat.py:30:31: F401 `vllm.lora.request.LoRARequest` imported but unused
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
Expand Down Expand Up @@ -123,6 +124,8 @@ async def create_chat_completion(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

model_name = self._get_model_name(lora_request)

Check failure on line 127 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "Optional[LoRARequest]"; expected "LoRARequest" [arg-type]

Check failure on line 127 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "LoRARequest | None"; expected "LoRARequest" [arg-type]

Check failure on line 127 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "LoRARequest | None"; expected "LoRARequest" [arg-type]

Check failure on line 127 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "LoRARequest | None"; expected "LoRARequest" [arg-type]

tokenizer = await self.engine_client.get_tokenizer(lora_request)

tool_parser = self.tool_parser
Expand Down Expand Up @@ -238,12 +241,12 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer,
request, result_generator, request_id, model_name, conversation, tokenizer,

Check failure on line 244 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/entrypoints/openai/serving_chat.py:244:81: E501 Line too long (91 > 80)
request_metadata)

try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer,
request, result_generator, request_id, model_name, conversation, tokenizer,

Check failure on line 249 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/entrypoints/openai/serving_chat.py:249:81: E501 Line too long (91 > 80)
request_metadata)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
Expand All @@ -259,11 +262,11 @@ async def chat_completion_stream_generator(
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
Expand Down Expand Up @@ -592,12 +595,12 @@ async def chat_completion_full_generator(
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.base_model_paths[0].name
created_time = int(time.time())
final_res: Optional[RequestOutput] = None

Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ async def create_completion(
return self.create_error_response(
"suffix is not currently supported")

model_name = self.base_model_paths[0].name
request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time())

Expand Down Expand Up @@ -162,6 +161,7 @@ async def create_completion(
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)

model_name = self._get_model_name(lora_request)

Check failure on line 164 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "Optional[LoRARequest]"; expected "LoRARequest" [arg-type]

Check failure on line 164 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "LoRARequest | None"; expected "LoRARequest" [arg-type]

Check failure on line 164 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "LoRARequest | None"; expected "LoRARequest" [arg-type]

Check failure on line 164 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Argument 1 to "_get_model_name" of "OpenAIServing" has incompatible type "LoRARequest | None"; expected "LoRARequest" [arg-type]
num_prompts = len(engine_prompts)

# Similar to the OpenAI API, when n != best_of, we do not stream the
Expand Down
13 changes: 13 additions & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,16 @@ async def unload_lora_adapter(

def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)

def _get_model_name(self, lora: LoRARequest):
"""
Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora is not None:
return lora.lora_name
return self.base_model_paths[0].name

0 comments on commit f7b221c

Please sign in to comment.