From 12383b26ce92abcba09e13ab08a0bb46fc8151b8 Mon Sep 17 00:00:00 2001 From: Gabriel Marinho Date: Wed, 5 Feb 2025 13:00:22 -0300 Subject: [PATCH] second version of api scoring Signed-off-by: Gabriel Marinho --- vllm/entrypoints/openai/serving_score.py | 357 +++++++++++++++++------ 1 file changed, 263 insertions(+), 94 deletions(-) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 832aa8516cc35..34679bf569775 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio import time from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast +import torch from fastapi import Request from vllm.config import ModelConfig @@ -90,115 +90,284 @@ async def create_score( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) - if prompt_adapter_request is not None: raise NotImplementedError("Prompt adapter is not supported " "for scoring models") - if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "MistralTokenizer not supported for cross-encoding") - - if not self.model_config.is_cross_encoder: - raise ValueError("Model is not cross encoder.") + tokenizer = await self.engine_client.get_tokenizer(lora_request) if truncate_prompt_tokens is not None and \ - truncate_prompt_tokens > self.max_model_len: + truncate_prompt_tokens > self.max_model_len: raise ValueError( f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " f"is greater than max_model_len ({self.max_model_len})." f" Please, select a smaller truncation size.") - input_pairs = make_pairs(request.text_1, request.text_2) - for q, t in input_pairs: - request_prompt = f"{q}{tokenizer.sep_token}{t}" - - tokenization_kwargs: Dict[str, Any] = {} - if truncate_prompt_tokens is not None: - tokenization_kwargs["truncation"] = True - tokenization_kwargs["max_length"] = truncate_prompt_tokens - - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - prompt_inputs = await tokenize_async(text=q, - text_pair=t, - **tokenization_kwargs) - - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) - except ValueError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - # Schedule the request and get the result generator. - generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] - - try: - pooling_params = request.to_pooling_params() - - for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" - - self._log_inputs(request_id_item, - request_prompts[i], - params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) - - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) - - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, + if self.model_config.is_cross_encoder: + try: + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + input_pairs = make_pairs(request.text_1, request.text_2) + for q, t in input_pairs: + request_prompt = f"{q}{tokenizer.sep_token}{t}" + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs[ + "max_length"] = truncate_prompt_tokens + + tokenize_async = make_async( + tokenizer.__call__, executor=self._tokenizer_executor) + prompt_inputs = await tokenize_async(text=q, + text_pair=t, + **tokenization_kwargs) + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs( + request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers( + raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None + for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_score_response( + final_res_batch_checked, + request_id, + created_time, + model_name, ) - - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - result_generator = merge_async_iterators(*generators) - - num_prompts = len(engine_prompts) - - # Non-streaming response - final_res_batch: List[Optional[PoolingRequestOutput]] - final_res_batch = [None] * num_prompts - - try: - async for i, res in result_generator: - final_res_batch[i] = res - - assert all(final_res is not None for final_res in final_res_batch) - - final_res_batch_checked = cast(List[PoolingRequestOutput], - final_res_batch) - - response = self.request_output_to_score_response( - final_res_batch_checked, - request_id, - created_time, - model_name, - ) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + else: + try: + input_pairs = make_pairs(request.text_1, request.text_2) + for q, t in input_pairs: + request_prompt = f"{q}{tokenizer.sep_token}{t}" + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs[ + "max_length"] = truncate_prompt_tokens + + tokenize_async = make_async( + tokenizer.__call__, executor=self._tokenizer_executor) + + #first of the pair + prompt_inputs_q = await tokenize_async( + text=q, **tokenization_kwargs) + + input_ids_q = prompt_inputs_q["input_ids"] + + text_token_prompt_q = \ + self._validate_input(request, input_ids_q, q) + + engine_prompt_q = TokensPrompt( + prompt_token_ids=text_token_prompt_q[ + "prompt_token_ids"], + token_type_ids=prompt_inputs_q.get("token_type_ids")) + + #second of the pair + + prompt_inputs_t = await tokenize_async( + text=t, **tokenization_kwargs) + input_ids_t = prompt_inputs_t["input_ids"] + + text_token_prompt_t = \ + self._validate_input(request, input_ids_t, t) + + engine_prompt_t = TokensPrompt( + prompt_token_ids=text_token_prompt_t[ + "prompt_token_ids"], + token_type_ids=prompt_inputs_t.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append((engine_prompt_q, engine_prompt_t)) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + + try: + pooling_params = request.to_pooling_params() + + #for i in range(0, len(engine_prompts), 2): + for i, engine_prompt in enumerate(engine_prompts): + trace_headers = (None if raw_request is None else await + self._get_trace_headers( + raw_request.headers)) + + request_id_item_0 = f"{request_id}-{i}" + + self._log_inputs( + request_id_item_0, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + generator_0 = self.engine_client.encode( + engine_prompt[0], + pooling_params, + request_id_item_0, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + #second element of the pair + request_id_item_1 = f"{request_id}-{i}.1" + + self._log_inputs( + request_id_item_1, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + generator_1 = self.engine_client.encode( + engine_prompt[1], + pooling_params, + request_id_item_1, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator_0) + generators.append(generator_1) + + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + embeddings = [] + async for i, res in result_generator: + embeddings.append(res) + + scores = [] + scorer = torch.nn.CosineSimilarity(0) + + for i in range(0, len(embeddings), 2): + pair_score = scorer(embeddings[i].outputs.data, + embeddings[i + 1].outputs.data) + + if (pad_token_id := getattr(tokenizer, "pad_token_id", + None)) is not None: + tokens = embeddings[i].prompt_token_ids + [ + pad_token_id + ] + embeddings[i + 1].prompt_token_ids + else: + tokens = embeddings[i].prompt_token_ids + embeddings[ + i + 1].prompt_token_ids + + scores.append( + PoolingRequestOutput( + request_id= + f"{embeddings[i].request_id}_{embeddings[i+1].request_id}", + outputs=pair_score, + prompt_token_ids=tokens, + finished=True)) + + final_res_batch = scores + assert all(final_res is not None + for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_score_response( + final_res_batch_checked, + request_id, + created_time, + model_name, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) return response