Skip to content

Commit

Permalink
second version of api scoring
Browse files Browse the repository at this point in the history
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
  • Loading branch information
gmarinho2 committed Feb 5, 2025
1 parent 7851b44 commit 12383b2
Showing 1 changed file with 263 additions and 94 deletions.
357 changes: 263 additions & 94 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast

import torch
from fastapi import Request

from vllm.config import ModelConfig
Expand Down Expand Up @@ -90,115 +90,284 @@ async def create_score(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

tokenizer = await self.engine_client.get_tokenizer(lora_request)

if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")

if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")

if not self.model_config.is_cross_encoder:
raise ValueError("Model is not cross encoder.")
tokenizer = await self.engine_client.get_tokenizer(lora_request)

if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")

input_pairs = make_pairs(request.text_1, request.text_2)
for q, t in input_pairs:
request_prompt = f"{q}{tokenizer.sep_token}{t}"

tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens

tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
text_pair=t,
**tokenization_kwargs)

input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))

request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)

except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []

try:
pooling_params = request.to_pooling_params()

for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"

self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))

generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
if self.model_config.is_cross_encoder:
try:

if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")

input_pairs = make_pairs(request.text_1, request.text_2)
for q, t in input_pairs:
request_prompt = f"{q}{tokenizer.sep_token}{t}"

tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs[
"max_length"] = truncate_prompt_tokens

tokenize_async = make_async(
tokenizer.__call__, executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
text_pair=t,
**tokenization_kwargs)

input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))

request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)

except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []

try:
pooling_params = request.to_pooling_params()

for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"

self._log_inputs(
request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

trace_headers = (None if raw_request is None else await
self._get_trace_headers(
raw_request.headers))

generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)

generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

result_generator = merge_async_iterators(*generators)

num_prompts = len(engine_prompts)

# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts

try:
async for i, res in result_generator:
final_res_batch[i] = res

assert all(final_res is not None
for final_res in final_res_batch)

final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)

response = self.request_output_to_score_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
)

generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

result_generator = merge_async_iterators(*generators)

num_prompts = len(engine_prompts)

# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts

try:
async for i, res in result_generator:
final_res_batch[i] = res

assert all(final_res is not None for final_res in final_res_batch)

final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)

response = self.request_output_to_score_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

else:
try:
input_pairs = make_pairs(request.text_1, request.text_2)
for q, t in input_pairs:
request_prompt = f"{q}{tokenizer.sep_token}{t}"

tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs[
"max_length"] = truncate_prompt_tokens

tokenize_async = make_async(
tokenizer.__call__, executor=self._tokenizer_executor)

#first of the pair
prompt_inputs_q = await tokenize_async(
text=q, **tokenization_kwargs)

input_ids_q = prompt_inputs_q["input_ids"]

text_token_prompt_q = \
self._validate_input(request, input_ids_q, q)

engine_prompt_q = TokensPrompt(
prompt_token_ids=text_token_prompt_q[
"prompt_token_ids"],
token_type_ids=prompt_inputs_q.get("token_type_ids"))

#second of the pair

prompt_inputs_t = await tokenize_async(
text=t, **tokenization_kwargs)
input_ids_t = prompt_inputs_t["input_ids"]

text_token_prompt_t = \
self._validate_input(request, input_ids_t, t)

engine_prompt_t = TokensPrompt(
prompt_token_ids=text_token_prompt_t[
"prompt_token_ids"],
token_type_ids=prompt_inputs_t.get("token_type_ids"))

request_prompts.append(request_prompt)
engine_prompts.append((engine_prompt_q, engine_prompt_t))

except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []

try:
pooling_params = request.to_pooling_params()

#for i in range(0, len(engine_prompts), 2):
for i, engine_prompt in enumerate(engine_prompts):
trace_headers = (None if raw_request is None else await
self._get_trace_headers(
raw_request.headers))

request_id_item_0 = f"{request_id}-{i}"

self._log_inputs(
request_id_item_0,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

generator_0 = self.engine_client.encode(
engine_prompt[0],
pooling_params,
request_id_item_0,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
#second element of the pair
request_id_item_1 = f"{request_id}-{i}.1"

self._log_inputs(
request_id_item_1,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

generator_1 = self.engine_client.encode(
engine_prompt[1],
pooling_params,
request_id_item_1,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)

generators.append(generator_0)
generators.append(generator_1)

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

result_generator = merge_async_iterators(*generators)

num_prompts = len(engine_prompts)

# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts

try:
embeddings = []
async for i, res in result_generator:
embeddings.append(res)

scores = []
scorer = torch.nn.CosineSimilarity(0)

for i in range(0, len(embeddings), 2):
pair_score = scorer(embeddings[i].outputs.data,
embeddings[i + 1].outputs.data)

if (pad_token_id := getattr(tokenizer, "pad_token_id",
None)) is not None:
tokens = embeddings[i].prompt_token_ids + [
pad_token_id
] + embeddings[i + 1].prompt_token_ids
else:
tokens = embeddings[i].prompt_token_ids + embeddings[
i + 1].prompt_token_ids

scores.append(
PoolingRequestOutput(
request_id=
f"{embeddings[i].request_id}_{embeddings[i+1].request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))

final_res_batch = scores
assert all(final_res is not None
for final_res in final_res_batch)

final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)

response = self.request_output_to_score_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

return response

Expand Down

0 comments on commit 12383b2

Please sign in to comment.