Skip to content

Commit

Permalink
merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU committed Aug 7, 2023
2 parents b61c8dd + 58a072b commit 58e7121
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 35 deletions.
33 changes: 33 additions & 0 deletions examples/openai_chatcompletion_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import openai

# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"

# List models API
models = openai.Model.list()
print("Models:", models)

model = models["data"][0]["id"]

# Chat completion API
chat_completion = openai.ChatCompletion.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}])

print("Chat completion results:")
print(chat_completion)
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"

# Test list models API
# List models API
models = openai.Model.list()
print("Models:", models)

# Test completion API
stream = True
model = models["data"][0]["id"]

# Completion API
stream = False
completion = openai.Completion.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
best_of=3,
stream=stream,
logprobs=3)

# print the completion
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print("Completion result:", completion)
print(completion)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ xformers >= 0.0.19
fastapi
uvicorn
pydantic < 2 # Required for OpenAI server.
fschat # Required for OpenAI ChatCompletion Endpoint.
28 changes: 28 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

def get_max_model_len(self) -> int:
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
rope_scaling = getattr(self.hf_config, 'rope_scaling', None)
if rope_scaling is not None:
scale_factor = rope_scaling['factor'] if rope_scaling else 1.0
max_position_embeddings = getattr(self.hf_config,
'max_position_embeddings',
None)
assert max_position_embeddings is not None
max_model_len = max_position_embeddings * scale_factor
return max_model_len

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
Expand Down
11 changes: 2 additions & 9 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,9 @@ def create_engine_configs(
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
rope_scaling = getattr(model_config.hf_config, 'rope_scaling', None)
scale_factor = rope_scaling['factor'] if rope_scaling else 1.0
max_seq_len = getattr(model_config.hf_config, 'max_sequence_length',
float('inf'))
max_model_len = getattr(model_config.hf_config,
'max_position_embeddings',
float('inf')) * scale_factor
max_model_len = min(max_model_len, max_seq_len)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_model_len)
self.max_num_seqs,
model_config.get_max_model_len())
return model_config, cache_config, parallel_config, scheduler_config


Expand Down
35 changes: 17 additions & 18 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template

import uvicorn

from vllm.engine.arg_utils import AsyncEngineArgs
Expand All @@ -33,6 +30,13 @@
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid

try:
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False

TIMEOUT_KEEP_ALIVE = 5 # seconds

logger = init_logger(__name__)
Expand Down Expand Up @@ -63,6 +67,11 @@ async def check_model(request) -> Optional[JSONResponse]:


async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
Expand Down Expand Up @@ -98,25 +107,14 @@ async def get_gen_prompt(request) -> str:
return prompt


async def check_length(request, prompt, model_config):
if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = model_config.hf_config.max_sequence_length
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
elif hasattr(model_config.hf_config, "max_position_embeddings"):
context_len = model_config.hf_config.max_position_embeddings
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
else:
context_len = 2048

async def check_length(request, prompt):
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)

if token_num + request.max_tokens > context_len:
if token_num + request.max_tokens > max_model_len:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {context_len} tokens. "
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
Expand Down Expand Up @@ -185,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported")

prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine_model_config)
error_check_ret = await check_length(request, prompt)
if error_check_ret is not None:
return error_check_ret

Expand Down Expand Up @@ -582,6 +580,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len()

# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,
Expand Down

0 comments on commit 58e7121

Please sign in to comment.