Skip to content

Commit

Permalink
Run yapf and ruff
Browse files Browse the repository at this point in the history
Signed-off-by: clark <panf2333@gmail.com>
  • Loading branch information
panf2333 committed Jan 8, 2025
1 parent d9741bb commit 1bc97ec
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 101 deletions.
49 changes: 28 additions & 21 deletions benchmarks/disagg_benchmarks/zmq/test_request.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
import asyncio
import json

import aiohttp

# test connect completions we assume prefill and decode are on the same node
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 --chat-template ~/vllm/examples/template_chatglm2.jinja

# test connect completions we assume prefill and decode are on the same node
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \
# --chat-template ~/vllm/examples/template_chatglm2.jinja
# 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010
# 3. python test_request.py

async def test_connect_completions(session):
try:
base_url = "http://localhost:8001/v1/connect/completions"
body = {
"temperature": 0.5,
"top_p": 0.9,
"max_tokens": 150,
"frequency_penalty": 1.3,
"presence_penalty": 0.2,
"repetition_penalty": 1.2,
"model": "facebook/opt-125m",
"prompt": "Can you introduce vllm?",
"stream": True,
"stream_options": {
"temperature": 0.5,
"top_p": 0.9,
"max_tokens": 150,
"frequency_penalty": 1.3,
"presence_penalty": 0.2,
"repetition_penalty": 1.2,
"model": "facebook/opt-125m",
"prompt": "Can you introduce vllm?",
"stream": True,
"stream_options": {
"include_usage": True
}}
print(f"Sending request to {base_url}, body {body}")
async with session.post(base_url, json= body) as response:

}
}
print(f"Sending request to {base_url}, body {body}")
async with session.post(base_url, json=body) as response:

print(response.status)
print(response.headers)
responseText = ""
Expand All @@ -40,13 +43,18 @@ async def test_connect_completions(session):
print(f"Error decoding chunk: {chunk!r}")
else:
# Print the headers and JSON response
print(f"Unexpected Transfer-Encoding: {transfer_encoding} {response.headers} {await response.json()}")
print("Unexpected Transfer-Encoding: {} {} {}".format(
transfer_encoding, response.headers, await
response.json()))
else:
print(f"Request failed with status code {response.status}")
print(f"baseurl {base_url} response data {extract_data(responseText)}")
print(
f"baseurl {base_url} response data {extract_data(responseText)}"
)
except aiohttp.ClientError as e:
print(f"Error: {e}")


def extract_data(responseText):
reply = ""
for data in responseText.split("\n\n"):
Expand All @@ -66,7 +74,7 @@ def extract_data(responseText):

return reply


async def main():
async with aiohttp.ClientSession() as session:
tasks = []
Expand All @@ -76,4 +84,3 @@ async def main():


asyncio.run(main())

84 changes: 54 additions & 30 deletions vllm/entrypoints/disagg_connector.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import json
import signal
import uuid
# from fastapi.lifespan import Lifespan
from asyncio import Queue
from contextlib import asynccontextmanager

import uvicorn
import zmq
import zmq.asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
# from fastapi.lifespan import Lifespan
from asyncio import Queue
import uuid
import signal

from vllm.logger import init_logger

# default prefill and decode url
Expand All @@ -20,94 +22,117 @@
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.connect')


