-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathserving_embedding.py
242 lines (202 loc) · 8.91 KB
/
serving_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import time
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
class OpenAIServingEmbedding(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(request, EmbeddingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
# In embedding requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_embedding_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_embedding_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> EmbeddingResponse:
items: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)