From b92bc51725bd63dd63346a0fd14d60c3a4b85a0b Mon Sep 17 00:00:00 2001 From: zibai Date: Thu, 16 Jan 2025 15:40:56 +0800 Subject: [PATCH] benchmark_serving support --served-model-name param --- benchmarks/backend_request_func.py | 37 ++++++++++++++++-------------- benchmarks/benchmark_serving.py | 14 +++++++++++ 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 9d71e4ecc4a37..3f81dc5baf767 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -22,6 +22,7 @@ class RequestFuncInput: prompt_len: int output_len: int model: str + model_name: str = None best_of: int = 1 logprobs: Optional[int] = None extra_body: Optional[dict] = None @@ -43,8 +44,8 @@ class RequestFuncOutput: async def async_request_tgi( - request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") @@ -78,7 +79,7 @@ async def async_request_tgi( continue chunk_bytes = chunk_bytes.decode("utf-8") - #NOTE: Sometimes TGI returns a ping response without + # NOTE: Sometimes TGI returns a ping response without # any data, we should skip it. if chunk_bytes.startswith(":"): continue @@ -115,8 +116,8 @@ async def async_request_tgi( async def async_request_trt_llm( - request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") @@ -182,8 +183,8 @@ async def async_request_trt_llm( async def async_request_deepspeed_mii( - request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 @@ -225,8 +226,8 @@ async def async_request_deepspeed_mii( async def async_request_openai_completions( - request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( @@ -235,7 +236,8 @@ async def async_request_openai_completions( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "best_of": request_func_input.best_of, @@ -315,8 +317,8 @@ async def async_request_openai_completions( async def async_request_openai_chat_completions( - request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( @@ -328,7 +330,8 @@ async def async_request_openai_chat_completions( if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model, + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, "messages": [ { "role": "user", @@ -417,10 +420,10 @@ def get_model(pretrained_model_name_or_path: str) -> str: def get_tokenizer( - pretrained_model_name_or_path: str, - tokenizer_mode: str = "auto", - trust_remote_code: bool = False, - **kwargs, + pretrained_model_name_or_path: str, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( pretrained_model_name_or_path): diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 4eb0e1f8ac903..d8f00866dba0b 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -525,6 +525,7 @@ async def benchmark( api_url: str, base_url: str, model_id: str, + model_name: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], logprobs: Optional[int], @@ -553,6 +554,7 @@ async def benchmark( "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, + model_name=model_name, prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, @@ -573,6 +575,7 @@ async def benchmark( if profile: print("Starting profiler...") profile_input = RequestFuncInput(model=model_id, + model_name=model_name, prompt=test_prompt, api_url=base_url + "/start_profile", prompt_len=test_prompt_len, @@ -616,6 +619,7 @@ async def limited_request_func(request_func_input, pbar): async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput(model=model_id, + model_name=model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, @@ -780,6 +784,7 @@ def main(args: argparse.Namespace): backend = args.backend model_id = args.model + model_name = args.served_model_name tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_mode = args.tokenizer_mode @@ -877,6 +882,7 @@ def main(args: argparse.Namespace): api_url=api_url, base_url=base_url, model_id=model_id, + model_name=model_name, tokenizer=tokenizer, input_requests=input_requests, logprobs=args.logprobs, @@ -1222,5 +1228,13 @@ def main(args: argparse.Namespace): 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + args = parser.parse_args() main(args)