@asynccontextmanager
async def lifespan(app: FastAPI):
# create socket pool with prefill and decode
logger.info("start create_socket_pool")
app.state.zmqctx = zmq.asyncio.Context()
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
app.state.sockets_prefill = await create_socket_pool(
app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_prefill")
app.state.sockets_decode = await create_socket_pool(app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
app.state.sockets_decode = await create_socket_pool(
app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_decode")
yield
## close zmq context
logger.info("term zmqctx")
app.state.zmqctx.destroy(linger=0)


app = FastAPI(lifespan=lifespan)


# create async socket pool with num_sockets use ZMQ_DEALER
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue:
sockets = Queue()
async def create_socket_pool(url: str, num_sockets: int,
zmqctx: zmq.asyncio.Context) -> Queue:
sockets: Queue = Queue()
for i in range(num_sockets):
sock = zmqctx.socket(zmq.DEALER)
identity = f"worker-{i}-{uuid.uuid4()}"
sock.setsockopt(zmq.IDENTITY, identity.encode())
sock.connect(url)
logger.info(f"{identity} started at {url} {sockets.qsize()}")
logger.info("%s started at %s with queue size %s", identity, url,
sockets.qsize())
await sockets.put(sock)
return sockets


# select a socket and execute task
async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue):
async def execute_task_async(route: str, headers: dict, request: dict,
sockets: Queue):
sock = await sockets.get()
try:
requestBody = json.dumps(request)
headersJson = json.dumps(headers)
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
logger.info(f"Sent end")
logger.info("Sending requestBody: %s to %s with headers: %s",
requestBody, route, headersJson)
await sock.send_multipart(
[route.encode(),
headersJson.encode(),
requestBody.encode()])
logger.info("Sent end")
while True:
logger.info(f"Waiting for reply")
logger.info("Waiting for reply")
[contentType, reply] = await sock.recv_multipart()
logger.info(f"Received result: {contentType}, {reply}")
logger.info("Received result: %s, %s", contentType, reply)
reply = reply.decode()
yield f"{reply}"
if "[DONE]" in reply:
logger.info(f"Received stop signal, return socket")
logger.info("Received stop signal, return socket")
break
finally:
await sockets.put(sock)


@app.post('/v1/connect/completions')
async def chat_completions(request: Request):
try:
# Add the X-Request-Id header to the raw headers list
x_request_id = str(uuid.uuid4())
header = dict(request.headers)
if header.get("X-Request-Id") is None:
logger.info(f"add X-Request-Id: {x_request_id}")
logger.info("add X-Request-Id: %s", x_request_id)
header["X-Request-Id"] = x_request_id
original_request_data = await request.json()
logger.info(f"Received request: {original_request_data} header: {header}")
logger.info("Received request: %s header: %s", original_request_data,
header)
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1
route = "/v1/completions"
# finish prefill
async for _ in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
async for _ in execute_task_async(route, header, prefill_request,
app.state.sockets_prefill):
continue

# return decode
return StreamingResponse(execute_task_async(route, header,original_request_data, app.state.sockets_decode), media_type="text/event-stream")

return StreamingResponse(execute_task_async(route, header,
original_request_data,
app.state.sockets_decode),
media_type="text/event-stream")

except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
logger.error("Error occurred in disagg prefill proxy server")
logger.error(e)
logger.error("".join(traceback.format_exception(*exc_info)))


async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
logger.info(f"vLLM Disaggregate Connector start {args} {uvicorn_kwargs}")
logger.info("vLLM Disaggregate Connector start %s %s", args,
uvicorn_kwargs)
logger.info(args.prefill_addr)

app.state.prefill_addr = f"tcp://{args.prefill_addr}" if args.prefill_addr is not None else url_prefill
app.state.decode_addr = f"tcp://{args.decode_addr}" if args.decode_addr is not None else url_decode
logger.info(f"start connect url_prefill: {app.state.prefill_addr} url_decode: {app.state.decode_addr}")

app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr
is not None else url_prefill)
app.state.decode_addr = (f"tcp://{args.decode_addr}"
if args.decode_addr is not None else url_decode)
logger.info("start connect url_prefill: %s url_decode: %s",
app.state.prefill_addr, app.state.decode_addr)

def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
Expand All @@ -118,8 +143,7 @@ def signal_handler(*_) -> None:
server = uvicorn.Server(config)
await server.serve()



if __name__ == "__main__":
# url = 'tcp://127.0.0.1:5555'
uvicorn.run(app, host="0.0.0.0", port=8001)
uvicorn.run(app, host="0.0.0.0", port=8001)
25 changes: 15 additions & 10 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from http import HTTPStatus
from typing import Any

import uvicorn
import zmq
import zmq.asyncio

import uvicorn
from fastapi import FastAPI, Request, Response

from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm.logger import init_logger
from vllm.utils import find_process_using_port

Expand Down Expand Up @@ -59,11 +58,13 @@ async def dummy_shutdown() -> None:
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
return server.shutdown()


async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}")
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
zmq_server_port)
url_worker = "inproc://workers"
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
# Prepare our context and sockets
Expand All @@ -72,15 +73,18 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
# Socket to talk to clients
clients = context.socket(zmq.ROUTER)
clients.bind(url_client)
logger.info(f"ZMQ Server ROUTER started at {url_client}")
logger.info("ZMQ Server ROUTER started at %s", url_client)
# Socket to talk to workers
workers = context.socket(zmq.DEALER)
workers.bind(url_worker)
logger.info(f"ZMQ Worker DEALER started at {url_worker}")
logger.info("ZMQ Worker DEALER started at %s", url_worker)

tasks = [
asyncio.create_task(worker_routine(url_worker, app, context, i))
for i in range(5)
]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)

tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)

try:
await asyncio.gather(*tasks, proxy_task)
except KeyboardInterrupt:
Expand All @@ -93,6 +97,7 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
workers.close()
context.destroy(linger=0)


def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""

Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,8 @@ def signal_handler(*_) -> None:

zmq_server_port = args.zmq_server_port
if zmq_server_port is not None:
logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port)
logger.info("asyncio.create_task Starting ZMQ server at port %d",
zmq_server_port)
asyncio.create_task(serve_zmq(args, zmq_server_port, app))

shutdown_task = await serve_http(
Expand Down
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")

parser.add_argument('--zmq-server-port',

parser.add_argument(
'--zmq-server-port',
type=int,
default=5555,
help='The port to serve the zmq server on.\n\nDefault: 5555')
Expand Down
Loading

0 comments on commit 1bc97ec

Please sign in to comment.